In [1]:
from _0_mamba_vs_neo.models.MambaForSequenceClassification import MambaForSequenceClassification
import _0_mamba_vs_neo.datasets.ecthr.utils_ecthr as utils_ecthr

In [2]:
from transformers import AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
import torch
import numpy as np
from peft import get_peft_model, LoraConfig, TaskType
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
import os

In [3]:
"""
Description:

Train for 6 epochs with 24000 tokens without silver data, as to trying to improve our current results.
"""

'\nDescription:\n\nTrain for 6 epochs with 24000 tokens without silver data, as to trying to improve our current results.\n'

In [4]:
os.environ["WANDB_PROJECT"] = "mamba_vs_neo_ecthr"

In [5]:
"""
CONFIGS:
"""

'\nCONFIGS:\n'

In [6]:
"""
    general:
        - RUN_NAME: str
            name of the run
        - OUTPUT_DIR: str
            directory to save the model and logs
        - SEED: int
            random seed to use
        - REPORT_TO: str
"""
RUN_NAME = "mamba_run_24000_tokens_6_epochs_no_silver"
OUTPUT_DIR = f"_0_mamba_vs_neo/models/mamba/{RUN_NAME}"
SEED = 42
REPORT_TO = "wandb"

In [7]:
"""
    dataset:
        - ALLEGATIONS: bool
            True: use allegation data for the cases, so what laws did the cases allegedly violate
            False: use court decisions, so what laws did the court decide the cases violated
        - SILVER: bool
            True: only use facts which were deemed relevant by the court
            False: use all facts
        - MULTI_LABEL: bool
            True: use multi-label classification (which law was (allegedly) violated)
            False: use binary classification (was there a law (allegedly) violated)
        - FREQUENCY_THRESHOLD: int
            minimum number of cases a law must be (allegedly) violated in to be considered
        - NUM_LABELS: int
            number of labels in the dataset (ecthr: 41)
        - MAX_LENGTH: int
            maximum number of tokens in a sequence     
"""
ALLEGATIONS = True
SILVER = False
MULTI_LABEL = True
FREQUENCY_THRESHOLD = 0
NUM_LABELS = 41

MAX_LENGTH = 24000

In [8]:
"""
    training:
        - EPOCHS: int
            number of times to iterate over the dataset
        - LEARNING_RATE: float
            rate at which the model learns
        - BATCH_SIZE: int
            number of sequences in a batch
        - GRADIENT_ACCUMULATION_STEPS: int
            number of batches to accumulate gradients over
        - USE_LENGTH_GROUPING: bool
            True: group sequences of similar length together to minimize padding
            False: do not group sequences by length
        - WARMUP_RATIO: float
            ratio of training steps to warmup steps
        - MAX_GRAD_NORM: float
            maximum gradient norm to clip to
        - WEIGHT_DECAY: float
            weight decay to apply to the model
"""
EPOCHS = 6
LEARNING_RATE = 2e-5
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 16
print("true batch size:", BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)

WARMUP_RATIO = 0.1
MAX_GRAD_NORM = 0.3
WEIGHT_DECAY = 0.001

USE_LENGTH_GROUPING = True

true batch size: 16


In [9]:
"""
    evaluation:
        - EVAL_STEPS: int
            number of steps between evaluations
        - BATCH_SIZE_EVAL: int
            number of sequences in a batch for evaluation
        - LOGGING_STEPS: int
            number of steps between logging
        - EVAL_ACCUMULATION_STEPS: int
            number eval batches to calculate before copying to the cpu, if the eval requires a lot of memory this is helpful
"""
EVAL_STEPS = 200
BATCH_SIZE_EVAL = BATCH_SIZE
LOGGING_STEPS = 100
EVAL_ACCUMULATION_STEPS = 20

In [10]:
"""
    model:
        - MODEL_NAME: str
            name of the model to use
        - LORA_TASK_TYPE:
        - LORA_R: int
           r is the rank of the approximation
        - LORA_TARGET_MODULES: list
            list of modules to target with LoRA
"""
MODEL_NAME = "state-spaces/mamba-1.4b-hf"
LORA_TASK_TYPE = TaskType.SEQ_CLS
LORA_R = 8
LORA_TARGET_MODULES = ["x_proj", "embeddings", "in_proj", "out_proj"]

In [11]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    
    precision_macro, recall_macto, f1_macro, _ = precision_recall_fscore_support(labels, predictions, average='macro', zero_division=0)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, predictions, average='micro', zero_division=0)
    accuracy = accuracy_score(labels, predictions)

    return {
        'strict_accuracy': accuracy,
        'hamming_accuracy': 1 - hamming_loss(labels, predictions),
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'precision_micro': precision_micro,
        'recall_macro': recall_macto,
        'recall_micro': recall_micro
    }

In [12]:
class SimpleBCELossTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = torch.nn.BCEWithLogitsLoss()
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = self.loss_fct(logits, labels.float())
        return (loss, outputs) if return_outputs else loss

