# DistilBERT Fine Tuning

Tune the DistilBERT model for a token classification problem on the `nlpaueb/finer-139` dataset

In [None]:
import calendar
import functools
import shutil
from collections import Counter

import huggingface_hub
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics
import torch
from datasets import load_dataset
from evaluate import load as load_metric
from matplotlib import pyplot as plt
from onnxruntime import InferenceSession
from optimum.onnxruntime import ORTModelForTokenClassification
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    Trainer,
    TrainingArguments,
)

sns.set()
sns.set_palette(sns.color_palette("colorblind"))

In [None]:
huggingface_hub.notebook_login()

Define constants to make them configurable on the run. The constants are grouped by the following:

*   Input related constants
*   Training related constants
*   Output related constants

In [None]:
STOPWORDS_FILE = "stopwords.txt"
BASE_MODEL_CHECKPOINT = "distilbert-base-uncased"
DATA_SOURCE = "nlpaueb/finer-139"
DATA_CACHE = ".cache"

OUTPUT_MODEL_NAME = f"{BASE_MODEL_CHECKPOINT}-on-mini-finer"
TRAINED_MODEL_CHECKPOINT = f"checkpoints/{OUTPUT_MODEL_NAME}"
BATCH_SIZE = 16
USE_CPU = False
N_EPOCHS = 20

HUGGING_FACE_REPOSITORY = f"baluyotraf/{OUTPUT_MODEL_NAME}"
ONNX_OUTPUT_PATH = f"onnx/{OUTPUT_MODEL_NAME}"

PyTorch uses -100 to present labels related to padding. Define it a constant for ease of use.

In [None]:
PYTORCH_IGNORE = -100

Load the `finer-139` dataset using the `datasets` library from Hugging Face and check the labels from the dataset.

In [None]:
data = load_dataset(DATA_SOURCE, cache_dir=DATA_CACHE)
labels = data["train"].features["ner_tags"].feature.names

Convert the data to a `pandas.DataFrame` to have easier access to the different data utilities it offers

In [None]:
data_df = data["train"].to_pandas()

A function is defined to calculate the word count for each of the tagged token. This should help on getting an intuition on what the label means. 

Aside from the token and tags, the `calculate_word_count_per_tag` function also has a `surround` parameter to allow adding the surrounding tokens to the count. This should allow viewing the surrounding words to get more context about the label. A `count_filter` parameter is also provided to remove possible words that do not provide a lot of information, like stop words.

A `filter_no_information_words` was also defined to help in getting most sense of the labels. The rules were defined iteratively based on the data, but generally it's removal or stopwords, months, and money related symbols.

In [None]:
def calculate_word_count_per_tag(df, tokens, tags, surround=(0, 1), count_filter=lambda w: True):
    words_on_tag = {}
    for tokens_, tags_ in zip(df[tokens], df[tags]):
        tokens_ = [token_.lower() for token_ in tokens_]
        for idx, tag_ in enumerate(tags_):
            if tag_ != 0:
                min_idx, max_idx = (s + idx for s in surround)
                min_idx = max(0, min_idx)
                max_idx = min(len(tokens_), max_idx)

                words_on_tag.setdefault(tag_, []).extend(tokens_[min_idx:max_idx])

    word_count_on_tag = {
        tag_: Counter((token_ for token_ in tokens_ if count_filter(token_))) for tag_, tokens_ in words_on_tag.items()
    }
    return word_count_on_tag


STOPWORDS = set()
with open(STOPWORDS_FILE) as f:
    for line in f.readlines():
        STOPWORDS.add(line.strip())

CALENDAR_NAMES = {calendar.month_name[idx].lower() for idx in range(0, 13)}
NUMERIC_WORDS = {
    "million",
    "billion",
}
NUMERIC_SYMBOLS = {",", "$", ".", "-", "%"}
NO_INFORMATION_WORDS = STOPWORDS | CALENDAR_NAMES | NUMERIC_WORDS | NUMERIC_SYMBOLS


def filter_no_information_words(w):
    try:
        float(w)
        return False
    except ValueError:
        pass

    if w in NO_INFORMATION_WORDS:
        return False

    return True


word_count_on_tag = calculate_word_count_per_tag(
    df=data_df, tokens="tokens", tags="ner_tags", surround=(-5, 6), count_filter=filter_no_information_words
)

The number of sentences containing the tag was also computed. The unique tags per sentence were extracted and then counted. The count only focused on the `B-` labels as they mark the beginning of an entity.

