In [1]:
from _0_mamba_vs_neo.models.MambaForSequenceClassification import MambaForSequenceClassification
import _0_mamba_vs_neo.datasets.ecthr.utils_ecthr as utils_ecthr

In [2]:
from transformers import AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
import torch
import numpy as np
from peft import get_peft_model, LoraConfig, TaskType
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
import os

In [3]:
"""
Description:

Train for 6 epochs with 512 tokens without silver data, as to recreate the environment of the original ecthr paper.
"""

'\nDescription:\n\nTrain for 6 epochs with 512 tokens without silver data, as to recreate the environment of the original ecthr paper.\n'

In [4]:
os.environ["WANDB_PROJECT"] = "mamba_vs_neo_ecthr"

In [5]:
"""
CONFIGS:
"""

'\nCONFIGS:\n'

In [6]:
"""
    general:
        - RUN_NAME: str
            name of the run
        - OUTPUT_DIR: str
            directory to save the model and logs
        - SEED: int
            random seed to use
        - REPORT_TO: str
"""
RUN_NAME = "mamba_run_512_tokens_6_epochs_no_silver"
OUTPUT_DIR = f"_0_mamba_vs_neo/models/mamba/{RUN_NAME}"
SEED = 42
REPORT_TO = "wandb"

In [7]:
"""
    dataset:
        - ALLEGATIONS: bool
            True: use allegation data for the cases, so what laws did the cases allegedly violate
            False: use court decisions, so what laws did the court decide the cases violated
        - SILVER: bool
            True: only use facts which were deemed relevant by the court
            False: use all facts
        - MULTI_LABEL: bool
            True: use multi-label classification (which law was (allegedly) violated)
            False: use binary classification (was there a law (allegedly) violated)
        - FREQUENCY_THRESHOLD: int
            minimum number of cases a law must be (allegedly) violated in to be considered
        - NUM_LABELS: int
            number of labels in the dataset (ecthr: 41)
        - MAX_LENGTH: int
            maximum number of tokens in a sequence     
"""
ALLEGATIONS = True
SILVER = False
MULTI_LABEL = True
FREQUENCY_THRESHOLD = 0
NUM_LABELS = 41

MAX_LENGTH = 512

In [8]:
"""
    training:
        - EPOCHS: int
            number of times to iterate over the dataset
        - LEARNING_RATE: float
            rate at which the model learns
        - BATCH_SIZE: int
            number of sequences in a batch
        - GRADIENT_ACCUMULATION_STEPS: int
            number of batches to accumulate gradients over
        - USE_LENGTH_GROUPING: bool
            True: group sequences of similar length together to minimize padding
            False: do not group sequences by length
        - WARMUP_RATIO: float
            ratio of training steps to warmup steps
        - MAX_GRAD_NORM: float
            maximum gradient norm to clip to
        - WEIGHT_DECAY: float
            weight decay to apply to the model
"""
EPOCHS = 6
LEARNING_RATE = 2e-5
BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 2
print("true batch size:", BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)

WARMUP_RATIO = 0.1
MAX_GRAD_NORM = 0.3
WEIGHT_DECAY = 0.001

USE_LENGTH_GROUPING = True

true batch size: 16


In [9]:
"""
    evaluation:
        - EVAL_STEPS: int
            number of steps between evaluations
        - BATCH_SIZE_EVAL: int
            number of sequences in a batch for evaluation
        - LOGGING_STEPS: int
            number of steps between logging
        - EVAL_ACCUMULATION_STEPS: int
            number eval batches to calculate before copying to the cpu, if the eval requires a lot of memory this is helpful
"""
EVAL_STEPS = 200
BATCH_SIZE_EVAL = BATCH_SIZE
LOGGING_STEPS = 100
EVAL_ACCUMULATION_STEPS = 20

In [10]:
"""
    model:
        - MODEL_NAME: str
            name of the model to use
        - LORA_TASK_TYPE:
        - LORA_R: int
           r is the rank of the approximation
        - LORA_TARGET_MODULES: list
            list of modules to target with LoRA
"""
MODEL_NAME = "state-spaces/mamba-1.4b-hf"
LORA_TASK_TYPE = TaskType.SEQ_CLS
LORA_R = 8
LORA_TARGET_MODULES = ["x_proj", "embeddings", "in_proj", "out_proj"]

In [11]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    
    precision_macro, recall_macto, f1_macro, _ = precision_recall_fscore_support(labels, predictions, average='macro', zero_division=0)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, predictions, average='micro', zero_division=0)
    accuracy = accuracy_score(labels, predictions)

    return {
        'strict_accuracy': accuracy,
        'hamming_accuracy': 1 - hamming_loss(labels, predictions),
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'precision_micro': precision_micro,
        'recall_macro': recall_macto,
        'recall_micro': recall_micro
    }

In [12]:
class SimpleBCELossTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = torch.nn.BCEWithLogitsLoss()
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = self.loss_fct(logits, labels.float())
        return (loss, outputs) if return_outputs else loss

