In [None]:
from collections import Counter
import shutil
import functools

import numpy as np
import torch
import pandas as pd
import huggingface_hub
from onnxruntime import InferenceSession
from datasets import load_dataset, Dataset
from optimum.onnxruntime import ORTModelForTokenClassification
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
from evaluate import load as load_metric
import sklearn.metrics

In [None]:
huggingface_hub.notebook_login()

In [None]:
BASE_MODEL_CHECKPOINT = "distilbert-base-uncased"
DATA_CACHE = ".cache"

OUTPUT_MODEL_NAME = f"{BASE_MODEL_CHECKPOINT}-on-mini-finer"

TRAINED_MODEL_CHECKPOINT = f"checkpoints/{OUTPUT_MODEL_NAME}"
BATCH_SIZE = 16
USE_CPU = False
N_EPOCHS = 20

HUGGING_FACE_REPOSITORY = f"baluyotraf/{OUTPUT_MODEL_NAME}"
ONNX_OUTPUT_PATH = f"onnx/{OUTPUT_MODEL_NAME}"

In [None]:
PYTORCH_IGNORE = -100

In [None]:
data = load_dataset("nlpaueb/finer-139", cache_dir=DATA_CACHE)
labels = data["train"].features["ner_tags"].feature.names

In [None]:
data_df = data["train"].to_pandas()

In [None]:
ner_counts_per_row = data_df["ner_tags"].map(lambda r: list(Counter(r).keys()))
ner_counts = Counter(ner_counts_per_row.explode())
ner_count_names_df = pd.DataFrame([
    {"idx": idx, "count": count, "label": labels[idx]}
    for idx, count in ner_counts.items() 
    if labels[idx].startswith("B")
]).sort_values("count", ascending=True)

In [None]:
ner_count_names_df.head()

In [None]:
target_ner_df = ner_count_names_df.head(4)
target_label_idxs = set(target_ner_df["idx"])
target_label_names = set(target_ner_df["label"].str[2:])

target_ner_tag_map = {
    old: new
    for new, old in enumerate(
        (idx for idx, label in enumerate(labels) if label[2:] in target_label_names), 1
    )
}
target_labels = [labels[idx] for idx in target_ner_tag_map.keys()]

target_id_to_label = dict(enumerate([labels[0], *target_labels]))
target_label_to_id = {label: id_ for id_, label in target_id_to_label.items()}

In [None]:
def remap_ner_tags(row):
    new_tags = [
        target_ner_tag_map.get(tag, 0)
        for tag in row["ner_tags"]
    ]
    return {"target_ner_tags": new_tags}

target_data = data.filter(lambda x: set(x["ner_tags"]) & target_label_idxs)
target_data = target_data.map(remap_ner_tags)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_CHECKPOINT)
def tokenize_and_align_labels(examples, tokenizer, label_all_tokens=True):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"target_ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(PYTORCH_IGNORE)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else PYTORCH_IGNORE)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_target_data = target_data.map(lambda rows: tokenize_and_align_labels(rows, tokenizer), batched=True)
tokenized_target_data

In [None]:
model = AutoModelForTokenClassification.from_pretrained(BASE_MODEL_CHECKPOINT, num_labels=len(target_id_to_label), id2label=target_id_to_label, label2id=target_label_to_id)

In [None]:
training_args = TrainingArguments(
    TRAINED_MODEL_CHECKPOINT,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=N_EPOCHS,
    weight_decay=0.01,
    use_cpu=USE_CPU
)

In [None]:
metric = load_metric("seqeval")
def compute_metrics(predictions, labels, id2label):
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != PYTORCH_IGNORE]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(prediction, label) if l != PYTORCH_IGNORE]
        for prediction, label in zip(predictions, labels)
    ]

    return metric.compute(predictions=true_predictions, references=true_labels)

def compute_training_metrics(p, id2label):
    results = compute_metrics(*p, id2label)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

def plot_confusion_matrix(predictions, labels, names=None, normalize=None):
    flat_predictions = np.asarray(predictions).argmax(-1).reshape(-1)
    flat_labels = np.asarray(labels).reshape(-1)

    valid_labels = flat_labels != PYTORCH_IGNORE

    confusion_matrix = sklearn.metrics.confusion_matrix(flat_labels[valid_labels], flat_predictions[valid_labels], normalize=normalize)
    display = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix, display_labels=names)
    display.plot()

In [None]:
shutil.rmtree(TRAINED_MODEL_CHECKPOINT, ignore_errors=True)
data_collator = DataCollatorForTokenClassification(tokenizer)
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_target_data["train"],
    eval_dataset=tokenized_target_data["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=functools.partial(compute_training_metrics, id2label=target_id_to_label)
)

In [None]:
test_output = trainer.predict(tokenized_target_data["test"])
compute_metrics(test_output.predictions, test_output.label_ids, target_id_to_label)

In [None]:
training_result = trainer.train()

In [None]:
test_output = trainer.predict(tokenized_target_data["test"])
compute_metrics(test_output.predictions, test_output.label_ids, target_id_to_label)

In [None]:
model.push_to_hub(HUGGING_FACE_REPOSITORY)
tokenizer.push_to_hub(HUGGING_FACE_REPOSITORY)

In [None]:
ort_model = ORTModelForTokenClassification.from_pretrained(HUGGING_FACE_REPOSITORY, export=True)

shutil.rmtree(ONNX_OUTPUT_PATH, ignore_errors=True)
ort_model.save_pretrained(ONNX_OUTPUT_PATH)
tokenizer.save_pretrained(ONNX_OUTPUT_PATH)

In [None]:
def predict_from_model(model, tokenizer, tokenized_data):
    predictions = model(
        input_ids=torch.tensor(tokenized_data["input_ids"], device=model.device),
        attention_mask=torch.tensor(tokenized_data["attention_mask"], device=model.device)
    )
    return predictions.logits.detach().numpy()

In [None]:
tokenized_padded_test_data = target_data["test"].map(lambda rows: tokenize_and_align_labels(rows, functools.partial(tokenizer, padding="longest")), batched=True)

In [None]:
ort_pred = predict_from_model(ort_model, tokenizer, tokenized_padded_test_data)

In [None]:
compute_metrics(ort_pred, tokenized_padded_test_data["labels"], target_id_to_label)

In [None]:
plot_confusion_matrix(ort_pred, tokenized_padded_test_data["labels"])

In [None]:
cpu_model = model.to("cpu")

In [None]:
with torch.no_grad():
    cpu_pred = predict_from_model(cpu_model, tokenizer, tokenized_padded_test_data)

In [None]:
ort_session = InferenceSession(f"{ONNX_OUTPUT_PATH}/model.onnx")

In [None]:
ort_output = ort_session.run(output_names=["logits"], input_feed={key: tokenized_padded_test_data[key] for key in ["input_ids", "attention_mask"]})

In [None]:
compute_metrics(ort_output[0], tokenized_padded_test_data["labels"], target_id_to_label)

In [None]:
plot_confusion_matrix(ort_output[0], tokenized_padded_test_data["labels"])