# Finetuning RoBERTa for NER: Evaluate Model
 

***

## Imports

In [1]:
from transformers import (BertTokenizerFast,
                          RobertaTokenizerFast,
                          AutoTokenizer,
                          BertForTokenClassification,
                          RobertaForTokenClassification,
                          DataCollatorForTokenClassification, 
                          AutoModelForTokenClassification, 
                          TrainingArguments, Trainer)
from datasets import load_dataset, load_metric, concatenate_datasets, DatasetDict
from pprint import pprint
import numpy as np
import pickle
import torch
import os

## Load Dataset

In [2]:
data_path = "./data/dataset_processed.pkl"
with open(data_path, 'rb') as pickle_file:
    dataset = pickle.load(file=pickle_file)

## Load Model and Tokenizer

Information about model variants can be found here: https://huggingface.co/docs/transformers/model_doc/roberta

Load Model which was finetuned:

In [3]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [4]:
label_list = dataset["train"].features[f"ner_tags"].feature.names

In [5]:
# model_name = "xlm-roberta-large" #"bert-base-multilingual-cased" #xlm-roberta-large
tokenizer = AutoTokenizer.from_pretrained("./results/checkpoint-final/", add_prefix_space=True) #AutoTokenizer(use_fast = True)
model = AutoModelForTokenClassification.from_pretrained("./results/checkpoint-final/")

**Define Metrics:**

See https://huggingface.co/course/chapter7/2#metrics

In [6]:
metric = load_metric("seqeval")

  metric = load_metric("seqeval")


In [7]:
print(dataset["train"][150])

{'tokens': ['Auch', 'anschließend', 'blieb', 'er', 'der', 'Mannschaft', 'von', 'Trainer', 'Per', 'Olsson', 'treu', ',', 'im', 'Sommer', '2011', 'verlängerte', 'er', 'seinen', 'Kontrakt', 'erneut', 'langfristig', '.'], 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'langs': ['de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de'], 'spans': ['PER: Per Olsson'], 'input_ids': [12717, 133177, 178814, 72, 122, 132002, 542, 119205, 908, 9295, 4503, 1360, 34, 6, 4, 566, 29924, 1392, 241960, 13, 72, 25080, 3692, 44962, 119054, 165335, 6, 5], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


In [8]:
example = dataset["train"][150]
labels = [label_list[i] for i in example[f"ner_tags"]]
metric.compute(predictions=[labels], references=[labels])

{'PER': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'overall_precision': 1.0,
 'overall_recall': 1.0,
 'overall_f1': 1.0,
 'overall_accuracy': 1.0}

**Calculate Accuracy:**

In [9]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [10]:
predictions, labels, _ = trainer.predict(dataset["test"])
predictions = np.argmax(predictions, axis=-1)

The following columns in the test set don't have a corresponding argument in `XLMRobertaForTokenClassification.forward` and have been ignored: spans, ner_tags, langs, tokens. If spans, ner_tags, langs, tokens are not expected by `XLMRobertaForTokenClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 20000
  Batch size = 16
You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [11]:
label_names = dataset["train"].features[f"ner_tags"].feature.names

In [12]:
true_labels = [
    [label_names[l] for l in label  if l != -100] 
    for label in labels
]

true_predictions = [
    [label_names[p] for (p, l) in zip(prediction, label)  if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
pprint(results)

{'LOC': {'f1': 0.8512989302927002,
         'number': 21063,
         'precision': 0.843191132637854,
         'recall': 0.8595641646489104},
 'ORG': {'f1': 0.7326220690065083,
         'number': 16972,
         'precision': 0.7391017569107153,
         'recall': 0.7262550082488806},
 'PER': {'f1': 0.8671199011124845,
         'number': 14649,
         'precision': 0.8723316062176166,
         'recall': 0.8619701003481466},
 'overall_accuracy': 0.9284292732082127,
 'overall_f1': 0.8177536369506591,
 'overall_precision': 0.8182198236546062,
 'overall_recall': 0.817287981170754}
