In [1]:
from _0_mamba_vs_neo.models.MambaForSequenceClassification import MambaForSequenceClassification
import _0_mamba_vs_neo.datasets.ecthr.utils_ecthr as utils_ecthr

In [2]:
from transformers import AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
import torch
import numpy as np
from peft import get_peft_model, LoraConfig, TaskType
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
import os

In [3]:
"""
Description:

Train for 6 epochs with 3000 tokens without silver data, as to trying to improve our current results.
"""

'\nDescription:\n\nTrain for 6 epochs with 3000 tokens without silver data, as to trying to improve our current results.\n'

In [4]:
os.environ["WANDB_PROJECT"] = "mamba_vs_neo_ecthr"

In [5]:
"""
CONFIGS:
"""

'\nCONFIGS:\n'

In [6]:
"""
    general:
        - RUN_NAME: str
            name of the run
        - OUTPUT_DIR: str
            directory to save the model and logs
        - SEED: int
            random seed to use
        - REPORT_TO: str
"""
RUN_NAME = "mamba_run_512_tokens_6_epochs_no_silver_3000"
OUTPUT_DIR = f"_0_mamba_vs_neo/models/mamba/{RUN_NAME}"
SEED = 42
REPORT_TO = "wandb"

In [7]:
"""
    dataset:
        - ALLEGATIONS: bool
            True: use allegation data for the cases, so what laws did the cases allegedly violate
            False: use court decisions, so what laws did the court decide the cases violated
        - SILVER: bool
            True: only use facts which were deemed relevant by the court
            False: use all facts
        - MULTI_LABEL: bool
            True: use multi-label classification (which law was (allegedly) violated)
            False: use binary classification (was there a law (allegedly) violated)
        - FREQUENCY_THRESHOLD: int
            minimum number of cases a law must be (allegedly) violated in to be considered
        - NUM_LABELS: int
            number of labels in the dataset (ecthr: 41)
        - MAX_LENGTH: int
            maximum number of tokens in a sequence     
"""
ALLEGATIONS = True
SILVER = False
MULTI_LABEL = True
FREQUENCY_THRESHOLD = 0
NUM_LABELS = 41

MAX_LENGTH = 3000

In [8]:
"""
    training:
        - EPOCHS: int
            number of times to iterate over the dataset
        - LEARNING_RATE: float
            rate at which the model learns
        - BATCH_SIZE: int
            number of sequences in a batch
        - GRADIENT_ACCUMULATION_STEPS: int
            number of batches to accumulate gradients over
        - USE_LENGTH_GROUPING: bool
            True: group sequences of similar length together to minimize padding
            False: do not group sequences by length
        - WARMUP_RATIO: float
            ratio of training steps to warmup steps
        - MAX_GRAD_NORM: float
            maximum gradient norm to clip to
        - WEIGHT_DECAY: float
            weight decay to apply to the model
"""
EPOCHS = 6
LEARNING_RATE = 2e-5
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8
print("true batch size:", BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)

WARMUP_RATIO = 0.1
MAX_GRAD_NORM = 0.3
WEIGHT_DECAY = 0.001

USE_LENGTH_GROUPING = True

true batch size: 16


In [9]:
"""
    evaluation:
        - EVAL_STEPS: int
            number of steps between evaluations
        - BATCH_SIZE_EVAL: int
            number of sequences in a batch for evaluation
        - LOGGING_STEPS: int
            number of steps between logging
        - EVAL_ACCUMULATION_STEPS: int
            number eval batches to calculate before copying to the cpu, if the eval requires a lot of memory this is helpful
"""
EVAL_STEPS = 200
BATCH_SIZE_EVAL = BATCH_SIZE
LOGGING_STEPS = 100
EVAL_ACCUMULATION_STEPS = 20

In [10]:
"""
    model:
        - MODEL_NAME: str
            name of the model to use
        - LORA_TASK_TYPE:
        - LORA_R: int
           r is the rank of the approximation
        - LORA_TARGET_MODULES: list
            list of modules to target with LoRA
"""
MODEL_NAME = "state-spaces/mamba-1.4b-hf"
LORA_TASK_TYPE = TaskType.SEQ_CLS
LORA_R = 8
LORA_TARGET_MODULES = ["x_proj", "embeddings", "in_proj", "out_proj"]

In [11]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    
    precision_macro, recall_macto, f1_macro, _ = precision_recall_fscore_support(labels, predictions, average='macro', zero_division=0)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, predictions, average='micro', zero_division=0)
    accuracy = accuracy_score(labels, predictions)

    return {
        'strict_accuracy': accuracy,
        'hamming_accuracy': 1 - hamming_loss(labels, predictions),
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'precision_micro': precision_micro,
        'recall_macro': recall_macto,
        'recall_micro': recall_micro
    }

In [12]:
class SimpleBCELossTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = torch.nn.BCEWithLogitsLoss()
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = self.loss_fct(logits, labels.float())
        return (loss, outputs) if return_outputs else loss

