In [1]:
from _0_mamba_vs_neo.models.MambaForSequenceClassification import MambaForSequenceClassification
import _0_mamba_vs_neo.datasets.ecthr.utils_ecthr as utils_ecthr

In [2]:
from transformers import AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
import torch
import numpy as np
from peft import get_peft_model, LoraConfig, TaskType
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
import os

In [3]:
"""
Description:

Here we will test how training the model with weighted BCE loss will affect the performance of the model.
For speed we will use only sequences up to 512 tokens, but train for 6 epochs.
"""

'\nDescription:\n\nHere we will test how training the model with weighted BCE loss will affect the performance of the model.\nFor speed we will use only sequences up to 512 tokens, but train for 6 epochs.\n'

In [4]:
os.environ["WANDB_PROJECT"] = "mamba_vs_neo_ecthr"

In [5]:
"""
CONFIGS:
"""

'\nCONFIGS:\n'

In [6]:
"""
    general:
        - RUN_NAME: str
            name of the run
        - OUTPUT_DIR: str
            directory to save the model and logs
        - SEED: int
            random seed to use
        - REPORT_TO: str
"""
RUN_NAME = "mamba_run_weighted_512_tokens_6_epochs"
OUTPUT_DIR = f"_0_mamba_vs_neo/models/mamba/{RUN_NAME}"
SEED = 42
REPORT_TO = "wandb"

In [7]:
"""
    dataset:
        - ALLEGATIONS: bool
            True: use allegation data for the cases, so what laws did the cases allegedly violate
            False: use court decisions, so what laws did the court decide the cases violated
        - SILVER: bool
            True: only use facts which were deemed relevant by the court
            False: use all facts
        - MULTI_LABEL: bool
            True: use multi-label classification (which law was (allegedly) violated)
            False: use binary classification (was there a law (allegedly) violated)
        - FREQUENCY_THRESHOLD: int
            minimum number of cases a law must be (allegedly) violated in to be considered
        - NUM_LABELS: int
            number of labels in the dataset (ecthr: 41)
        - MAX_LENGTH: int
            maximum number of tokens in a sequence     
"""
ALLEGATIONS = True
SILVER = True
MULTI_LABEL = True
FREQUENCY_THRESHOLD = 0
NUM_LABELS = 41

MAX_LENGTH = 512

In [8]:
"""
    training:
        - EPOCHS: int
            number of times to iterate over the dataset
        - LEARNING_RATE: float
            rate at which the model learns
        - BATCH_SIZE: int
            number of sequences in a batch
        - GRADIENT_ACCUMULATION_STEPS: int
            number of batches to accumulate gradients over
        - USE_LENGTH_GROUPING: bool
            True: group sequences of similar length together to minimize padding
            False: do not group sequences by length
        - WARMUP_RATIO: float
            ratio of training steps to warmup steps
        - MAX_GRAD_NORM: float
            maximum gradient norm to clip to
        - WEIGHT_DECAY: float
            weight decay to apply to the model
"""
EPOCHS = 6
LEARNING_RATE = 2e-5
BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 2
print("true batch size:", BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)

WARMUP_RATIO = 0.1
MAX_GRAD_NORM = 0.3
WEIGHT_DECAY = 0.001

USE_LENGTH_GROUPING = True

true batch size: 16


In [9]:
"""
    evaluation:
        - EVAL_STEPS: int
            number of steps between evaluations
        - BATCH_SIZE_EVAL: int
            number of sequences in a batch for evaluation
        - LOGGING_STEPS: int
            number of steps between logging
        - EVAL_ACCUMULATION_STEPS: int
            number eval batches to calculate before copying to the cpu, if the eval requires a lot of memory this is helpful
"""
EVAL_STEPS = 200
BATCH_SIZE_EVAL = BATCH_SIZE
LOGGING_STEPS = 100
EVAL_ACCUMULATION_STEPS = 20

In [10]:
"""
    model:
        - MODEL_NAME: str
            name of the model to use
        - LORA_TASK_TYPE:
        - LORA_R: int
           r is the rank of the approximation
        - LORA_TARGET_MODULES: list
            list of modules to target with LoRA
"""
MODEL_NAME = "state-spaces/mamba-1.4b-hf"
LORA_TASK_TYPE = TaskType.SEQ_CLS
LORA_R = 8
LORA_TARGET_MODULES = ["x_proj", "embeddings", "in_proj", "out_proj"]

