In [None]:
!pip install evaluate seqeval
!pip install datasets==2.15.0
!pip install accelerate -U

In [None]:
from datasets import load_dataset
from transformers import (AutoTokenizer,
                          AutoModelForTokenClassification,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForTokenClassification,
                         )
from itertools import chain

from tqdm import tqdm
import torch
import os
import numpy as np
import evaluate
import warnings

warnings.filterwarnings("ignore")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
system = 'B'
nr_epochs = 1
metric = evaluate.load("seqeval")
confusion_matrix = evaluate.load("BucketHeadP65/confusion_matrix")
dataset = load_dataset("Babelscape/multinerd")

In [None]:
print(np.unique(dataset['train']['lang']))
dataset = dataset.filter(lambda example: example['lang'] == 'en')
print(np.unique(dataset['train']['lang']))

In [None]:
ner_tags_dict = {
    "O": 0,
    "B-PER": 1,
    "I-PER": 2,
    "B-ORG": 3,
    "I-ORG": 4,
    "B-LOC": 5,
    "I-LOC": 6,
    "B-ANIM": 7,
    "I-ANIM": 8,
    "B-BIO": 9,
    "I-BIO": 10,
    "B-CEL": 11,
    "I-CEL": 12,
    "B-DIS": 13,
    "I-DIS": 14,
    "B-EVE": 15,
    "I-EVE": 16,
    "B-FOOD": 17,
    "I-FOOD": 18,
    "B-INST": 19,
    "I-INST": 20,
    "B-MEDIA": 21,
    "I-MEDIA": 22,
    "B-MYTH": 23,
    "I-MYTH": 24,
    "B-PLANT": 25,
    "I-PLANT": 26,
    "B-TIME": 27,
    "I-TIME": 28,
    "B-VEHI": 29,
    "I-VEHI": 30,
  }

In [None]:
label_list = [key for key in ner_tags_dict.keys()]
if system == 'A':
  id2label = {i: label for i, label in enumerate(label_list)}
  label2id = {v: k for k, v in id2label.items()}
elif system == 'B':
  allowed_tags = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG',
                        'B-LOC', 'I-LOC', 'B-ANIM', 'I-ANIM', 'B-DIS', 'I-DIS']

  allowed_values = [i[1] for i in ner_tags_dict.items() if i[0] in allowed_tags]
  tags_values = {i: j for j, i in enumerate(allowed_values)}
  def replace_values(example):
    feature_values = example['ner_tags']
    #set any values outside of allowed_tags to 0
    replaced_values = [val if val in allowed_values else 0 for val in feature_values]
    replaced_values = [tags_values[i] for i in replaced_values]
    example['ner_tags'] = replaced_values
    return example

  dataset = dataset.map(replace_values)
  label_list = [i for i in allowed_tags]
  id2label = {i: label for i, label in enumerate(label_list)}
  label2id = {v: k for k, v in id2label.items()}

In [None]:
model_name_or_path = 'distilbert-base-cased'
tokenizer_name_or_path = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
model = AutoModelForTokenClassification.from_pretrained(
    model_name_or_path,
    num_labels = len(label_list),
    id2label=id2label,
    label2id=label2id,
)
model.to(device)

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
tokenized_datasets = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
output_dir = "./distilbert-base-cased-system-A" if system == 'A' else "./distilbert-base-cased-system-B"

In [None]:
args = TrainingArguments(
    output_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=nr_epochs,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)
trainer.train()

In [None]:
def compute_inference_metrics(label, pred):
    
    true_predictions = [label_list[p] for (p, l) in zip(pred, label) if l != -100]
    true_labels = [label_list[l] for (p, l) in zip(pred, label) if l != -100]

    confusion_pred = [label2id[i] for i in true_predictions]
    confusion_label = [label2id[i] for i in true_labels]
    results = metric.compute(predictions=[true_predictions], references=[true_labels])
    confusion = confusion_matrix.compute(predictions=confusion_pred, references=confusion_label)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }, confusion, results

In [None]:
ckpt_dirs = os.listdir(output_dir)
if 'runs' in ckpt_dirs:
  ckpt_dirs.remove('runs')
ckpt_dirs = sorted(ckpt_dirs, key=lambda x: int(x.split('-')[1]))
last_ckpt = ckpt_dirs[-1]
tokenizer = AutoTokenizer.from_pretrained(f"{output_dir}/"+last_ckpt)
model = AutoModelForTokenClassification.from_pretrained(f"{output_dir}/"+last_ckpt).to(device)
model = model.eval()

y_test = []
y_pred = []
for inputs in tqdm(tokenized_datasets['test']):
  y_test.append(inputs['labels'])
  with torch.no_grad():
      inputs = {'input_ids': torch.Tensor([inputs['input_ids']]).long().to(device),
                'attention_mask': torch.Tensor([inputs['attention_mask']]).long().to(device)}
      logits = model(**inputs).logits
      pred = np.argmax(logits.cpu().numpy(), axis = 2)[0]
  y_pred.append(pred)


y_test = list(chain.from_iterable(y_test))
y_pred =  list(chain.from_iterable(y_pred))
score, confusion, results = compute_inference_metrics(y_test, y_pred)

In [None]:
confusion['confusion_matrix'].tolist(), score

In [None]:
results