In [13]:
model = MambaForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of MambaForSequenceClassification were not initialized from the model checkpoint at state-spaces/mamba-1.4b-hf and are newly initialized: ['backbone.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.pad_token_id = tokenizer.eos_token_id

In [15]:
collator = DataCollatorWithPadding(tokenizer=tokenizer, padding = True)

In [16]:
ecthr_dataset = utils_ecthr.load_ecthr_dataset(allegations=ALLEGATIONS, silver=SILVER, is_multi_label=MULTI_LABEL, frequency_threshold=FREQUENCY_THRESHOLD)
ecthr_dataset = utils_ecthr.tokenize_dataset(ecthr_dataset, tokenizer, max_length=MAX_LENGTH)
ecthr_dataset = ecthr_dataset.remove_columns("facts")

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [17]:
train = ecthr_dataset["train"]
val = ecthr_dataset["validation"]
test = ecthr_dataset["test"]

In [18]:
lora_config =  LoraConfig(
        r=LORA_R,
        target_modules=LORA_TARGET_MODULES,
        task_type=LORA_TASK_TYPE,
        bias="none"
)

In [19]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 8,428,352 || all params: 1,380,690,752 || trainable%: 0.6104


In [20]:
model.to("cuda")

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): MambaForSequenceClassification(
      (embeddings): lora.Embedding(
        (base_layer): Embedding(50280, 2048)
        (lora_dropout): ModuleDict(
          (default): Identity()
        )
        (lora_A): ModuleDict()
        (lora_B): ModuleDict()
        (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 8x50280 (cuda:0)])
        (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 2048x8 (cuda:0)])
      )
      (layers): ModuleList(
        (0-47): 48 x MambaBlock(
          (norm): MambaRMSNorm()
          (mixer): MambaMixer(
            (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
            (act): SiLU()
            (in_proj): lora.Linear(
              (base_layer): Linear(in_features=2048, out_features=8192, bias=False)
              (lora_dropout): ModuleDi

In [21]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

base_model.model.embeddings.lora_embedding_A.default
base_model.model.embeddings.lora_embedding_B.default
base_model.model.layers.0.mixer.in_proj.lora_A.default.weight
base_model.model.layers.0.mixer.in_proj.lora_B.default.weight
base_model.model.layers.0.mixer.x_proj.lora_A.default.weight
base_model.model.layers.0.mixer.x_proj.lora_B.default.weight
base_model.model.layers.0.mixer.out_proj.lora_A.default.weight
base_model.model.layers.0.mixer.out_proj.lora_B.default.weight
base_model.model.layers.1.mixer.in_proj.lora_A.default.weight
base_model.model.layers.1.mixer.in_proj.lora_B.default.weight
base_model.model.layers.1.mixer.x_proj.lora_A.default.weight
base_model.model.layers.1.mixer.x_proj.lora_B.default.weight
base_model.model.layers.1.mixer.out_proj.lora_A.default.weight
base_model.model.layers.1.mixer.out_proj.lora_B.default.weight
base_model.model.layers.2.mixer.in_proj.lora_A.default.weight
base_model.model.layers.2.mixer.in_proj.lora_B.default.weight
base_model.model.layers.2.

In [22]:
training_args = TrainingArguments(
    output_dir= OUTPUT_DIR,
    run_name= RUN_NAME,
    learning_rate= LEARNING_RATE,
    lr_scheduler_type= "constant",
    warmup_ratio= WARMUP_RATIO,
    max_grad_norm= MAX_GRAD_NORM,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size= BATCH_SIZE_EVAL,
    gradient_accumulation_steps= GRADIENT_ACCUMULATION_STEPS,#
    group_by_length= USE_LENGTH_GROUPING,
    num_train_epochs= EPOCHS,
    weight_decay= WEIGHT_DECAY,
    eval_strategy="steps",
    eval_steps= EVAL_STEPS,
    eval_accumulation_steps = EVAL_ACCUMULATION_STEPS,
    save_strategy="steps",
    save_steps= EVAL_STEPS,
    load_best_model_at_end=True,
    report_to= REPORT_TO,
    fp16=False,
    gradient_checkpointing=True,
    logging_dir="logs",
    logging_steps= LOGGING_STEPS,
    label_names=["labels"],
)

In [23]:
trainer = SimpleBCELossTrainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train,
    eval_dataset=val,
    compute_metrics=compute_metrics
)

In [24]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33melisabeth-fittschen[0m. Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss,Strict Accuracy,Hamming Accuracy,F1 Macro,F1 Micro,Precision Macro,Precision Micro,Recall Macro,Recall Micro
200,0.1678,0.133901,0.158,0.955707,0.014054,0.287843,0.016072,0.407778,0.02276,0.222424
400,0.1068,0.111182,0.156,0.956488,0.014651,0.288676,0.024899,0.421911,0.022534,0.219394
600,0.0911,0.09213,0.23,0.967585,0.062729,0.48025,0.070699,0.676957,0.057952,0.372121
800,0.0743,0.079394,0.327,0.973707,0.105332,0.596557,0.145408,0.779843,0.092426,0.48303
1000,0.069,0.074962,0.349,0.974585,0.118088,0.61943,0.146742,0.779412,0.103858,0.513939
1200,0.0665,0.071357,0.367,0.975902,0.12821,0.652846,0.147503,0.776756,0.11765,0.56303
1400,0.0616,0.068726,0.397,0.977293,0.151018,0.674817,0.201982,0.796373,0.133127,0.585455
1600,0.0599,0.065422,0.419,0.978171,0.142159,0.68497,0.192643,0.816961,0.128454,0.589697
1800,0.0589,0.064426,0.432,0.978756,0.169446,0.691244,0.206828,0.832622,0.150514,0.590909
2000,0.0553,0.063103,0.443,0.979634,0.177104,0.708958,0.210953,0.83429,0.160042,0.616364




TrainOutput(global_step=3372, training_loss=0.07888669800786643, metrics={'train_runtime': 67583.9234, 'train_samples_per_second': 0.799, 'train_steps_per_second': 0.05, 'total_flos': 6.229271310431731e+17, 'train_loss': 0.07888669800786643, 'epoch': 5.994666666666666})

In [25]:
print(trainer)

<__main__.SimpleBCELossTrainer object at 0x7fc2e5d49ae0>


In [26]:
trainer.evaluate(test)

{'eval_loss': 0.061180777847766876,
 'eval_strict_accuracy': 0.499,
 'eval_hamming_accuracy': 0.9808536585365853,
 'eval_f1_macro': 0.175226564708734,
 'eval_f1_micro': 0.7347076715106455,
 'eval_precision_macro': 0.2271931387119774,
 'eval_precision_micro': 0.8572555205047319,
 'eval_recall_macro': 0.156072879308217,
 'eval_recall_micro': 0.6428149024246008,
 'eval_runtime': 488.3146,
 'eval_samples_per_second': 2.048,
 'eval_steps_per_second': 1.024,
 'epoch': 5.994666666666666}

In [27]:
predictions = trainer.predict(test)

In [28]:
predictions

PredictionOutput(predictions=array([[ -5.309937 ,  -4.5318904,  -7.400669 , ...,  -9.525612 ,
         -9.525612 ,  -9.525612 ],
       [ -5.094165 ,  -4.204568 ,  -6.429409 , ...,  -9.783821 ,
         -9.783821 ,  -9.783821 ],
       [ -6.3716455,  -5.4453177,  -7.776255 , ..., -10.779707 ,
        -10.779707 , -10.779707 ],
       ...,
       [ -4.6644335,  -1.2084378,  -6.5028176, ...,  -9.241397 ,
         -9.241397 ,  -9.241397 ],
       [ -4.575066 ,   2.7364695,  -7.192851 , ...,  -9.672376 ,
         -9.672376 ,  -9.672376 ],
       [ -5.4789486,   2.2882545,  -6.3264813, ...,  -9.233077 ,
         -9.233077 ,  -9.233077 ]], dtype=float32), label_ids=array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), metrics={'test_loss': 0.061180777847766876, 'test_strict_accuracy': 0.499, 'test_hamming_accuracy': 0.9808536585365853, 'test_

In [29]:
def calulate_metrics_index(predictions, index):
    logits = predictions.predictions
    labels = predictions.label_ids
    
    logits = logits[:, index]
    labels = labels[:, index]
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary', zero_division=0)
    
    count_correct = np.sum(labels)
    count_predicted = np.sum(predictions)
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'count_cases': count_correct,
        'count_predicted': count_predicted
    }


In [30]:
ids = utils_ecthr.ARTICLES_ID
ids = {v: k for k, v in ids.items()}
desc = utils_ecthr.ARTICLES_DESC


In [31]:
for i in range(0, 41):
    print("-"*50)
    print(f"Label {i}")
    print(ids[i])
    print(desc[ids[i]])
    print(calulate_metrics_index(predictions, i))

--------------------------------------------------
Label 0
2
Right to life
{'f1': 0.8467153284671532, 'precision': 0.9508196721311475, 'recall': 0.7631578947368421, 'count_cases': 76, 'count_predicted': 61}
--------------------------------------------------
Label 1
3
Prohibition of torture
{'f1': 0.8620689655172413, 'precision': 0.8695652173913043, 'recall': 0.8547008547008547, 'count_cases': 234, 'count_predicted': 230}
--------------------------------------------------
Label 2
4
Prohibition of slavery and forced labour
{'f1': 0.0, 'precision': 0.0, 'recall': 0.0, 'count_cases': 3, 'count_predicted': 0}
--------------------------------------------------
Label 3
5
Right to liberty and security
{'f1': 0.8148148148148148, 'precision': 0.9225806451612903, 'recall': 0.7295918367346939, 'count_cases': 196, 'count_predicted': 155}
--------------------------------------------------
Label 4
6
Right to a fair trial
{'f1': 0.8071519795657727, 'precision': 0.8123393316195373, 'recall': 0.80203045

In [32]:
calulate_metrics_index(predictions, 2)

{'f1': 0.0,
 'precision': 0.0,
 'recall': 0.0,
 'count_cases': 3,
 'count_predicted': 0}