In [13]:
model = MambaForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of MambaForSequenceClassification were not initialized from the model checkpoint at state-spaces/mamba-1.4b-hf and are newly initialized: ['backbone.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.pad_token_id = tokenizer.eos_token_id

In [15]:
collator = DataCollatorWithPadding(tokenizer=tokenizer, padding = True)

In [16]:
ecthr_dataset = utils_ecthr.load_ecthr_dataset(allegations=ALLEGATIONS, silver=SILVER, is_multi_label=MULTI_LABEL, frequency_threshold=FREQUENCY_THRESHOLD)
ecthr_dataset = utils_ecthr.tokenize_dataset(ecthr_dataset, tokenizer, max_length=MAX_LENGTH)
ecthr_dataset = ecthr_dataset.remove_columns("facts")

In [17]:
train = ecthr_dataset["train"]
val = ecthr_dataset["validation"]
test = ecthr_dataset["test"]

In [18]:
lora_config =  LoraConfig(
        r=LORA_R,
        target_modules=LORA_TARGET_MODULES,
        task_type=LORA_TASK_TYPE,
        bias="none"
)

In [19]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 8,428,352 || all params: 1,380,690,752 || trainable%: 0.6104


In [20]:
model.to("cuda")

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): MambaForSequenceClassification(
      (embeddings): lora.Embedding(
        (base_layer): Embedding(50280, 2048)
        (lora_dropout): ModuleDict(
          (default): Identity()
        )
        (lora_A): ModuleDict()
        (lora_B): ModuleDict()
        (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 8x50280 (cuda:0)])
        (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 2048x8 (cuda:0)])
      )
      (layers): ModuleList(
        (0-47): 48 x MambaBlock(
          (norm): MambaRMSNorm()
          (mixer): MambaMixer(
            (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
            (act): SiLU()
            (in_proj): lora.Linear(
              (base_layer): Linear(in_features=2048, out_features=8192, bias=False)
              (lora_dropout): ModuleDi

In [21]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

base_model.model.embeddings.lora_embedding_A.default
base_model.model.embeddings.lora_embedding_B.default
base_model.model.layers.0.mixer.in_proj.lora_A.default.weight
base_model.model.layers.0.mixer.in_proj.lora_B.default.weight
base_model.model.layers.0.mixer.x_proj.lora_A.default.weight
base_model.model.layers.0.mixer.x_proj.lora_B.default.weight
base_model.model.layers.0.mixer.out_proj.lora_A.default.weight
base_model.model.layers.0.mixer.out_proj.lora_B.default.weight
base_model.model.layers.1.mixer.in_proj.lora_A.default.weight
base_model.model.layers.1.mixer.in_proj.lora_B.default.weight
base_model.model.layers.1.mixer.x_proj.lora_A.default.weight
base_model.model.layers.1.mixer.x_proj.lora_B.default.weight
base_model.model.layers.1.mixer.out_proj.lora_A.default.weight
base_model.model.layers.1.mixer.out_proj.lora_B.default.weight
base_model.model.layers.2.mixer.in_proj.lora_A.default.weight
base_model.model.layers.2.mixer.in_proj.lora_B.default.weight
base_model.model.layers.2.

In [22]:
training_args = TrainingArguments(
    output_dir= OUTPUT_DIR,
    run_name= RUN_NAME,
    learning_rate= LEARNING_RATE,
    lr_scheduler_type= "constant",
    warmup_ratio= WARMUP_RATIO,
    max_grad_norm= MAX_GRAD_NORM,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size= BATCH_SIZE_EVAL,
    gradient_accumulation_steps= GRADIENT_ACCUMULATION_STEPS,#
    group_by_length= USE_LENGTH_GROUPING,
    num_train_epochs= EPOCHS,
    weight_decay= WEIGHT_DECAY,
    eval_strategy="steps",
    eval_steps= EVAL_STEPS,
    eval_accumulation_steps = EVAL_ACCUMULATION_STEPS,
    save_strategy="steps",
    save_steps= EVAL_STEPS,
    load_best_model_at_end=True,
    report_to= REPORT_TO,
    fp16=False,
    gradient_checkpointing=True,
    logging_dir="logs",
    logging_steps= LOGGING_STEPS,
    label_names=["labels"],
)

In [23]:
trainer = SimpleBCELossTrainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train,
    eval_dataset=val,
    compute_metrics=compute_metrics
)

In [24]:
trainer.train(resume_from_checkpoint=True)

[34m[1mwandb[0m: Currently logged in as: [33melisabeth-fittschen[0m. Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss,Strict Accuracy,Hamming Accuracy,F1 Macro,F1 Micro,Precision Macro,Precision Micro,Recall Macro,Recall Micro
800,0.0768,0.079067,0.318,0.973122,0.107863,0.603026,0.141859,0.743339,0.09721,0.507273
1000,0.0684,0.071631,0.371,0.975732,0.123269,0.634326,0.150037,0.805789,0.1086,0.52303
1200,0.0653,0.067899,0.385,0.976805,0.136091,0.665494,0.1906,0.792959,0.123839,0.573333
1400,0.0595,0.065477,0.415,0.97839,0.1641,0.692147,0.21031,0.811075,0.145344,0.603636
1600,0.0574,0.062717,0.439,0.97922,0.16002,0.703136,0.219521,0.827049,0.142279,0.611515
1800,0.0567,0.061104,0.466,0.979976,0.186992,0.712233,0.211954,0.844555,0.170635,0.615758
2000,0.0531,0.060221,0.468,0.980512,0.187606,0.724007,0.209797,0.841767,0.173293,0.635152
2200,0.0538,0.059561,0.477,0.98061,0.197476,0.729867,0.234647,0.830626,0.180219,0.650909
2400,0.0509,0.058996,0.477,0.980585,0.190786,0.72721,0.235547,0.836751,0.171892,0.64303
2600,0.0516,0.059575,0.467,0.98061,0.204646,0.731146,0.235661,0.827085,0.187172,0.655152




TrainOutput(global_step=3372, training_loss=0.04639049222602256, metrics={'train_runtime': 62193.7453, 'train_samples_per_second': 0.868, 'train_steps_per_second': 0.054, 'total_flos': 8.460914581745948e+17, 'train_loss': 0.04639049222602256, 'epoch': 5.995555555555556})

In [25]:
print(trainer)

<__main__.SimpleBCELossTrainer object at 0x7f405c1a3ac0>


In [26]:
trainer.evaluate(test)

{'eval_loss': 0.05875522270798683,
 'eval_strict_accuracy': 0.503,
 'eval_hamming_accuracy': 0.9813170731707317,
 'eval_f1_macro': 0.19638829828582866,
 'eval_f1_micro': 0.7401628222523745,
 'eval_precision_macro': 0.26463790068564746,
 'eval_precision_micro': 0.86793953858393,
 'eval_recall_macro': 0.17357790370550572,
 'eval_recall_micro': 0.6451803666469544,
 'eval_runtime': 464.9518,
 'eval_samples_per_second': 2.151,
 'eval_steps_per_second': 2.151,
 'epoch': 5.995555555555556}

In [27]:
predictions = trainer.predict(test)

In [28]:
predictions

PredictionOutput(predictions=array([[ -5.176667 ,  -3.4321737,  -7.507284 , ...,  -9.13706  ,
         -9.13706  ,  -9.13706  ],
       [ -3.9836233,  -3.4444451,  -6.4422345, ...,  -9.466854 ,
         -9.466854 ,  -9.466854 ],
       [ -5.2913375,  -5.354499 ,  -7.378887 , ..., -10.212474 ,
        -10.212474 , -10.212474 ],
       ...,
       [ -5.2673464,  -2.318149 ,  -6.707821 , ...,  -9.407799 ,
         -9.407799 ,  -9.407799 ],
       [ -4.0101576,   3.081163 ,  -7.1755342, ...,  -9.695277 ,
         -9.695277 ,  -9.695277 ],
       [ -5.0343504,   1.9087962,  -6.6021767, ...,  -9.3357935,
         -9.3357935,  -9.3357935]], dtype=float32), label_ids=array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), metrics={'test_loss': 0.05875522270798683, 'test_strict_accuracy': 0.503, 'test_hamming_accuracy': 0.9813170731707317, 'test_f

In [29]:
def calulate_metrics_index(predictions, index):
    logits = predictions.predictions
    labels = predictions.label_ids
    
    logits = logits[:, index]
    labels = labels[:, index]
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary', zero_division=0)
    
    count_correct = np.sum(labels)
    count_predicted = np.sum(predictions)
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'count_cases': count_correct,
        'count_predicted': count_predicted
    }


In [30]:
ids = utils_ecthr.ARTICLES_ID
ids = {v: k for k, v in ids.items()}
desc = utils_ecthr.ARTICLES_DESC


In [31]:
for i in range(0, 41):
    print("-"*50)
    print(f"Label {i}")
    print(ids[i])
    print(desc[ids[i]])
    print(calulate_metrics_index(predictions, i))

--------------------------------------------------
Label 0
2
Right to life
{'f1': 0.835820895522388, 'precision': 0.9655172413793104, 'recall': 0.7368421052631579, 'count_cases': 76, 'count_predicted': 58}
--------------------------------------------------
Label 1
3
Prohibition of torture
{'f1': 0.869198312236287, 'precision': 0.8583333333333333, 'recall': 0.8803418803418803, 'count_cases': 234, 'count_predicted': 240}
--------------------------------------------------
Label 2
4
Prohibition of slavery and forced labour
{'f1': 0.0, 'precision': 0.0, 'recall': 0.0, 'count_cases': 3, 'count_predicted': 0}
--------------------------------------------------
Label 3
5
Right to liberty and security
{'f1': 0.8287292817679558, 'precision': 0.9036144578313253, 'recall': 0.7653061224489796, 'count_cases': 196, 'count_predicted': 166}
--------------------------------------------------
Label 4
6
Right to a fair trial
{'f1': 0.821522309711286, 'precision': 0.8505434782608695, 'recall': 0.79441624365

In [32]:
calulate_metrics_index(predictions, 2)

{'f1': 0.0,
 'precision': 0.0,
 'recall': 0.0,
 'count_cases': 3,
 'count_predicted': 0}