In [None]:
ner_counts_per_row = data_df["ner_tags"].map(lambda r: set(r))
ner_counts = Counter(ner_counts_per_row.explode())
ner_count_names_df = pd.DataFrame(
    [
        {"idx": idx, "count": count, "label": labels[idx]}
        for idx, count in ner_counts.items()
        if labels[idx].startswith("B")
    ]
).sort_values("count", ascending=True)
ner_count_names_df.head()

The focus will be the labels with the fewest members. This was mostly done for computation time and resources reasons but the code can be extended to any number of labels.

In [None]:
target_ner_df = ner_count_names_df.head(4)

The distribution of the smallest labels does not have a large difference between them. The samples are kept the same without using any augmentation techniques.

In [None]:
def plot_count_distribution(df, x, y):
    ax = sns.barplot(df, x=x, y=y)
    ax.set_title("Count of the target labels")
    plt.show()


plot_count_distribution(target_ner_df, x="count", y="label")

The common words around the labels are printed out to get more understanding of the labels.

In [None]:
for idx, label in zip(target_ner_df["idx"], target_ner_df["label"]):
    print(label)
    for word, count in word_count_on_tag[idx].most_common(20):
        print(f"\t{word}: {count}")

The `nlpaueb/finer-139` data contained a lot of labels however, this exercise only used 4 labels. This means that there was a need to remap the labels. To do this, a mapping was created with the index of the label in the data to the 4 label problem.

In [None]:
target_label_idxs = set(target_ner_df["idx"])
target_label_names = set(target_ner_df["label"].str[2:])

target_ner_tag_map = {
    old: new for new, old in enumerate((idx for idx, label in enumerate(labels) if label[2:] in target_label_names), 1)
}
target_labels = [labels[idx] for idx in target_ner_tag_map.keys()]

target_id_to_label = dict(enumerate([labels[0], *target_labels]))
target_label_to_id = {label: id_ for id_, label in target_id_to_label.items()}

The data was filtered to remove the data without the target tags. This was not mandatory, but was also done for the sake of computation time and resources. The labels were also mapped to the new labels, thus, a column `target_ner_tags` was created from the `ner_tag` column

In [None]:
def remap_ner_tags(row):
    new_tags = [target_ner_tag_map.get(tag, 0) for tag in row["ner_tags"]]
    return {"target_ner_tags": new_tags}


target_data = data.filter(lambda x: set(x["ner_tags"]) & target_label_idxs)
target_data = target_data.map(remap_ner_tags)

A tokenizer was loaded from the `DistilBERT` repository and a function to tokenize the words and align the labels was defined. The alignment was performed since the tokenizer can split up words to subwords and can produce more tokens than the words.

The implementation of the `tokenize_and_align_labels` was mostly taken from the `DistilBERT` documentation, with some minor refactors.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_CHECKPOINT)


def tokenize_and_align_labels(examples, tokenizer, label_all_tokens=True):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples["target_ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(PYTORCH_IGNORE)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else PYTORCH_IGNORE)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_target_data = target_data.map(lambda rows: tokenize_and_align_labels(rows, tokenizer), batched=True)
tokenized_target_data

A token classification was defined using the `DistilBERT`. The number of labels and the mapping were defined to fit the smaller problem defined here.

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    BASE_MODEL_CHECKPOINT, num_labels=len(target_id_to_label), id2label=target_id_to_label, label2id=target_label_to_id
)

The training arguments were defined. Evaluation, checkpoints and logging were done at each epoch. The best model was also saved in the end. Otherwise, the values were taken from the `DistilBERT` documentation.

In [None]:
training_args = TrainingArguments(
    TRAINED_MODEL_CHECKPOINT,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=N_EPOCHS,
    weight_decay=0.01,
    use_cpu=USE_CPU,
)

Metric helpers for model evaluation were defined. There are three types of metrics defined:

*   Sequence metrics with per label metrics
*   Overall metrics for training
*   Confusion matrix for complete prediction picture

In [None]:
metric = load_metric("seqeval")


def compute_metrics(predictions, labels, id2label):
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [id2label[pred] for (pred, lbl) in zip(prediction, label) if lbl != PYTORCH_IGNORE]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2label[lbl] for (pred, lbl) in zip(prediction, label) if lbl != PYTORCH_IGNORE]
        for prediction, label in zip(predictions, labels)
    ]

    return metric.compute(predictions=true_predictions, references=true_labels)


def compute_training_metrics(p, id2label):
    results = compute_metrics(*p, id2label)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