In [13]:
model = MambaForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of MambaForSequenceClassification were not initialized from the model checkpoint at state-spaces/mamba-1.4b-hf and are newly initialized: ['backbone.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.pad_token_id = tokenizer.eos_token_id

In [15]:
collator = DataCollatorWithPadding(tokenizer=tokenizer, padding = True)

In [16]:
ecthr_dataset = utils_ecthr.load_ecthr_dataset(allegations=ALLEGATIONS, silver=SILVER, is_multi_label=MULTI_LABEL, frequency_threshold=FREQUENCY_THRESHOLD)
ecthr_dataset = utils_ecthr.tokenize_dataset(ecthr_dataset, tokenizer, max_length=MAX_LENGTH)
ecthr_dataset = ecthr_dataset.remove_columns("facts")

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [17]:
train = ecthr_dataset["train"]
val = ecthr_dataset["validation"]
test = ecthr_dataset["test"]

In [18]:
lora_config =  LoraConfig(
        r=LORA_R,
        target_modules=LORA_TARGET_MODULES,
        task_type=LORA_TASK_TYPE,
        bias="none"
)

In [19]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 8,428,352 || all params: 1,380,690,752 || trainable%: 0.6104


In [20]:
model.to("cuda")

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): MambaForSequenceClassification(
      (embeddings): lora.Embedding(
        (base_layer): Embedding(50280, 2048)
        (lora_dropout): ModuleDict(
          (default): Identity()
        )
        (lora_A): ModuleDict()
        (lora_B): ModuleDict()
        (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 8x50280 (cuda:0)])
        (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 2048x8 (cuda:0)])
      )
      (layers): ModuleList(
        (0-47): 48 x MambaBlock(
          (norm): MambaRMSNorm()
          (mixer): MambaMixer(
            (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
            (act): SiLU()
            (in_proj): lora.Linear(
              (base_layer): Linear(in_features=2048, out_features=8192, bias=False)
              (lora_dropout): ModuleDi

In [21]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

base_model.model.embeddings.lora_embedding_A.default
base_model.model.embeddings.lora_embedding_B.default
base_model.model.layers.0.mixer.in_proj.lora_A.default.weight
base_model.model.layers.0.mixer.in_proj.lora_B.default.weight
base_model.model.layers.0.mixer.x_proj.lora_A.default.weight
base_model.model.layers.0.mixer.x_proj.lora_B.default.weight
base_model.model.layers.0.mixer.out_proj.lora_A.default.weight
base_model.model.layers.0.mixer.out_proj.lora_B.default.weight
base_model.model.layers.1.mixer.in_proj.lora_A.default.weight
base_model.model.layers.1.mixer.in_proj.lora_B.default.weight
base_model.model.layers.1.mixer.x_proj.lora_A.default.weight
base_model.model.layers.1.mixer.x_proj.lora_B.default.weight
base_model.model.layers.1.mixer.out_proj.lora_A.default.weight
base_model.model.layers.1.mixer.out_proj.lora_B.default.weight
base_model.model.layers.2.mixer.in_proj.lora_A.default.weight
base_model.model.layers.2.mixer.in_proj.lora_B.default.weight
base_model.model.layers.2.

In [22]:
training_args = TrainingArguments(
    output_dir= OUTPUT_DIR,
    run_name= RUN_NAME,
    learning_rate= LEARNING_RATE,
    lr_scheduler_type= "constant",
    warmup_ratio= WARMUP_RATIO,
    max_grad_norm= MAX_GRAD_NORM,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size= BATCH_SIZE_EVAL,
    gradient_accumulation_steps= GRADIENT_ACCUMULATION_STEPS,#
    group_by_length= USE_LENGTH_GROUPING,
    num_train_epochs= EPOCHS,
    weight_decay= WEIGHT_DECAY,
    eval_strategy="steps",
    eval_steps= EVAL_STEPS,
    eval_accumulation_steps = EVAL_ACCUMULATION_STEPS,
    save_strategy="steps",
    save_steps= EVAL_STEPS,
    load_best_model_at_end=True,
    report_to= REPORT_TO,
    fp16=False,
    gradient_checkpointing=True,
    logging_dir="logs",
    logging_steps= LOGGING_STEPS,
    label_names=["labels"],
)

In [23]:
trainer = SimpleBCELossTrainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train,
    eval_dataset=val,
    compute_metrics=compute_metrics
)

In [24]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33melisabeth-fittschen[0m. Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss,Strict Accuracy,Hamming Accuracy,F1 Macro,F1 Micro,Precision Macro,Precision Micro,Recall Macro,Recall Micro
200,0.1771,0.122381,0.157,0.956293,0.014287,0.297255,0.010271,0.421111,0.023462,0.229697
400,0.1009,0.107689,0.152,0.958659,0.022542,0.296972,0.026174,0.470434,0.024648,0.21697
600,0.0904,0.09411,0.227,0.967585,0.067642,0.490219,0.091834,0.667712,0.063485,0.387273
800,0.078,0.085677,0.246,0.96978,0.093127,0.528359,0.145729,0.710338,0.077974,0.420606
1000,0.076,0.082664,0.269,0.970439,0.104257,0.557664,0.138818,0.700917,0.090861,0.46303
1200,0.0721,0.081232,0.305,0.971756,0.115184,0.586429,0.139213,0.713913,0.103006,0.497576
1400,0.069,0.079026,0.312,0.971732,0.118843,0.584141,0.152821,0.715919,0.104814,0.493333
1600,0.0675,0.077372,0.321,0.972854,0.124705,0.59863,0.163641,0.739092,0.108302,0.50303
1800,0.0658,0.075399,0.343,0.973634,0.135626,0.613514,0.176146,0.748038,0.120841,0.52
2000,0.0623,0.074077,0.363,0.974439,0.138071,0.624103,0.181169,0.764499,0.122133,0.527273




TrainOutput(global_step=3372, training_loss=0.08599307818486314, metrics={'train_runtime': 19534.4487, 'train_samples_per_second': 2.764, 'train_steps_per_second': 0.173, 'total_flos': 1.998290794319493e+17, 'train_loss': 0.08599307818486314, 'epoch': 5.994666666666666})

In [25]:
print(trainer)

<__main__.SimpleBCELossTrainer object at 0x7ff5b5e11b70>


In [26]:
trainer.evaluate(test)

{'eval_loss': 0.07208921760320663,
 'eval_strict_accuracy': 0.421,
 'eval_hamming_accuracy': 0.9753414634146341,
 'eval_f1_macro': 0.15345974648025873,
 'eval_f1_micro': 0.6453875833041038,
 'eval_precision_macro': 0.20668286416975057,
 'eval_precision_micro': 0.7931034482758621,
 'eval_recall_macro': 0.1322420328399865,
 'eval_recall_micro': 0.5440567711413364,
 'eval_runtime': 105.4129,
 'eval_samples_per_second': 9.487,
 'eval_steps_per_second': 1.186,
 'epoch': 5.994666666666666}

In [27]:
predictions = trainer.predict(test)

In [28]:
predictions

PredictionOutput(predictions=array([[ -4.6086802 ,  -3.5544908 ,  -6.64375   , ...,  -8.15938   ,
         -8.15938   ,  -8.15938   ],
       [ -5.0444384 ,  -3.7203033 ,  -5.7259474 , ...,  -8.585681  ,
         -8.585681  ,  -8.585681  ],
       [ -5.581774  ,  -4.9260125 ,  -7.6950803 , ..., -10.603534  ,
        -10.603534  , -10.603534  ],
       ...,
       [ -4.815941  ,   0.41324547,  -5.3264475 , ...,  -7.7252154 ,
         -7.7252154 ,  -7.7252154 ],
       [ -4.581134  ,  -0.406994  ,  -7.3289795 , ...,  -9.254812  ,
         -9.254812  ,  -9.254812  ],
       [ -4.9448485 ,  -3.463598  ,  -5.5648823 , ...,  -8.310236  ,
         -8.310236  ,  -8.310236  ]], dtype=float32), label_ids=array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), metrics={'test_loss': 0.07208921760320663, 'test_strict_accuracy': 0.421, 'test_hamming_ac

In [29]:
def calulate_metrics_index(predictions, index):
    logits = predictions.predictions
    labels = predictions.label_ids
    
    logits = logits[:, index]
    labels = labels[:, index]
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary', zero_division=0)
    
    count_correct = np.sum(labels)
    count_predicted = np.sum(predictions)
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'count_cases': count_correct,
        'count_predicted': count_predicted
    }


In [30]:
ids = utils_ecthr.ARTICLES_ID
ids = {v: k for k, v in ids.items()}
desc = utils_ecthr.ARTICLES_DESC


In [31]:
for i in range(0, 41):
    print("-"*50)
    print(f"Label {i}")
    print(ids[i])
    print(desc[ids[i]])
    print(calulate_metrics_index(predictions, i))

--------------------------------------------------
Label 0
2
Right to life
{'f1': 0.6666666666666666, 'precision': 0.7627118644067796, 'recall': 0.5921052631578947, 'count_cases': 76, 'count_predicted': 59}
--------------------------------------------------
Label 1
3
Prohibition of torture
{'f1': 0.7320261437908496, 'precision': 0.7466666666666667, 'recall': 0.717948717948718, 'count_cases': 234, 'count_predicted': 225}
--------------------------------------------------
Label 2
4
Prohibition of slavery and forced labour
{'f1': 0.0, 'precision': 0.0, 'recall': 0.0, 'count_cases': 3, 'count_predicted': 0}
--------------------------------------------------
Label 3
5
Right to liberty and security
{'f1': 0.6827794561933535, 'precision': 0.837037037037037, 'recall': 0.576530612244898, 'count_cases': 196, 'count_predicted': 135}
--------------------------------------------------
Label 4
6
Right to a fair trial
{'f1': 0.7472826086956522, 'precision': 0.804093567251462, 'recall': 0.697969543147

In [32]:
calulate_metrics_index(predictions, 2)

{'f1': 0.0,
 'precision': 0.0,
 'recall': 0.0,
 'count_cases': 3,
 'count_predicted': 0}