# Token Classification NER

## About
- Finetune DistilBERT on the WNUT 17 dataset to detect new entities. (NER)
- Use your finetuned model for inference.

## Loading WNUT 17 dataset

In [1]:
from datasets import load_dataset

wnut = load_dataset("wnut_17")

Downloading data:   0%|          | 0.00/494k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/115k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/192k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3394 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1009 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1287 [00:00<?, ? examples/s]

In [2]:
wnut['train'][0]

{'id': '0',
 'tokens': ['@paulwalk',
  'It',
  "'s",
  'the',
  'view',
  'from',
  'where',
  'I',
  "'m",
  'living',
  'for',
  'two',
  'weeks',
  '.',
  'Empire',
  'State',
  'Building',
  '=',
  'ESB',
  '.',
  'Pretty',
  'bad',
  'storm',
  'here',
  'last',
  'evening',
  '.'],
 'ner_tags': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  7,
  8,
  8,
  0,
  7,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0]}

In [4]:
label_list = wnut['train'].features[f'ner_tags'].feature.names
label_list

['O',
 'B-corporation',
 'I-corporation',
 'B-creative-work',
 'I-creative-work',
 'B-group',
 'I-group',
 'B-location',
 'I-location',
 'B-person',
 'I-person',
 'B-product',
 'I-product']

## Preprocessing

In [5]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')

In [6]:
example = wnut['train'][0]
tokenized_input = tokenizer(
    example['tokens'],
    is_split_into_words=True
)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input['input_ids'])
tokens

['[CLS]',
 '@',
 'paul',
 '##walk',
 'it',
 "'",
 's',
 'the',
 'view',
 'from',
 'where',
 'i',
 "'",
 'm',
 'living',
 'for',
 'two',
 'weeks',
 '.',
 'empire',
 'state',
 'building',
 '=',
 'es',
 '##b',
 '.',
 'pretty',
 'bad',
 'storm',
 'here',
 'last',
 'evening',
 '.',
 '[SEP]']

In [7]:
def tokenize_and_align_labels(example):
    toenized_inputs = tokenizer(
        examples['tokens'],
        truncation=True,
        is_split_into_words=True
    )

    labels= []
    for i, label in enumerate(examples[f'ner_tags']):
        word_ids = tokenized_input.word_ids(batch_index=i)
        prev_word_idx = None
        labels_ids =[]
        for word_idx in word_ids:
            if woprd_idx is None:
                label_ids.append(-100)
            elif word_idx != prev_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            prev_word_idx = word_idx
        labels.append(label_ids)
    
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [None]:
tokenized_wnut = wnut.map(
    tokenize_and_align_labels,
    batched=True
)

In [None]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

## Evaluate

In [None]:
import evaluate
seqeval = evaluate.load('seqeval')

In [None]:
import numpy as np
labels = [label_list[i] for i in example[f'ner_tags']]

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p,l) in zip(prediction, label) if l != 100]
        for prediction, label in zip(predictions, labels)
    ]

    true_labels = [
        [labels_list[l] for (p,l) in zip(prediction, label) if l!=100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(
        predictons=true_predictions,
        references=true_labels
    )

    return {
        'precision': results['overall_precision'],
        'recall': results['overall_recall'],
        'f1': results['overall_f1'],
        'accuracy': results['overall_accuracy']
    }

## Train

In [None]:
id2label = {
    0: "O",
    1: "B-corporation",
    2: "I-corporation",
    3: "B-creative-work",
    4: "I-creative-work",
    5: "B-group",
    6: "I-group",
    7: "B-location",
    8: "I-location",
    9: "B-person",
    10: "I-person",
    11: "B-product",
    12: "I-product",
}

label2id = {
    "O": 0,
    "B-corporation": 1,
    "I-corporation": 2,
    "B-creative-work": 3,
    "I-creative-work": 4,
    "B-group": 5,
    "I-group": 6,
    "B-location": 7,
    "I-location": 8,
    "B-person": 9,
    "I-person": 10,
    "B-product": 11,
    "I-product": 12,
}

In [None]:
from transformers import AutoModelForTokenCLassification, TrainingArguments, Trainer
model = AutoModelForTokenCLassification.from_pretrained(
    "distilbert/distilbert-base-uncased",
    num_labels=13,
    id2label=id2label,
    label2id=label2id
)

In [None]:
training_args = TrainingArguments(
    output_dir="models/wnut_topic_cls_model",
    learning_rate=2e-5,
    per_drive_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_wnut['train'],
    eval_dataset=tokenized_wnut['test'],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

## Inference

In [None]:
text = "The Golden State Warriors are an American professional basketball team based in San Francisco."

In [None]:
from transformers import pipeline

classifier = pipeline('ner', model="models/wnut_topic_cls_model")
classifier(text)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("models/wnut_topic_cls_model")
inputs = tokenizer(text, return_tensors='pt')


In [None]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("models/wnut_topic_cls_model")

with torch.no_grad():
    logits = model(**inputs).logits

In [None]:
predictions = torch.argmax(logits, dim=2)
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]] 
predicted_token_class