In [11]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    
    precision_macro, recall_macto, f1_macro, _ = precision_recall_fscore_support(labels, predictions, average='macro', zero_division=0)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, predictions, average='micro', zero_division=0)
    accuracy = accuracy_score(labels, predictions)

    return {
        'strict_accuracy': accuracy,
        'hamming_accuracy': 1 - hamming_loss(labels, predictions),
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'precision_micro': precision_micro,
        'recall_macro': recall_macto,
        'recall_micro': recall_micro
    }

In [12]:
class WeightedBCELossTrainer(Trainer):
    def __init__(self, *args, weight=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight=weight)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = self.loss_fct(logits, labels.float())
        return (loss, outputs) if return_outputs else loss

In [13]:
model = MambaForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of MambaForSequenceClassification were not initialized from the model checkpoint at state-spaces/mamba-1.4b-hf and are newly initialized: ['backbone.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.pad_token_id = tokenizer.eos_token_id

In [15]:
collator = DataCollatorWithPadding(tokenizer=tokenizer, padding = True)

In [16]:
ecthr_dataset = utils_ecthr.load_ecthr_dataset(allegations=ALLEGATIONS, silver=SILVER, is_multi_label=MULTI_LABEL, frequency_threshold=FREQUENCY_THRESHOLD)
ecthr_dataset = utils_ecthr.tokenize_dataset(ecthr_dataset, tokenizer, max_length=MAX_LENGTH)
ecthr_dataset = ecthr_dataset.remove_columns("facts")

In [17]:
train = ecthr_dataset["train"]
val = ecthr_dataset["validation"]
test = ecthr_dataset["test"]

In [18]:
lora_config =  LoraConfig(
        r=LORA_R,
        target_modules=LORA_TARGET_MODULES,
        task_type=LORA_TASK_TYPE,
        bias="none"
)

In [19]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 8,428,352 || all params: 1,380,690,752 || trainable%: 0.6104


In [20]:
model.to("cuda")

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): MambaForSequenceClassification(
      (embeddings): lora.Embedding(
        (base_layer): Embedding(50280, 2048)
        (lora_dropout): ModuleDict(
          (default): Identity()
        )
        (lora_A): ModuleDict()
        (lora_B): ModuleDict()
        (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 8x50280 (cuda:0)])
        (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 2048x8 (cuda:0)])
      )
      (layers): ModuleList(
        (0-47): 48 x MambaBlock(
          (norm): MambaRMSNorm()
          (mixer): MambaMixer(
            (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
            (act): SiLU()
            (in_proj): lora.Linear(
              (base_layer): Linear(in_features=2048, out_features=8192, bias=False)
              (lora_dropout): ModuleDi

In [21]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

base_model.model.embeddings.lora_embedding_A.default
base_model.model.embeddings.lora_embedding_B.default
base_model.model.layers.0.mixer.in_proj.lora_A.default.weight
base_model.model.layers.0.mixer.in_proj.lora_B.default.weight
base_model.model.layers.0.mixer.x_proj.lora_A.default.weight
base_model.model.layers.0.mixer.x_proj.lora_B.default.weight
base_model.model.layers.0.mixer.out_proj.lora_A.default.weight
base_model.model.layers.0.mixer.out_proj.lora_B.default.weight
base_model.model.layers.1.mixer.in_proj.lora_A.default.weight
base_model.model.layers.1.mixer.in_proj.lora_B.default.weight
base_model.model.layers.1.mixer.x_proj.lora_A.default.weight
base_model.model.layers.1.mixer.x_proj.lora_B.default.weight
base_model.model.layers.1.mixer.out_proj.lora_A.default.weight
base_model.model.layers.1.mixer.out_proj.lora_B.default.weight
base_model.model.layers.2.mixer.in_proj.lora_A.default.weight
base_model.model.layers.2.mixer.in_proj.lora_B.default.weight
base_model.model.layers.2.

In [22]:
training_args = TrainingArguments(
    output_dir= OUTPUT_DIR,
    run_name= RUN_NAME,
    learning_rate= LEARNING_RATE,
    lr_scheduler_type= "constant",
    warmup_ratio= WARMUP_RATIO,
    max_grad_norm= MAX_GRAD_NORM,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size= BATCH_SIZE_EVAL,
    gradient_accumulation_steps= GRADIENT_ACCUMULATION_STEPS,#
    group_by_length= USE_LENGTH_GROUPING,
    num_train_epochs= EPOCHS,
    weight_decay= WEIGHT_DECAY,
    eval_strategy="steps",
    eval_steps= EVAL_STEPS,
    eval_accumulation_steps = EVAL_ACCUMULATION_STEPS,
    save_strategy="steps",
    save_steps= EVAL_STEPS,
    load_best_model_at_end=True,
    report_to= REPORT_TO,
    fp16=False,
    gradient_checkpointing=True,
    logging_dir="logs",
    logging_steps= LOGGING_STEPS,
    label_names=["labels"],
)

In [23]:
labels = train["labels"]

In [24]:
# make labels torch:
labels = torch.tensor(labels)
num_positives = labels.sum(dim=0)
num_negatives = len(labels) - num_positives
pos_weight = num_negatives.float() / num_positives.float()
pos_weight[torch.isinf(pos_weight)] = 1.0


In [25]:
num_positives

tensor([ 623, 1740,   26, 1623, 5437,   72, 1056,   81,  441,  162,   16, 1665,
         444,    6,    0,   31,   42,  547,  119,  159,  187, 1558,   15,   61,
           1,    0,   48,    0,    7,    0,    0,    1,    7,   17,    1,   29,
           2,    5,    0,    0,    0])

In [26]:
num_negatives

tensor([8377, 7260, 8974, 7377, 3563, 8928, 7944, 8919, 8559, 8838, 8984, 7335,
        8556, 8994, 9000, 8969, 8958, 8453, 8881, 8841, 8813, 7442, 8985, 8939,
        8999, 9000, 8952, 9000, 8993, 9000, 9000, 8999, 8993, 8983, 8999, 8971,
        8998, 8995, 9000, 9000, 9000])

In [27]:
pos_weight

tensor([1.3446e+01, 4.1724e+00, 3.4515e+02, 4.5453e+00, 6.5532e-01, 1.2400e+02,
        7.5227e+00, 1.1011e+02, 1.9408e+01, 5.4556e+01, 5.6150e+02, 4.4054e+00,
        1.9270e+01, 1.4990e+03, 1.0000e+00, 2.8932e+02, 2.1329e+02, 1.5453e+01,
        7.4630e+01, 5.5604e+01, 4.7128e+01, 4.7766e+00, 5.9900e+02, 1.4654e+02,
        8.9990e+03, 1.0000e+00, 1.8650e+02, 1.0000e+00, 1.2847e+03, 1.0000e+00,
        1.0000e+00, 8.9990e+03, 1.2847e+03, 5.2841e+02, 8.9990e+03, 3.0934e+02,
        4.4990e+03, 1.7990e+03, 1.0000e+00, 1.0000e+00, 1.0000e+00])

In [28]:
trainer = WeightedBCELossTrainer(
    weight=pos_weight.to("cuda"),
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train,
    eval_dataset=val,
    compute_metrics=compute_metrics
)

In [29]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33melisabeth-fittschen[0m. Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss,Strict Accuracy,Hamming Accuracy,F1 Macro,F1 Micro,Precision Macro,Precision Micro,Recall Macro,Recall Micro
200,1.0369,1.584261,0.038,0.930171,0.074029,0.383373,0.063058,0.297361,0.106422,0.539394
400,1.8906,1.981325,0.083,0.943171,0.079508,0.431984,0.069702,0.361338,0.106163,0.53697
600,1.0485,1.91495,0.07,0.923195,0.102383,0.409082,0.084207,0.296276,0.168236,0.660606
800,1.3893,1.759286,0.108,0.915024,0.120466,0.402196,0.088826,0.280517,0.263728,0.710303
1000,1.8528,1.76462,0.077,0.913098,0.129301,0.406661,0.096367,0.280367,0.277682,0.74
1200,0.946,1.633213,0.072,0.906415,0.127928,0.390275,0.096436,0.264484,0.290291,0.744242
1400,2.0937,1.631404,0.097,0.921585,0.138094,0.431678,0.105019,0.304717,0.281983,0.74
1600,0.6167,1.658565,0.1,0.935244,0.150332,0.47311,0.116455,0.351726,0.27533,0.722424
1800,0.8313,1.291687,0.073,0.919195,0.159579,0.425325,0.118118,0.297934,0.369758,0.74303
2000,1.0743,1.418048,0.089,0.924146,0.163807,0.439438,0.124548,0.312724,0.336466,0.738788




TrainOutput(global_step=3372, training_loss=1.047440564505146, metrics={'train_runtime': 17056.7804, 'train_samples_per_second': 3.166, 'train_steps_per_second': 0.198, 'total_flos': 1.702218752087163e+17, 'train_loss': 1.047440564505146, 'epoch': 5.994666666666666})

In [30]:
print(trainer)

<__main__.WeightedBCELossTrainer object at 0x7f83d0230eb0>


In [31]:
trainer.evaluate(test)

{'eval_loss': 3.6794471740722656,
 'eval_strict_accuracy': 0.066,
 'eval_hamming_accuracy': 0.9121707317073171,
 'eval_f1_macro': 0.14599509621044499,
 'eval_f1_micro': 0.4070475876831879,
 'eval_precision_macro': 0.11024248968984837,
 'eval_precision_micro': 0.2820629849383843,
 'eval_recall_macro': 0.3498547274539948,
 'eval_recall_micro': 0.7309284447072738,
 'eval_runtime': 105.2343,
 'eval_samples_per_second': 9.503,
 'eval_steps_per_second': 1.188,
 'epoch': 5.994666666666666}

In [32]:
predictions = trainer.predict(test)

In [33]:
predictions

PredictionOutput(predictions=array([[-2.806864  , -2.3981493 , -2.6462197 , ..., -6.029908  ,
        -6.029908  , -6.029908  ],
       [-1.4677405 , -1.8144476 , -2.4256995 , ..., -7.036352  ,
        -7.036352  , -7.036352  ],
       [-2.2026987 , -2.365809  , -2.1277614 , ..., -5.0918813 ,
        -5.0918813 , -5.0918813 ],
       ...,
       [-0.38720113, -0.90709853, -2.1523602 , ..., -5.989273  ,
        -5.989273  , -5.989273  ],
       [ 0.48376253,  2.9684055 , -3.400179  , ..., -6.8790054 ,
        -6.8790054 , -6.8790054 ],
       [-0.40080598,  3.2161222 , -3.1352289 , ..., -6.5014496 ,
        -6.5014496 , -6.5014496 ]], dtype=float32), label_ids=array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), metrics={'test_loss': 3.6794471740722656, 'test_strict_accuracy': 0.066, 'test_hamming_accuracy': 0.9121707317073171, 'test_f1

In [34]:
def calulate_metrics_index(predictions, index):
    logits = predictions.predictions
    labels = predictions.label_ids
    
    logits = logits[:, index]
    labels = labels[:, index]
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary', zero_division=0)
    
    count_correct = np.sum(labels)
    count_predicted = np.sum(predictions)
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'count_cases': count_correct,
        'count_predicted': count_predicted
    }


In [35]:
ids = utils_ecthr.ARTICLES_ID
ids = {v: k for k, v in ids.items()}
desc = utils_ecthr.ARTICLES_DESC


In [36]:
# starting at 1 because 0 is not occupied due to an indexing error
for i in range(1, 41):
    print("-"*50)
    print(f"Label {i}")
    print(ids[i])
    print(desc[ids[i]])
    print(calulate_metrics_index(predictions, i))

--------------------------------------------------
Label 1
3
Prohibition of torture
{'f1': 0.6261127596439169, 'precision': 0.47954545454545455, 'recall': 0.9017094017094017, 'count_cases': 234, 'count_predicted': 440}
--------------------------------------------------
Label 2
4
Prohibition of slavery and forced labour
{'f1': 0.0, 'precision': 0.0, 'recall': 0.0, 'count_cases': 3, 'count_predicted': 1}
--------------------------------------------------
Label 3
5
Right to liberty and security
{'f1': 0.5936920222634509, 'precision': 0.46647230320699706, 'recall': 0.8163265306122449, 'count_cases': 196, 'count_predicted': 343}
--------------------------------------------------
Label 4
6
Right to a fair trial
{'f1': 0.6190476190476191, 'precision': 0.690625, 'recall': 0.5609137055837563, 'count_cases': 394, 'count_predicted': 320}
--------------------------------------------------
Label 5
7
No punishment without law
{'f1': 0.05660377358490566, 'precision': 0.03125, 'recall': 0.3, 'count_ca

In [37]:
calulate_metrics_index(predictions, 2)

{'f1': 0.0,
 'precision': 0.0,
 'recall': 0.0,
 'count_cases': 3,
 'count_predicted': 1}

In [37]:
""" It seems like the weights are too high for unimportant labels, so we might have to do more research"""