def plot_confusion_matrix(predictions, labels, names=None, normalize=None):
    flat_predictions = np.asarray(predictions).argmax(-1).reshape(-1)
    flat_labels = np.asarray(labels).reshape(-1)

    valid_labels = flat_labels != PYTORCH_IGNORE

    confusion_matrix = sklearn.metrics.confusion_matrix(
        flat_labels[valid_labels], flat_predictions[valid_labels], normalize=normalize
    )
    display = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix, display_labels=names)
    display.plot()

A `DataCollatorForTokenClassification` was defined to save memory. This padded based on the batch length rather than the global maximum length.

This was used on the `Trainer` that used the training args, the tokenized dataset, and the metrics function for training. The tokenizer was also provided for proper padding and future fine tuning.

In [None]:
shutil.rmtree(TRAINED_MODEL_CHECKPOINT, ignore_errors=True)
data_collator = DataCollatorForTokenClassification(tokenizer)
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_target_data["train"],
    eval_dataset=tokenized_target_data["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=functools.partial(compute_training_metrics, id2label=target_id_to_label),
)

Check the performance of the model before training. This was a very bad performance since the out-of-the-box model was not trained on any labels before hand. This is only a benchmark to see the performance improvement

In [None]:
test_output = trainer.predict(tokenized_target_data["test"])
compute_metrics(test_output.predictions, test_output.label_ids, target_id_to_label)

Run the training. Training metrics were displayed along side the process.

In [None]:
training_result = trainer.train()

Check the performance of the model to verify the training

In [None]:
test_output = trainer.predict(tokenized_target_data["test"])
compute_metrics(test_output.predictions, test_output.label_ids, target_id_to_label)

Upload the data to the `Hugging Face` repository. This made the model publicly available and easily reusable.

In [None]:
model.push_to_hub(HUGGING_FACE_REPOSITORY)
tokenizer.push_to_hub(HUGGING_FACE_REPOSITORY)

Create an `ONNX Runtime` model using the `Hugging Face` utilities. The tokenizer was also saved on the same path so that the complete pipeline can be read from the same path.

In [None]:
ort_model = ORTModelForTokenClassification.from_pretrained(HUGGING_FACE_REPOSITORY, export=True)

shutil.rmtree(ONNX_OUTPUT_PATH, ignore_errors=True)
ort_model.save_pretrained(ONNX_OUTPUT_PATH)
tokenizer.save_pretrained(ONNX_OUTPUT_PATH)

A utitilty function `predict_from_model` predicts the labels with a model, tokenizer, and the tokenized data. This helped in comparing the `ONNX` model with the PyTorch model.

In [None]:
def predict_from_model(model, tokenizer, tokenized_data):
    predictions = model(
        input_ids=torch.tensor(tokenized_data["input_ids"], device=model.device),
        attention_mask=torch.tensor(tokenized_data["attention_mask"], device=model.device),
    )
    return predictions.logits.detach().numpy()

A padded data was created for prediction purposes.

In [None]:
tokenized_padded_test_data = target_data["test"].map(
    lambda rows: tokenize_and_align_labels(rows, functools.partial(tokenizer, padding="longest")), batched=True
)

The metrics calculated from the `ONNX Runtime` model seems to match the training results. To get more information about the predictions, the confusion matrix was also extracted.

In [None]:
ort_pred = predict_from_model(ort_model, tokenizer, tokenized_padded_test_data)

In [None]:
compute_metrics(ort_pred, tokenized_padded_test_data["labels"], target_id_to_label)

In [None]:
plot_confusion_matrix(ort_pred, tokenized_padded_test_data["labels"])

The actual PyTorch model was also tested instead of relying on the `Trainer.predict` method

In [None]:
cpu_model = model.to("cpu")

In [None]:
with torch.no_grad():
    cpu_pred = predict_from_model(cpu_model, tokenizer, tokenized_padded_test_data)

The `ONNX Runtime` model was also run using the `ONNX Runtime` API to make sure that the result was stable.

In [None]:
ort_session = InferenceSession(f"{ONNX_OUTPUT_PATH}/model.onnx")

In [None]:
ort_output = ort_session.run(
    output_names=["logits"],
    input_feed={key: tokenized_padded_test_data[key] for key in ["input_ids", "attention_mask"]},
)

In [None]:
compute_metrics(ort_output[0], tokenized_padded_test_data["labels"], target_id_to_label)

In [None]:
plot_confusion_matrix(ort_output[0], tokenized_padded_test_data["labels"])