In [53]:
from datasets import load_dataset, Dataset
from transformers import TrainingArguments, Trainer, LongformerTokenizer, LongformerForSequenceClassification, pipeline
import torch
from transformers import DataCollatorWithPadding
import evaluate
import numpy as np
from evaluate import evaluator
import torch.cuda
import pandas as pd
from sklearn.metrics import classification_report

In [54]:
echr = load_dataset("ecthr_cases",  "violation-prediction")

In [55]:
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

In [56]:
def encode(examples):
    return tokenizer( examples["text"],
                     truncation=True, 
                     padding=True)

In [57]:
train_dataset, val_dataset, test_dataset = echr['train'], echr['validation'], echr['test']
train_dataset, val_dataset, test_dataset = [dataset.map( lambda examples: {"text": "\n".join(examples["facts"])}) for dataset in [train_dataset, val_dataset, test_dataset]]
train_dataset, val_dataset, test_dataset = [dataset.map(encode, batched=True) for dataset in [train_dataset, val_dataset, test_dataset]]
train_dataset, val_dataset, test_dataset = [dataset.map( lambda examples: {'labels' :list(1 if examples['labels'][i] else 0 for i in range(len(examples['labels'])))}, batched=True) for dataset in [train_dataset, val_dataset, test_dataset]]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [58]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [59]:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

In [62]:
device = "cpu"

In [63]:
id2label = {0: "NON_VIOLATED", 1: "VIOLATED"}
label2id = {"NON_VIOLATED": 0, "VIOLATED": 1}
model = LongformerForSequenceClassification.from_pretrained("../models/longformer_ecthr_model/checkpoint-382", num_labels=2, id2label=id2label, label2id=label2id)
model.to(device)

LongformerForSequenceClassification(
  (longformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0-11): 12 x LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
          

In [64]:
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device)

In [69]:
results = pipe(tokenizer.decode(case, clean_up_tokenization_spaces=True, skip_special_tokens=True) for case in test_dataset['input_ids'])

In [70]:
predictions = list(result['label'] for result in results)

Input ids are automatically padded from 4095 to 4096 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 4091 to 4096 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 4092 to 4096 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 2360 to 2560 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 992 to 1024 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 417 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 2848 to 3072 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 1316 to 1536 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 1241 to 1536 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 555 to 1024 to be a multipl

In [71]:
report = classification_report(list('NON_VIOLATED' if outcome==0 else 'VIOLATED' for outcome in test_dataset['labels']), predictions)

In [72]:
print(report)

              precision    recall  f1-score   support

NON_VIOLATED       0.31      0.76      0.44       135
    VIOLATED       0.95      0.73      0.83       865

    accuracy                           0.74      1000
   macro avg       0.63      0.74      0.63      1000
weighted avg       0.86      0.74      0.77      1000



: 