In [1]:
from _0_mamba_vs_neo.models.MambaForSequenceClassification import MambaForSequenceClassification
import _0_mamba_vs_neo.datasets.ecthr.utils_ecthr as utils_ecthr

In [2]:
from transformers import AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
import torch
import numpy as np
from peft import get_peft_model, LoraConfig, TaskType
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
import os

In [3]:
os.environ["WANDB_PROJECT"] = "mamba_vs_neo"

In [4]:
"""
CONFIGS:
"""

'\nCONFIGS:\n'

In [32]:
"""
    general:
        - RUN_NAME: str
            name of the run
        - OUTPUT_DIR: str
            directory to save the model and logs
        - SEED: int
            random seed to use
        - REPORT_TO: str
"""
RUN_NAME = "sample_mamba_run_continue_from_pretrained"
OUTPUT_DIR = f"_0_mamba_vs_neo/models/mamba/{RUN_NAME}"
SEED = 42
REPORT_TO = "wandb"

In [6]:
"""
    dataset:
        - ALLEGATIONS: bool
            True: use allegation data for the cases, so what laws did the cases allegedly violate
            False: use court decisions, so what laws did the court decide the cases violated
        - SILVER: bool
            True: only use facts which were deemed relevant by the court
            False: use all facts
        - MULTI_LABEL: bool
            True: use multi-label classification (which law was (allegedly) violated)
            False: use binary classification (was there a law (allegedly) violated)
        - FREQUENCY_THRESHOLD: int
            minimum number of cases a law must be (allegedly) violated in to be considered
        - NUM_LABELS: int
            number of labels in the dataset (ecthr: 41)
        - MAX_LENGTH: int
            maximum number of tokens in a sequence     
"""
ALLEGATIONS = True
SILVER = True
MULTI_LABEL = True
FREQUENCY_THRESHOLD = 0
NUM_LABELS = 41

MAX_LENGTH = 3000

In [7]:
"""
    training:
        - EPOCHS: int
            number of times to iterate over the dataset
        - LEARNING_RATE: float
            rate at which the model learns
        - BATCH_SIZE: int
            number of sequences in a batch
        - GRADIENT_ACCUMULATION_STEPS: int
            number of batches to accumulate gradients over
        - USE_LENGTH_GROUPING: bool
            True: group sequences of similar length together to minimize padding
            False: do not group sequences by length
        - WARMUP_RATIO: float
            ratio of training steps to warmup steps
        - MAX_GRAD_NORM: float
            maximum gradient norm to clip to
        - WEIGHT_DECAY: float
            weight decay to apply to the model
"""
EPOCHS = 3
LEARNING_RATE = 2e-5
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 8
print("true batch size:", BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)

WARMUP_RATIO = 0.1
MAX_GRAD_NORM = 0.3
WEIGHT_DECAY = 0.001

USE_LENGTH_GROUPING = True

true batch size: 16


In [8]:
"""
    evaluation:
        - EVAL_STEPS: int
            number of steps between evaluations
        - BATCH_SIZE_EVAL: int
            number of sequences in a batch for evaluation
        - LOGGING_STEPS: int
            number of steps between logging
        - EVAL_ACCUMULATION_STEPS: int
            number eval batches to calculate before copying to the cpu, if the eval requires a lot of memory this is helpful
"""
EVAL_STEPS = 200
BATCH_SIZE_EVAL = BATCH_SIZE
LOGGING_STEPS = 100
EVAL_ACCUMULATION_STEPS = 20

In [9]:
"""
    model:
        - MODEL_NAME: str
            name of the model to use
        - LORA_TASK_TYPE:
        - LORA_R: int
           r is the rank of the approximation
        - LORA_TARGET_MODULES: list
            list of modules to target with LoRA
"""
MODEL_NAME = "state-spaces/mamba-1.4b-hf"
LORA_TASK_TYPE = TaskType.SEQ_CLS
LORA_R = 8
LORA_TARGET_MODULES = ["x_proj", "embeddings", "in_proj", "out_proj"]

In [10]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    
    precision_macro, recall_macto, f1_macro, _ = precision_recall_fscore_support(labels, predictions, average='macro', zero_division=0)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, predictions, average='micro', zero_division=0)
    accuracy = accuracy_score(labels, predictions)

    return {
        'strict_accuracy': accuracy,
        'hamming_accuracy': 1 - hamming_loss(labels, predictions),
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'precision_micro': precision_micro,
        'recall_macro': recall_macto,
        'recall_micro': recall_micro
    }

