<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/model/ECTHR/ECTHR_SIMCSE_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install and Load packages

In [None]:
!git clone https://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@github.com/danielsaggau/IR_LDC.git

In [2]:
%cd IR_LDC

/content/IR_LDC


In [None]:
!pip install -r requirements.txt

In [4]:
import numpy as np
import torch as nn
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback,
    Trainer,
    LongformerTokenizer
)

# load Datasets

In [None]:
from datasets import load_dataset
dataset = load_dataset("lex_glue", "ecthr_b")

# Connect to Huggingface
Alternativ 1 via pop up window and entering access token

In [None]:
#from huggingface_hub import notebook_login
#notebook_login()
#access code:
#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

Alternativ 2 using the direct command

In [6]:
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')"

# Set labels 
We set the labels to 10 and also pass this argument to the ```AutoModelForSequenceClassification``` function

In [7]:
label_list = list(range(10))
num_labels = len(label_list)

Instantiating the model and the tokenizer from our pre-trained model. This model was pre-trained similarly to `SIMCSE`. Further are using the ``use_fast=True`` specification of the tokenizer.

In [None]:
tokenizer = AutoTokenizer.from_pretrained("danielsaggau/simcse_longformer_ecthr_b", use_auth_token=True, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained("danielsaggau/simcse_longformer_ecthr_b", num_labels=10)

# Data Collator 
Set colaltor to ``pad_to_multiple`` of 8 for efficiency (FP16)

In [9]:
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) # fp16

Here we add global attention mask for the longformer as our attention mechanism is twofold 

In [10]:
def preprocess_function(examples):
        # Tokenize the texts
        cases = []
        padding = "max_length"
        max_seq_length=4096
        for case in examples['text']:
            cases.append(f' {tokenizer.sep_token} '.join([fact for fact in case]))
        batch = tokenizer(cases, padding=padding, max_length=4096, truncation=True)
        # use global attention on CLS token
        global_attention_mask = np.zeros((len(cases),max_seq_length), dtype=np.int32)
        global_attention_mask[:, 0] = 1
        batch['global_attention_mask'] = list(global_attention_mask)
        batch["labels"] = [[1 if label in labels else 0 for label in label_list] for labels in examples["labels"]]
        return batch

In [11]:
tokenized_data = dataset.map(preprocess_function, batched=True,remove_columns=['text'])

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [12]:
def compute_metrics(p: EvalPrediction):
        # Fix gold labels
        y_true = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_true[:, :-1] = p.label_ids
        y_true[:, -1] = (np.sum(p.label_ids, axis=1) == 0).astype('int32')
        # Fix predictions
        logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = (expit(logits) > 0.5).astype('int32')
        y_pred = np.zeros((p.label_ids.shape[0], p.label_ids.shape[1] + 1), dtype=np.int32)
        y_pred[:, :-1] = preds
        y_pred[:, -1] = (np.sum(preds, axis=1) == 0).astype('int32')
        # Compute scores
        macro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='macro', zero_division=0)
        micro_f1 = f1_score(y_true=y_true, y_pred=y_pred, average='micro', zero_division=0)
        return {'macro-f1': macro_f1, 'micro-f1': micro_f1}

# Specify the training arguments 
Here we use a respective batch size of 6 as we are using longformer. Further we use a learning rate of 3e-5 as done by chalkidis et al in lexglue. 
Further we save results by epoch. The metric for the best model is our micro f1. One needs to ensure that the highest score is best so we use greater is better and we load the best model at the end. 
For more pronounced performance increase number of epochs. 

In [13]:
training_args = TrainingArguments(
    output_dir="/slbert_ecthr_b_classsification",
    learning_rate=3e-5,
    per_device_train_batch_size=6,
    per_device_eval_batch_size=6,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    #push_to_hub=True,
    metric_for_best_model="micro-f1",
    greater_is_better=True,
    load_best_model_at_end = True
)

Specifying costum trainer with multiple label classification loss

In [14]:
#class MultilabelTrainer(Trainer):
#    def compute_loss(self, model, inputs, return_outputs=False):
#        labels = inputs.pop("labels")
#        outputs = model(**inputs)
#        logits = outputs.logits
#        #loss_fct = nn.BCELoss()
#        loss_fct = nn.BCEWithLogitsLoss()
#        loss = loss_fct(logits.view(-1, self.model.config.num_labels),
#                        labels.float().view(-1, self.model.config.num_labels))
#        return (loss, outputs) if return_outputs else loss

In [15]:
from torch import cuda
cuda.empty_cache()

In [17]:
trainer =Trainer(
    model=model,
    compute_metrics=compute_metrics,
    args=training_args,
    eval_dataset=tokenized_data['validation'],
    train_dataset=tokenized_data['train'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)])
trainer.train()

***** Running training *****
  Num examples = 9000
  Num Epochs = 1
  Instantaneous batch size per device = 6
  Total train batch size (w. parallel, distributed & accumulation) = 6
  Gradient Accumulation steps = 1
  Total optimization steps = 1500


ValueError: ignored