In [8]:
from datasets import load_dataset, Dataset
from transformers import TrainingArguments, Trainer, LongformerTokenizer, LongformerForSequenceClassification, pipeline
import torch
from transformers import DataCollatorWithPadding
import evaluate
import numpy as np
from evaluate import evaluator
import torch.cuda
import pandas as pd
from sklearn.metrics import classification_report
from nltk.corpus import stopwords

In [9]:
echr = load_dataset("ecthr_cases",  "violation-prediction")
stop_words = stopwords.words("english")

In [10]:
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

In [11]:
def encode(examples):
    return tokenizer( examples["text"],
                     truncation=True, 
                     padding=True)

In [12]:
test_dataset = echr['test']
test_dataset = test_dataset.map( lambda examples: {"text": "\n".join(examples["facts"])})
test_dataset = test_dataset.map( lambda examples: {'labels' :list(1 if examples['labels'][i] else 0 for i in range(len(examples['labels'])))}, batched=True)
test_df = pd.DataFrame(test_dataset)
test_df['text'] = test_df['text'].str.lower().str.split().apply(lambda x: [item for item in x if item not in stop_words]).apply(lambda x: " ".join(x)).replace('\d+', '', regex=True)
test_dataset = Dataset.from_pandas(test_df)
test_dataset = test_dataset.map(encode, batched=True)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
id2label = {0: "NON_VIOLATED", 1: "VIOLATED"}
label2id = {"NON_VIOLATED": 0, "VIOLATED": 1}
model = LongformerForSequenceClassification.from_pretrained("../models/longformer_ecthr_model/removed_stopwords/max_words/accuracy/checkpoint-382", num_labels=2, id2label=id2label, label2id=label2id)
model.to(device)

LongformerForSequenceClassification(
  (longformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0-11): 12 x LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
          

In [15]:
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device)

In [16]:
results = pipe(tokenizer.decode(case, clean_up_tokenization_spaces=True, skip_special_tokens=True) for case in test_dataset['input_ids'])

In [17]:
predictions = list(result['label'] for result in results)

Initializing global attention on CLS token...
Input ids are automatically padded from 3427 to 3584 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 881 to 1024 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 309 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 521 to 1024 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 1309 to 1536 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 1342 to 1536 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 1402 to 1536 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 1854 to 2048 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 926 to 1024 to be a multiple of `config.attention_window`: 512
Input ids are automatic

In [18]:
report = classification_report(list('NON_VIOLATED' if outcome==0 else 'VIOLATED' for outcome in test_dataset['labels']), predictions)

In [19]:
print(report)

              precision    recall  f1-score   support

NON_VIOLATED       0.15      0.98      0.26       135
    VIOLATED       0.98      0.15      0.26       865

    accuracy                           0.26      1000
   macro avg       0.57      0.56      0.26      1000
weighted avg       0.87      0.26      0.26      1000