In [11]:
class SimpleBCELossTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = torch.nn.BCEWithLogitsLoss()
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss = self.loss_fct(logits, labels.float())
        return (loss, outputs) if return_outputs else loss

In [12]:
model = MambaForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of MambaForSequenceClassification were not initialized from the model checkpoint at state-spaces/mamba-1.4b-hf and are newly initialized: ['backbone.classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [13]:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.pad_token_id = tokenizer.eos_token_id

In [14]:
collator = DataCollatorWithPadding(tokenizer=tokenizer, padding = True)

In [15]:
ecthr_dataset = utils_ecthr.load_ecthr_dataset(allegations=ALLEGATIONS, silver=SILVER, is_multi_label=MULTI_LABEL, frequency_threshold=FREQUENCY_THRESHOLD)
ecthr_dataset = utils_ecthr.tokenize_dataset(ecthr_dataset, tokenizer, max_length=MAX_LENGTH)
ecthr_dataset = ecthr_dataset.remove_columns("facts")

Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [16]:
train = ecthr_dataset["train"]
val = ecthr_dataset["validation"]
test = ecthr_dataset["test"]

In [17]:
lora_config =  LoraConfig(
        r=LORA_R,
        target_modules=LORA_TARGET_MODULES,
        task_type=LORA_TASK_TYPE,
        bias="none"
)

In [18]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 8,428,352 || all params: 1,380,690,752 || trainable%: 0.6104


In [19]:
model.to("cuda")

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): MambaForSequenceClassification(
      (embeddings): lora.Embedding(
        (base_layer): Embedding(50280, 2048)
        (lora_dropout): ModuleDict(
          (default): Identity()
        )
        (lora_A): ModuleDict()
        (lora_B): ModuleDict()
        (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 8x50280 (cuda:0)])
        (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.FloatTensor of size 2048x8 (cuda:0)])
      )
      (layers): ModuleList(
        (0-47): 48 x MambaBlock(
          (norm): MambaRMSNorm()
          (mixer): MambaMixer(
            (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
            (act): SiLU()
            (in_proj): lora.Linear(
              (base_layer): Linear(in_features=2048, out_features=8192, bias=False)
              (lora_dropout): ModuleDi

In [20]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

base_model.model.embeddings.lora_embedding_A.default
base_model.model.embeddings.lora_embedding_B.default
base_model.model.layers.0.mixer.in_proj.lora_A.default.weight
base_model.model.layers.0.mixer.in_proj.lora_B.default.weight
base_model.model.layers.0.mixer.x_proj.lora_A.default.weight
base_model.model.layers.0.mixer.x_proj.lora_B.default.weight
base_model.model.layers.0.mixer.out_proj.lora_A.default.weight
base_model.model.layers.0.mixer.out_proj.lora_B.default.weight
base_model.model.layers.1.mixer.in_proj.lora_A.default.weight
base_model.model.layers.1.mixer.in_proj.lora_B.default.weight
base_model.model.layers.1.mixer.x_proj.lora_A.default.weight
base_model.model.layers.1.mixer.x_proj.lora_B.default.weight
base_model.model.layers.1.mixer.out_proj.lora_A.default.weight
base_model.model.layers.1.mixer.out_proj.lora_B.default.weight
base_model.model.layers.2.mixer.in_proj.lora_A.default.weight
base_model.model.layers.2.mixer.in_proj.lora_B.default.weight
base_model.model.layers.2.

In [33]:
training_args = TrainingArguments(
    output_dir= OUTPUT_DIR,
    run_name= RUN_NAME,
    learning_rate= LEARNING_RATE,
    lr_scheduler_type= "constant",
    warmup_ratio= WARMUP_RATIO,
    max_grad_norm= MAX_GRAD_NORM,
    per_device_train_batch_size= BATCH_SIZE,
    per_device_eval_batch_size= BATCH_SIZE_EVAL,
    gradient_accumulation_steps= GRADIENT_ACCUMULATION_STEPS,#
    group_by_length= USE_LENGTH_GROUPING,
    num_train_epochs= EPOCHS,
    weight_decay= WEIGHT_DECAY,
    eval_strategy="steps",
    eval_steps= EVAL_STEPS,
    eval_accumulation_steps = EVAL_ACCUMULATION_STEPS,
    save_strategy="steps",
    save_steps= EVAL_STEPS,
    load_best_model_at_end=True,
    report_to= REPORT_TO,
    fp16=False,
    gradient_checkpointing=True,
    logging_dir="logs",
    logging_steps= LOGGING_STEPS,
    label_names=["labels"],
)

In [34]:
#small_val = val.shuffle(seed=SEED).select(range(100))

In [35]:
trainer = SimpleBCELossTrainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train,
    eval_dataset=val,
    compute_metrics=compute_metrics
)

In [36]:
trainer.train()



Step,Training Loss,Validation Loss,Strict Accuracy,Hamming Accuracy,F1 Macro,F1 Micro,Precision Macro,Precision Micro,Recall Macro,Recall Micro
200,0.0642,0.070432,0.382,0.975878,0.136611,0.645393,0.18118,0.790167,0.119901,0.545455
400,0.0636,0.069303,0.404,0.977073,0.164553,0.668313,0.210678,0.799831,0.142621,0.573939
600,0.0623,0.068039,0.411,0.97739,0.15782,0.675534,0.236244,0.799503,0.137039,0.584848
800,0.059,0.067197,0.423,0.977366,0.178139,0.676204,0.232433,0.796875,0.158223,0.587273
1000,0.0592,0.067428,0.409,0.977415,0.180919,0.681129,0.223342,0.788676,0.160467,0.599394
1200,0.0582,0.06623,0.428,0.977707,0.186049,0.682197,0.227377,0.800163,0.163919,0.594545
1400,0.0551,0.066496,0.432,0.978049,0.190497,0.692833,0.227061,0.792969,0.171139,0.615152
1600,0.0539,0.06609,0.418,0.977659,0.187602,0.68128,0.233063,0.799837,0.164653,0.593333


[[-11.030687   -4.7208385  -3.0755806 ...  -8.125722  -11.030687
  -11.030687 ]
 [ -9.587326   -3.6709726   3.0577922 ...  -7.4198103  -9.587326
   -9.587326 ]
 [-10.254232   -3.3438926   2.5248585 ...  -8.12047   -10.254232
  -10.254232 ]
 ...
 [-11.273832   -4.7210336  -4.712103  ...  -8.715168  -11.273832
  -11.273832 ]
 [ -9.560636   -4.7206535  -1.9770163 ...  -7.628247   -9.560636
   -9.560636 ]
 [ -9.939376   -3.6101463   2.3114445 ...  -7.641819   -9.939376
   -9.939376 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[[-11.7527685  -4.508392   -3.176836  ...  -8.17596   -11.7527685
  -11.7527685]
 [-10.283288   -3.7687006   3.3299887 ...  -7.5767975 -10.283288
  -10.283288 ]
 [-11.255176   -3.720572    2.4142003 ...  -8.470324  -11.255176
  -11.255176 ]
 ...
 [-12.321575   -4.773405   -4.7494354 ...  -9.088423  -12.321575
  -12.321575 ]
 [-10.243725   -4.4806547  -1.5168746 ...  -7.7567677 -10.243725
  -10.243725 ]
 [-10.521891   -3.7549613   2.0810704 ...  -7.6671762 -10.521891
  -10.521891 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[[-12.63067    -4.358422   -3.7619557 ...  -8.188254  -12.63067
  -12.63067  ]
 [-10.989179   -3.8232074   3.3973696 ...  -7.697352  -10.989179
  -10.989179 ]
 [-12.079408   -3.7654476   2.4848952 ...  -8.66497   -12.079408
  -12.079408 ]
 ...
 [-12.942959   -4.783126   -4.946003  ...  -8.873002  -12.942959
  -12.942959 ]
 [-10.945094   -4.5104446  -2.0077193 ...  -7.9126754 -10.945094
  -10.945094 ]
 [-11.094442   -3.73681     2.492265  ...  -7.6264324 -11.094442
  -11.094442 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[[-12.762568   -4.726145   -3.4313607 ...  -8.068798  -12.762568
  -12.762568 ]
 [-11.257126   -3.9023404   3.8617055 ...  -7.7945786 -11.257126
  -11.257126 ]
 [-12.24145    -3.9259946   3.0788739 ...  -8.648427  -12.24145
  -12.24145  ]
 ...
 [-13.702964   -5.146306   -5.023329  ...  -9.307158  -13.702964
  -13.702964 ]
 [-11.437416   -4.6418023  -2.2527463 ...  -8.235993  -11.437416
  -11.437416 ]
 [-10.937367   -3.816062    2.935114  ...  -7.3858566 -10.937367
  -10.937367 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[[-12.609631   -4.6056743  -3.460998  ...  -7.4849377 -12.609631
  -12.609631 ]
 [-11.403215   -4.0275497   4.158847  ...  -7.684322  -11.403215
  -11.403215 ]
 [-12.585565   -4.1192155   3.3455005 ...  -8.680235  -12.585565
  -12.585565 ]
 ...
 [-14.331941   -5.1899624  -5.120717  ...  -9.615245  -14.331941
  -14.331941 ]
 [-11.416542   -4.512309   -1.789133  ...  -8.048155  -11.416542
  -11.416542 ]
 [-10.799052   -3.8888886   2.788811  ...  -6.9954233 -10.799052
  -10.799052 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[[-13.95679    -4.765257   -3.8723934 ...  -8.497732  -13.95679
  -13.95679  ]
 [-11.904991   -4.2277365   4.181346  ...  -7.984435  -11.904991
  -11.904991 ]
 [-13.475343   -4.321329    3.111916  ...  -9.258588  -13.475343
  -13.475343 ]
 ...
 [-15.004106   -5.1711     -5.1810107 ...  -9.759948  -15.004106
  -15.004106 ]
 [-12.0788965  -4.418902   -1.648253  ...  -8.474896  -12.0788965
  -12.0788965]
 [-11.983456   -4.2322235   3.0280879 ...  -7.739266  -11.983456
  -11.983456 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[[-13.105121   -4.652276   -3.6668723 ...  -7.5791535 -13.105121
  -13.105121 ]
 [-12.016555   -4.485792    4.125521  ...  -7.9402375 -12.016555
  -12.016555 ]
 [-13.6510935  -4.7080083   3.083308  ...  -9.220643  -13.6510935
  -13.6510935]
 ...
 [-14.457153   -5.0404267  -5.151036  ...  -9.14406   -14.457153
  -14.457153 ]
 [-12.140647   -4.5408225  -1.9883429 ...  -8.43836   -12.140647
  -12.140647 ]
 [-11.677112   -4.362741    2.8263578 ...  -7.376642  -11.677112
  -11.677112 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[[-14.810602   -4.8406076  -3.892897  ...  -8.968035  -14.810602
  -14.810602 ]
 [-12.286702   -4.5661373   4.443317  ...  -8.092192  -12.286702
  -12.286702 ]
 [-14.133541   -4.6240635   3.4322042 ...  -9.568611  -14.133541
  -14.133541 ]
 ...
 [-16.198198   -5.4035745  -4.922159  ... -10.347116  -16.198198
  -16.198198 ]
 [-12.298047   -4.3582034  -1.599482  ...  -8.47441   -12.298047
  -12.298047 ]
 [-12.543616   -4.5637197   3.489311  ...  -7.887649  -12.543616
  -12.543616 ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TrainOutput(global_step=1686, training_loss=0.05955262851601959, metrics={'train_runtime': 21803.8176, 'train_samples_per_second': 1.238, 'train_steps_per_second': 0.077, 'total_flos': 1.951673380408558e+17, 'train_loss': 0.05955262851601959, 'epoch': 2.997333333333333})

In [39]:
print(trainer)

<__main__.SimpleBCELossTrainer object at 0x7f7d7a422590>


In [40]:
trainer.evaluate(test)

[[-12.291445   -5.1227155  -4.258211  ...  -7.951714  -12.291445
  -12.291445 ]
 [-10.858002   -3.2784262  -2.6313453 ...  -6.9643283 -10.858002
  -10.858002 ]
 [-12.777686   -5.4600825  -4.741484  ...  -7.543176  -12.777686
  -12.777686 ]
 ...
 [-12.751511   -2.916143   -2.1151383 ...  -8.090708  -12.751511
  -12.751511 ]
 [-12.929234   -4.266504    3.9500709 ...  -8.730714  -12.929234
  -12.929234 ]
 [-11.38766    -3.5118544   4.4145746 ...  -7.567207  -11.38766
  -11.38766  ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 0.06835978478193283,
 'eval_strict_accuracy': 0.443,
 'eval_hamming_accuracy': 0.9775609756097561,
 'eval_f1_macro': 0.16792489466318994,
 'eval_f1_micro': 0.6836313617606602,
 'eval_precision_macro': 0.21634561072212818,
 'eval_precision_micro': 0.8167625308134757,
 'eval_recall_macro': 0.14640327529675798,
 'eval_recall_micro': 0.5878178592548787,
 'eval_runtime': 323.9407,
 'eval_samples_per_second': 3.087,
 'eval_steps_per_second': 1.543,
 'epoch': 2.997333333333333}

In [42]:
predictions = trainer.predict(test)

[[-12.291445   -5.1227155  -4.258211  ...  -7.951714  -12.291445
  -12.291445 ]
 [-10.858002   -3.2784262  -2.6313453 ...  -6.9643283 -10.858002
  -10.858002 ]
 [-12.777686   -5.4600825  -4.741484  ...  -7.543176  -12.777686
  -12.777686 ]
 ...
 [-12.751511   -2.916143   -2.1151383 ...  -8.090708  -12.751511
  -12.751511 ]
 [-12.929234   -4.266504    3.9500709 ...  -8.730714  -12.929234
  -12.929234 ]
 [-11.38766    -3.5118544   4.4145746 ...  -7.567207  -11.38766
  -11.38766  ]]
(1000, 41)


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [44]:
predictions

PredictionOutput(predictions=array([[-12.291445 ,  -5.1227155,  -4.258211 , ...,  -7.951714 ,
        -12.291445 , -12.291445 ],
       [-10.858002 ,  -3.2784262,  -2.6313453, ...,  -6.9643283,
        -10.858002 , -10.858002 ],
       [-12.777686 ,  -5.4600825,  -4.741484 , ...,  -7.543176 ,
        -12.777686 , -12.777686 ],
       ...,
       [-12.751511 ,  -2.916143 ,  -2.1151383, ...,  -8.090708 ,
        -12.751511 , -12.751511 ],
       [-12.929234 ,  -4.266504 ,   3.9500709, ...,  -8.730714 ,
        -12.929234 , -12.929234 ],
       [-11.38766  ,  -3.5118544,   4.4145746, ...,  -7.567207 ,
        -11.38766  , -11.38766  ]], dtype=float32), label_ids=array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]]), metrics={'test_loss': 0.06835978478193283, 'test_strict_accuracy': 0.443, 'test_hamming_accuracy': 0.9775609756097561, 'test_f

In [66]:
def calulate_metrics_index(predictions, index):
    logits = predictions.predictions
    labels = predictions.label_ids
    
    logits = logits[:, index]
    labels = labels[:, index]
    
    probs = 1 / (1 + np.exp(-logits))
    predictions = (probs > 0.5).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary', zero_division=0)
    
    count_correct = np.sum(labels)
    count_predicted = np.sum(predictions)
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'count_cases': count_correct,
        'count_predicted': count_predicted
    }


In [67]:
ids = utils_ecthr.ARTICLES_ID
ids = {v: k for k, v in ids.items()}
desc = utils_ecthr.ARTICLES_DESC


In [70]:
# starting at 1 because 0 is not occupied due to an indexing error
for i in range(1, 41):
    print("-"*50)
    print(f"Label {i}")
    print(ids[i])
    print(desc[ids[i]])
    print(calulate_metrics_index(predictions, i))

--------------------------------------------------
Label 1
2
Right to life
{'f1': 0.8029197080291971, 'precision': 0.9016393442622951, 'recall': 0.7236842105263158, 'count_cases': 76, 'count_predicted': 61}
--------------------------------------------------
Label 2
3
Prohibition of torture
{'f1': 0.8300220750551877, 'precision': 0.8584474885844748, 'recall': 0.8034188034188035, 'count_cases': 234, 'count_predicted': 219}
--------------------------------------------------
Label 3
4
Prohibition of slavery and forced labour
{'f1': 0.0, 'precision': 0.0, 'recall': 0.0, 'count_cases': 3, 'count_predicted': 0}
--------------------------------------------------
Label 4
5
Right to liberty and security
{'f1': 0.7605633802816901, 'precision': 0.8490566037735849, 'recall': 0.6887755102040817, 'count_cases': 196, 'count_predicted': 159}
--------------------------------------------------
Label 5
6
Right to a fair trial
{'f1': 0.7512953367875648, 'precision': 0.7671957671957672, 'recall': 0.73604060

In [62]:
calulate_metrics_index(predictions, 2)

{'f1': 0.8300220750551877,
 'precision': 0.8584474885844748,
 'recall': 0.8034188034188035,
 'count': 234}