In [None]:
from typing import List
import torch
import torch.nn as nn
import numpy as np

## 1. Data Preparation

In [None]:
from datasets import load_dataset

dataset = load_dataset("conll2003")

In [None]:
dataset = dataset.remove_columns(["id", "pos_tags", "chunk_tags"])
dataset

In [None]:
dataset_train = dataset["train"]
dataset_val = dataset["validation"]
dataset_test = dataset["test"]

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
MAX_LEN = 113

In [None]:
from torch.utils.data import Dataset

class NER_Dataset(Dataset):
    def __init__(self, dataset, tokenizer):
        super().__init__()
        self.tokens = dataset["tokens"]
        self.labels = dataset["ner_tags"]
        self.tokenizer = tokenizer
        self.max_len = MAX_LEN 

    def __len__(self):
        return len(self.tokens)
    
    def __getitem__(self, idx):
        input_token = self.tokens[idx]
        label_token = self.labels[idx]

        input_token = self.tokenizer.convert_tokens_to_ids(input_token)
        attention_mask = [1] * len(input_token)

        input_ids = self.pad_and_truncate(input_token, pad_id= self.tokenizer.pad_token_id)
        labels = self.pad_and_truncate(label_token, pad_id=0)
        attention_mask =  self.pad_and_truncate(attention_mask, pad_id=0)

        return {
            "input_ids": torch.as_tensor(input_ids), 
            "labels": torch.as_tensor(labels),
            "attention_mask": torch.as_tensor(attention_mask)
            }
    
    def pad_and_truncate(self, inputs: List[int], pad_id: int): 
        if len(inputs) < self.max_len:
            padded_inputs = inputs + [pad_id] * (self.max_len - len(inputs))
        else:
            padded_inputs = inputs[:self.max_len]
        return padded_inputs

In [None]:
train_set = NER_Dataset(dataset_train, tokenizer)
val_set = NER_Dataset(dataset_val, tokenizer)
test_set = NER_Dataset(dataset_test, tokenizer)

## 2. Model

In [None]:
from transformers import AutoModelForTokenClassification

label2id = {
    'O': 0, 
    'B-PER': 1, 
    'I-PER': 2, 
    'B-ORG': 3, 
    'I-ORG': 4, 
    'B-LOC': 5, 
    'I-LOC': 6, 
    'B-MISC': 7, 
    'I-MISC': 8,
}
id2label = {v:k for k, v in label2id.items()}

model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
model

## 3. Training

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    mask = labels != 0
    predictions = np.argmax(predictions, axis=-1)
    return accuracy.compute(predictions=predictions[mask], references=labels[mask])

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="out_dir",
    learning_rate=1e-4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    optim="adamw_torch"
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer = tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

## 4. Testing

In [None]:
trainer.evaluate(test_set)

## 5. Test sample

In [None]:
# test_sentence = "[UNK] rejects [UNK] call to boycott [UNK] lamb."
test_sentence = "France won the World Cup in Russia in 2018"
inputs = tokenizer(test_sentence, return_tensors="pt", add_special_tokens=False) # Use the function as training data
inputs

In [None]:
for key, value in inputs.items():
    inputs[key] = inputs[key].to("cuda")

In [None]:
outputs = model(**inputs)
outputs.logits.shape

In [None]:
_, preds = torch.max(outputs.logits, -1)
preds = preds[0].cpu().numpy()
preds

In [None]:
pred_tags = ""

for pred in preds:
    pred_tags += id2label[pred] + " "

pred_tags