In [1]:
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer
import torch
from transformers import DataCollatorWithPadding
import evaluate
import numpy as np
from evaluate import evaluator

  from .autonotebook import tqdm as notebook_tqdm
2024-02-28 11:42:01.442813: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-28 11:42:01.442859: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-28 11:42:01.445348: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-28 11:42:01.691247: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
echr = load_dataset("ecthr_cases",  "violation-prediction")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [4]:
def encode(examples):
    return tokenizer( examples["text"],
                     truncation=True, 
                     padding=True)

In [None]:
train_dataset, val_dataset, test_dataset = echr['train'], echr['validation'], echr['test']
train_dataset, val_dataset, test_dataset = [dataset.map( lambda examples: {"text": "\n".join(examples["facts"])}) for dataset in [train_dataset, val_dataset, test_dataset]]
train_dataset, val_dataset, test_dataset = [dataset.map(encode, batched=True) for dataset in [train_dataset, val_dataset, test_dataset]]
train_dataset, val_dataset, test_dataset = [dataset.map( lambda examples: {'labels' :list(1 if examples['labels'][i] else 0 for i in range(len(examples['labels'])))}, batched=True) for dataset in [train_dataset, val_dataset, test_dataset]]

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return f1.compute(predictions=predictions, references=labels)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
id2label = {0: "NON_VIOLATED", 1: "VIOLATED"}
label2id = {"NON_VIOLATED": 0, "VIOLATED": 1}
model = AutoModelForSequenceClassification.from_pretrained("albert-base-v2", num_labels=2, id2label=id2label, label2id=label2id)
model.to(device)

In [None]:
training_args = TrainingArguments(
    output_dir="../models/distilbert_ecthr_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset= train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
task_evaluator = evaluator("text-classification")
results_dict = {}
for metric in ["accuracy", "precision", "recall", "f1"]:
    results = task_evaluator.compute(
        model_or_pipeline="./distilbert_ecthr_model/checkpoint-1126/",
        data=test_dataset,
        metric=metric,
        tokenizer=tokenizer,
        strategy="simple",
        random_state=0,
        input_column='text',
        label_column='labels',
        label_mapping={"NON_VIOLATED": 0.0, "VIOLATED": 1.0},
    )
    metric_name, value = list(results.items())[0]
    results_dict[metric_name] = value

In [None]:
results_dict