# Fine-Tuning BERT on GLUE - MNLI

From [GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding - Wang et al.](https://arxiv.org/pdf/1804.07461):

The Multi-Genre Natural Language Inference Corpus (Williams et al., 2018) is a crowdsourced collection of sentence pairs with textual entailment annotations. Given a premise sentence
and a hypothesis sentence, the task is to predict whether the premise entails the hypothesis (entailment), contradicts the hypothesis (contradiction), or neither (neutral). The premise sentences are
gathered from ten different sources, including transcribed speech, fiction, and government reports.
We use the standard test set, for which we obtained private labels from the authors, and evaluate
on both the matched (in-domain) and mismatched (cross-domain) sections. We also use and recommend the SNLI corpus (Bowman et al., 2015) as 550k examples of auxiliary training data.

## 0. Configuration

In [1]:
import os

# Where to store the huggingface data. On the provided Jupyterlab instance that should be within the shared group folder.
os.environ['HF_HOME'] = '../groups/192.039-2024W/bert/huggingface/cache'

In [2]:
import pandas as pd

pd.options.mode.chained_assignment = None

In [3]:
from pathlib import Path

import numpy as np
from transformers import set_seed

# RANDOMNESS SEED
SEED = 42
set_seed(SEED)
np.random.seed(SEED)

# Which dataset to load
DATASET_NAME = "glue"
DATASET_TASK = "mnli"

PRE_TRAINED_CHECKPOINT = "google-bert/bert-base-uncased"

TRAIN_OUTPUT_DIR = (
    Path("../groups/192.039-2024W/bert") / "training" / f"{DATASET_NAME}-{DATASET_TASK}"
)

BATCH_SIZE = 32  # Original Paper claims to use 32 for GLUE tasks
NUM_EPOCHS = 5  # Original Paper claims to use 3 fine-tuning epochs for GLUE tasks

In [None]:
import torch

if torch.cuda.is_available():
  device = torch.device("cuda")
  device_count = torch.cuda.device_count()
  device_name = torch.cuda.get_device_name(0)

  print(f"There are {device_count} GPU(s) available.")
  print(f"GPU used: {device_name}")
  ! nvidia-smi -q --display=MEMORY,COMPUTE

else:
  print("No GPU available, using CPU.")
  device = torch.device("cpu")

## Dataset

In [5]:
# In the GLUE dataset different tasks have different accessor keys
_task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
}

In [None]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset(DATASET_NAME, DATASET_TASK)
dataset

The MNLI dataset is special, as it provides two validation and two test datasets. One matched (in-domain) and one mismatched (cross-domain).

In [None]:
pd.DataFrame(dataset["train"]).sample(10)

In [None]:
unique_lables_in_dataset = pd.DataFrame(dataset["train"])["label"].unique()
num_labels = len(unique_lables_in_dataset)

print(f"{unique_lables_in_dataset=}")
print(f"{num_labels=}")

The GLUE benchmark suite keeps the labels for its test dataset secret. This is a common practice in many machine learning benchmarks. By withholding the labels for the test set, it is ensured that the test set is used solely for evaluating the performance of models and models may not be trained on it. This encourages researchers to focus on developing models that generalize well, rather than optimizing for achieving a high score on the specific test set.

In [None]:
pd.DataFrame(dataset["test_matched"]).sample(10)

The only way to get an evaluation on the testing dataset is to train a model and sent it to the University of New York - which maintains the GLUE benchmark leaderboard - for evaluation. However this option only exists for researches about to publish a paper, therefore we can't do that.

Instead, we will split the validation datasets to create two custom test datasets for our experiment. We will keep the train split as it is.

We will use the first split (80%) as new matched validation dataset and use the second split as new matched test dataset:

In [None]:
new_matched_validation_test_split = dataset["validation_matched"].train_test_split(test_size=0.2)
new_matched_validation_test_split["validation_matched"] = new_matched_validation_test_split.pop("train")
new_matched_validation_test_split["test_matched"] = new_matched_validation_test_split.pop("test")
new_matched_validation_test_split

And we do the same for the mismatched case:

In [None]:
new_mismatched_validation_test_split = dataset["validation_mismatched"].train_test_split(test_size=0.2)
new_mismatched_validation_test_split["validation_mismatched"] = new_mismatched_validation_test_split.pop("train")
new_mismatched_validation_test_split["test_mismatched"] = new_mismatched_validation_test_split.pop("test")
new_mismatched_validation_test_split

In [None]:
dataset["validation_matched"] = new_matched_validation_test_split["validation_matched"]
dataset["test_matched"] = new_matched_validation_test_split["test_matched"]
dataset["validation_mismatched"] = new_mismatched_validation_test_split["validation_mismatched"]
dataset["test_mismatched"] = new_mismatched_validation_test_split["test_mismatched"]
dataset

We now have a test dataset with labels, which is __not__ part of our training data:

In [None]:
pd.DataFrame(dataset["test_matched"]).sample(10)

In [None]:
pd.DataFrame(dataset["test_mismatched"]).sample(10)

## 2. BERT-base

In [16]:
PRE_TRAINED_CHECKPOINT = "google-bert/bert-base-uncased"

### 2.1 Tokenization

In [17]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_CHECKPOINT, do_lower_case="uncased" in PRE_TRAINED_CHECKPOINT)

BERT has a maximum sequence length of 512. We can check the sequence lengths resulting from tokenizing our dataset to see if our dataset exceeds this restriction of BERT:

In [None]:
first_sentence_key, second_sentence_key = _task_to_keys[DATASET_TASK]

if second_sentence_key == None:  # Simply tokenize sentence

    for split in dataset.keys():
        max_len = 0
        for sentence in dataset[split][first_sentence_key]:
            # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
            input_ids = tokenizer.encode(sentence, add_special_tokens=True)
            
            max_len = max(max_len, len(input_ids))
        

        print(f"Max length in {split=}: {max_len}")

else:  # Append both sentences via [SEP] and tokenize

    for split in dataset.keys():
        max_len = 0
        for sentence1, sentence2 in zip(dataset[split][first_sentence_key], dataset[split][second_sentence_key]):
            # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
            input_ids = tokenizer.encode(sentence1, sentence2,  add_special_tokens=True)
            
            max_len = max(max_len, len(input_ids))
        

        print(f"Max length in {split=}: {max_len}")


In [None]:
def tokenize_func(item):
    """Tokenize passed item. 
    
    Depending on dataset task the passed item will either contain one sentence or two sentences.
    In the last case the two sentences will be appended via a [SEP] token.
    """
    if second_sentence_key is None:
        return tokenizer(item[first_sentence_key], add_special_tokens=True, truncation=True)
    else:
        return tokenizer(item[first_sentence_key], item[second_sentence_key], add_special_tokens=True, truncation=True)

tokenized_dataset = dataset.map(tokenize_func, batched=True)

Here is an example of a tokenized dataset item:

In [None]:
with pd.option_context('display.max_colwidth', 400):
    display(pd.DataFrame(tokenized_dataset["train"][:1]).transpose())

Tokenization added the `input_ids` field, which contains the tokenized sentence with a `[CLS]`(101) and two `[SEP]`(102) tokens added. A `token_type_ids` field which indicates first and second portion of the inputs, if necessary. And an `attention_mask` for the given input.

#### Dealing with Padding

Huggingface's `transformers` library provides a `DataCollatorWithPadding` class, which allows us to use dynamic padding.  
Dynamic padding will add `[PAD]` tokens to the length of the longest sequence within a batch, instead of padding to the maximum sequence length within the entire dataset.  
This will avoid unnecessary padding and therefore improve execution efficiency.

In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Example: Select a few samples from the training set
samples = tokenized_dataset["train"][:3]
samples = {k: v for k, v in samples.items() if k not in ["idx", first_sentence_key, second_sentence_key]}  # Drop `idx` and `sentence` columns, as DataCollator can't process those.
pd.DataFrame(samples["input_ids"])

In [None]:
# Apply padding using data_collator
batch = data_collator(samples)
pd.DataFrame(batch["input_ids"])


We can see that `data_collator` will insert `[PAD]` (0) tokens to the maximum length of the passed batch of data items.

### 2.2 Metrics

The GLUE dataset specifies one or more evaluation metrics depending on the selected task.

In [None]:
import numpy as np
import evaluate

metric = evaluate.load(DATASET_NAME, DATASET_TASK)
metric

Depending on the selected GLUE task we optimize for different evaluation metrics. See BERT paper p.6:

> F1 scores are reported for QQP and MRPC, Spearman correlations are reported for STS-B, and accuracy scores are reported for the other tasks. We exclude entries that use BERT as one of their components.

In [24]:
_task_to_metric = {
    "cola": "matthews_correlation",
    "mnli": "accuracy",
    "mrpc": "f1",
    "qnli": "accuracy",
    "qqp": "f1",
    "rte": "accuracy",
    "sst2": "accuracy",
    "stsb": "spearmanr",
}

metric_for_best_model = _task_to_metric[DATASET_TASK]

In [None]:
def get_metric_name_for_specific_task():
    """Helper function to derive the evaluation metric name for the specified GLUE task.

    The tasks specified by the GLUE benchmark use different evaluation metrics.
    Unfortunatly there is no easy way to derive there name after loading the corresponding metric function via HuggingFace's `evaluate` library.
    However we can simply do a "trial run" and expect the name key of its output.
    """
    output = metric.compute(
        predictions=[1, 0], references=[1, 1]
    )  # dummy input - we just want to inspect the returned dictionary.
    metric_names = output.keys()
    
    return list(metric_names)


metric_names = get_metric_name_for_specific_task()
print(f'We will use "{metric_names}" as an evaluation metric for the task {DATASET_TASK}')

In [26]:
assert metric_for_best_model in metric_names, "Metric to optimize for not found in evaluation metrics provided by GLUE"

### 2.3 Training

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    PRE_TRAINED_CHECKPOINT,
    num_labels=num_labels,
    torch_dtype="auto",
)

In [28]:
from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir=(TRAIN_OUTPUT_DIR / PRE_TRAINED_CHECKPOINT.replace("/", "_")).resolve(),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    learning_rate=2e-5,  # Original paper uses best out of  5e-5, 4e-5, 3e-5, and 2e-5
    weight_decay=0.01,  # Original paper uses 0.01 on pre-training
    save_total_limit = 3,  # Keep at most the three checkpoints (latest + best one)
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy_avg",  # `eval_accuracy_avg` will be computed via a custom callback to be the avg of the accuracy for both validation datasets (matched and mismatched)
)

In [29]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if DATASET_TASK != "stsb":
        predictions = np.argmax(predictions, axis=1)
    else:
        predictions = predictions[:, 0]
    return metric.compute(predictions=predictions, references=labels)

As this dataset provides two different evaluation datasets (matched and mismatched) we will perform two evaluations on each `eval_step`. We will then compute an average of both evaluation results and store the average within `transformer`s `metrics` dictionary as `eval_accuracy_avg`. We will also use this computed average as `metric_for_best_model`.

In [30]:
from transformers import TrainerCallback, TrainerState, TrainerControl


class AverageMatchedAndMismatchedAccuracies(TrainerCallback):
    """Callback to save the `eval_mnli_matched_accuracy` after the first evaluation step (MLNI matched)
    and then compute the average of (MLNI matched acc. and mismatched acc.) on the second evaluation step
    """

    matched_acc = None

    def on_evaluate(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        metrics: dict[str, float],
        **kwargs
    ):
        """Event called after an evaluation phase."""
        if self.matched_acc is None:
            # We are in the first evaluation step (matched) - save result metric for later use
            self.matched_acc = metrics["eval_mnli_matched_accuracy"]
            return

        # We are in the second evaluation step (mismatched)
        # Use the `matched_acc` saved before and the `mismatched_acc` to compute average:
        mismatched_acc = metrics["eval_mnli_mismatched_accuracy"]
        metrics["eval_accuracy_avg"] = (self.matched_acc + mismatched_acc) / 2

        self.matched_acc = None
        return

In [31]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=tokenized_dataset["train"],
    eval_dataset={
        "mnli_matched": tokenized_dataset["validation_matched"],
        "mnli_mismatched": tokenized_dataset["validation_mismatched"],
    },
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[AverageMatchedAndMismatchedAccuracies]
)

In [None]:
torch.cuda.empty_cache()

print(f"--- {training_arguments.output_dir=}")
print(f"--- {training_arguments.metric_for_best_model=}")
training_summary_bert_base = trainer.train()

In [None]:
training_summary_bert_base

We can call `trainer.evaluate()` to check that the `trainer` instance did indeed reload the model checkpoint with the highest evaluation score:

In [None]:
best_model_evaluation = trainer.evaluate()
best_model_evaluation

In [None]:
training_history_bert_base = pd.DataFrame(trainer.state.log_history)
training_history_bert_base.epoch = training_history_bert_base.epoch.astype(int)
training_history_bert_base.groupby("epoch").first()

In [None]:
import seaborn as sns

data = training_history_bert_base[["loss", "eval_mnli_matched_loss", "epoch", "eval_mnli_matched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_bert_base.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training  Epochs (matched {PRE_TRAINED_CHECKPOINT})"))

In [None]:
import seaborn as sns

data = training_history_bert_base[["loss", "eval_mnli_mismatched_loss", "epoch", "eval_mnli_mismatched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_bert_base.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training Steps (mismatched {PRE_TRAINED_CHECKPOINT})"))

### 2.4 Evaluation

In [None]:
import seaborn as sns

sns.countplot(x='label', data=pd.DataFrame(tokenized_dataset["test_matched"]))

from IPython.display import Markdown, display
display(Markdown("### Label frequency in matched test dataset"))

In [None]:
import seaborn as sns

sns.countplot(x='label', data=pd.DataFrame(tokenized_dataset["test_mismatched"]))

from IPython.display import Markdown, display
display(Markdown("### Label frequency in mismatched test dataset"))

The dataset classes seem to be somewhat balanced.

In [None]:
predictions_matched = trainer.predict(tokenized_dataset["test_matched"])
predictions_mismatched = trainer.predict(tokenized_dataset["test_mismatched"])

In [None]:
import sklearn.metrics

bert_base_matched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_matched"]["label"],
    predictions_matched.predictions.argmax(-1),
)
plot = sns.heatmap(bert_base_matched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(f"### Prediction Confusion Matrix (matchded - {PRE_TRAINED_CHECKPOINT})")
)

In [None]:
import sklearn.metrics

bert_base_mismatched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_mismatched"]["label"],
    predictions_mismatched.predictions.argmax(-1),
)
plot = sns.heatmap(bert_base_mismatched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(
        f"### Prediction Confusion Matrix (mismatchded - {PRE_TRAINED_CHECKPOINT})"
    )
)

In [None]:
predictions_matched.metrics

In [None]:
predictions_mismatched.metrics

In [None]:
display(Markdown(f"### Best Model performance:"))
results = pd.DataFrame(
    data=[training_summary_bert_base.metrics["train_runtime"]]
    + list(best_model_evaluation.values())
    + [
        predictions_matched.metrics["test_accuracy"],
        predictions_mismatched.metrics["test_accuracy"],
    ],
    index=["train_runtime_s"]
    + list(best_model_evaluation.keys())
    + [
        "matched_test_accuracy",
        "mismatched_test_accuracy",
    ],
    columns=["our BERT_BASE"],
).drop(
    # Drop runtime measurements
    index=[
        "eval_mnli_matched_runtime",
        "eval_mnli_mismatched_runtime",
        "eval_mnli_matched_samples_per_second",
        "eval_mnli_mismatched_samples_per_second",
        "eval_mnli_matched_steps_per_second",
        "eval_mnli_mismatched_steps_per_second",
        "epoch",
    ]
)
# Achieved scores from original BERT paper:
results["original BERT_BASE"] = ["-", "-", "-", "-", "-", "-", 0.846, 0.834]
results["original BERT_LARGE"] = ["-", "-", "-", "-", "-", "-", 0.867, 0.859]
print(
    f'"Our Model" based on {PRE_TRAINED_CHECKPOINT}, best performance on validation data.'
)
print(
    '"BERT_BASE" and "BERT_LARGE" performance on GLUE testing data as reported in original paper.'
)
results

## 3. BERT-Large

In [46]:
PRE_TRAINED_CHECKPOINT = "google-bert/bert-large-uncased"

### 3.1 Tokenization

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_CHECKPOINT, do_lower_case="uncased" in PRE_TRAINED_CHECKPOINT)

def tokenize_func(item):
    """Tokenize passed item. 
    
    Depending on dataset task the passed item will either contain one sentence or two sentences.
    In the last case the two sentences will be appended via a [SEP] token.
    """
    if second_sentence_key is None:
        return tokenizer(item[first_sentence_key], add_special_tokens=True, truncation=True)
    else:
        return tokenizer(item[first_sentence_key], item[second_sentence_key], add_special_tokens=True, truncation=True)

tokenized_dataset = dataset.map(tokenize_func, batched=True)

In [48]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

### 3.2 Metrics

In [None]:
import evaluate

metric = evaluate.load(DATASET_NAME, DATASET_TASK)

metric_for_best_model = _task_to_metric[DATASET_TASK]
metric_names = get_metric_name_for_specific_task()
print(f'We will use "{metric_names}" as an evaluation metric for the task {DATASET_TASK}')

In [50]:
assert metric_for_best_model in metric_names, "Metric to optimize for not found in evaluation metrics provided by GLUE"

### 3.3 Training

In [51]:
BATCH_SIZE = 32  # BERT-large might need a smaller batch size

In [None]:
import gc

try:
    del model
    del trainer
except:
    pass


gc.collect()

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    PRE_TRAINED_CHECKPOINT,
    num_labels=num_labels,
    torch_dtype="auto",
)

from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir=(TRAIN_OUTPUT_DIR / PRE_TRAINED_CHECKPOINT.replace("/", "_")).resolve(),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    learning_rate=2e-5,  # Original paper uses best out of  5e-5, 4e-5, 3e-5, and 2e-5
    weight_decay=0.01,  # Original paper uses 0.01 on pre-training
    save_total_limit = 3,  # Keep at most the three checkpoints (latest + best one)
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy_avg",  # `eval_accuracy_avg` will be computed via a custom callback to be the avg of the accuracy for both validation datasets (matched and mismatched)
)

from transformers import Trainer

validation_key = "validation_mismatched" if DATASET_TASK == "mnli-mm" else "validation_matched" if DATASET_TASK == "mnli" else "validation"

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=tokenized_dataset["train"],
    eval_dataset={
        "mnli_matched": tokenized_dataset["validation_matched"],
        "mnli_mismatched": tokenized_dataset["validation_mismatched"],
    },
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[AverageMatchedAndMismatchedAccuracies]
)

In [None]:
torch.cuda.empty_cache()

print(f"--- {training_arguments.output_dir=}")
print(f"--- {training_arguments.metric_for_best_model=}")
training_summary_bert_large = trainer.train()

In [None]:
training_summary_bert_large

In [None]:
best_model_evaluation = trainer.evaluate()
best_model_evaluation

In [None]:
training_history_bert_large = pd.DataFrame(trainer.state.log_history)
training_history_bert_large.epoch = training_history_bert_large.epoch.astype(int)
training_history_bert_large.groupby("epoch").first()

In [None]:
import seaborn as sns

data = training_history_bert_large[["loss", "eval_mnli_matched_loss", "epoch", "eval_mnli_matched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_bert_large.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training  Epochs (matched {PRE_TRAINED_CHECKPOINT})"))

In [None]:
import seaborn as sns

data = training_history_bert_large[["loss", "eval_mnli_mismatched_loss", "epoch", "eval_mnli_mismatched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_bert_large.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training Steps (mismatched {PRE_TRAINED_CHECKPOINT})"))

### 3.4 Evaluation

In [None]:
predictions_matched = trainer.predict(tokenized_dataset["test_matched"])
predictions_mismatched = trainer.predict(tokenized_dataset["test_mismatched"])

In [None]:
import sklearn.metrics

bert_large_matched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_matched"]["label"],
    predictions_matched.predictions.argmax(-1),
)
plot = sns.heatmap(bert_large_matched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(f"### Prediction Confusion Matrix (matchded - {PRE_TRAINED_CHECKPOINT})")
)

In [None]:
import sklearn.metrics

bert_large_mismatched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_mismatched"]["label"],
    predictions_mismatched.predictions.argmax(-1),
)
plot = sns.heatmap(bert_large_mismatched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(
        f"### Prediction Confusion Matrix (mismatchded - {PRE_TRAINED_CHECKPOINT})"
    )
)

In [None]:
display(Markdown(f"### Best Model performance:"))
results["our BERT_LARGE"] = [
    training_summary_bert_large.metrics["train_runtime"],
    best_model_evaluation["eval_mnli_matched_loss"],
    best_model_evaluation["eval_mnli_matched_accuracy"],
    best_model_evaluation["eval_mnli_mismatched_loss"],
    best_model_evaluation["eval_mnli_mismatched_accuracy"],
    best_model_evaluation["eval_accuracy_avg"],
    predictions_matched.metrics["test_accuracy"],
    predictions_mismatched.metrics["test_accuracy"],
]
results = results[
    [
        "our BERT_BASE",
        "original BERT_BASE",
        "our BERT_LARGE",
        "original BERT_LARGE",
    ]
]
print('"BERT_BASE" and "BERT_LARGE" performance on GLUE testing data as reported in original paper.')
results

## 4. ModernBERT-base

In [64]:
PRE_TRAINED_CHECKPOINT = "answerdotai/ModernBERT-base" 

### 4.1 Tokenization

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_CHECKPOINT, do_lower_case="uncased" in PRE_TRAINED_CHECKPOINT)

def tokenize_func(item):
    """Tokenize passed item. 
    
    Depending on dataset task the passed item will either contain one sentence or two sentences.
    In the last case the two sentences will be appended via a [SEP] token.
    """
    if second_sentence_key is None:
        return tokenizer(item[first_sentence_key], add_special_tokens=True, truncation=True)
    else:
        return tokenizer(item[first_sentence_key], item[second_sentence_key], add_special_tokens=True, truncation=True)

tokenized_dataset = dataset.map(tokenize_func, batched=True)

In [66]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

### 4.2 Metrics

In [None]:
import evaluate

metric = evaluate.load(DATASET_NAME, DATASET_TASK)

metric_for_best_model = _task_to_metric[DATASET_TASK]
metric_names = get_metric_name_for_specific_task()
print(f'We will use "{metric_names}" as an evaluation metric for the task {DATASET_TASK}')

In [68]:
assert metric_for_best_model in metric_names, "Metric to optimize for not found in evaluation metrics provided by GLUE"

### 4.3 Training

In [69]:
BATCH_SIZE = 32

In [None]:
import gc

try:
    del model
    del trainer
except:
    pass


gc.collect()

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    PRE_TRAINED_CHECKPOINT,
    num_labels=num_labels,
    reference_compile=False
)

from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir=(TRAIN_OUTPUT_DIR / PRE_TRAINED_CHECKPOINT.replace("/", "_")).resolve(),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    lr_scheduler_type="linear",
    optim="adamw_torch",
    adam_beta1=0.9,
    adam_beta2=0.98,
    adam_epsilon=1e-6,
    learning_rate=8e-5,  # Original paper recommends 8e-5
    weight_decay=0.01,  # Original paper uses 0.01 on pre-training
    save_total_limit = 3,  # Keep at most the three checkpoints (latest + best one)
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy_avg",  # `eval_accuracy_avg` will be computed via a custom callback to be the avg of the accuracy for both validation datasets (matched and mismatched)
    bf16=True,
    bf16_full_eval=True,
)

from transformers import Trainer

validation_key = "validation_mismatched" if DATASET_TASK == "mnli-mm" else "validation_matched" if DATASET_TASK == "mnli" else "validation"

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=tokenized_dataset["train"],
    eval_dataset={
        "mnli_matched": tokenized_dataset["validation_matched"],
        "mnli_mismatched": tokenized_dataset["validation_mismatched"],
    },
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[AverageMatchedAndMismatchedAccuracies]
)

In [None]:
torch.cuda.empty_cache()

print(f"--- {training_arguments.output_dir=}")
print(f"--- {training_arguments.metric_for_best_model=}")
training_summary_modernbert_base = trainer.train()

In [None]:
training_summary_modernbert_base

In [None]:
best_model_evaluation = trainer.evaluate()
best_model_evaluation

In [None]:
training_history_modernbert_base = pd.DataFrame(trainer.state.log_history)
training_history_modernbert_base.epoch = training_history_modernbert_base.epoch.astype(int)
training_history_modernbert_base.groupby("epoch").first()

In [None]:
import seaborn as sns

data = training_history_modernbert_base[["loss", "eval_mnli_matched_loss", "epoch", "eval_mnli_matched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_modernbert_base.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training  Epochs (matched {PRE_TRAINED_CHECKPOINT})"))

In [None]:
import seaborn as sns

data = training_history_modernbert_base[["loss", "eval_mnli_mismatched_loss", "epoch", "eval_mnli_mismatched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_modernbert_base.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training Steps (mismatched {PRE_TRAINED_CHECKPOINT})"))

### 4.4 Evaluation

In [None]:
predictions_matched = trainer.predict(tokenized_dataset["test_matched"])
predictions_mismatched = trainer.predict(tokenized_dataset["test_mismatched"])

In [None]:
import sklearn.metrics

modernbert_base_matched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_matched"]["label"],
    predictions_matched.predictions.argmax(-1),
)
plot = sns.heatmap(modernbert_base_matched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(f"### Prediction Confusion Matrix (matchded - {PRE_TRAINED_CHECKPOINT})")
)

In [None]:
import sklearn.metrics

modernbert_base_mismatched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_mismatched"]["label"],
    predictions_mismatched.predictions.argmax(-1),
)
plot = sns.heatmap(modernbert_base_mismatched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(
        f"### Prediction Confusion Matrix (mismatchded - {PRE_TRAINED_CHECKPOINT})"
    )
)

In [None]:
predictions_matched.metrics

In [None]:
predictions_mismatched.metrics

In [None]:
display(Markdown(f"### Best Model performance:"))
results["our ModernBERT_BASE"] = [
    training_summary_modernbert_base.metrics["train_runtime"],
    best_model_evaluation["eval_mnli_matched_loss"],
    best_model_evaluation["eval_mnli_matched_accuracy"],
    best_model_evaluation["eval_mnli_mismatched_loss"],
    best_model_evaluation["eval_mnli_mismatched_accuracy"],
    best_model_evaluation["eval_accuracy_avg"],
    predictions_matched.metrics["test_accuracy"],
    predictions_mismatched.metrics["test_accuracy"],
]
results = results[
    [
        "our BERT_BASE",
        "original BERT_BASE",
        "our ModernBERT_BASE",
        "our BERT_LARGE",
        "original BERT_LARGE",
    ]
]
print('"BERT_BASE" and "BERT_LARGE" performance on GLUE testing data as reported in original paper.')
results

## 5. ModernBERT-Large

In [84]:
PRE_TRAINED_CHECKPOINT = "answerdotai/ModernBERT-large" 

### 4.1 Tokenization

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_CHECKPOINT, do_lower_case="uncased" in PRE_TRAINED_CHECKPOINT)

def tokenize_func(item):
    """Tokenize passed item. 
    
    Depending on dataset task the passed item will either contain one sentence or two sentences.
    In the last case the two sentences will be appended via a [SEP] token.
    """
    if second_sentence_key is None:
        return tokenizer(item[first_sentence_key], add_special_tokens=True, truncation=True)
    else:
        return tokenizer(item[first_sentence_key], item[second_sentence_key], add_special_tokens=True, truncation=True)

tokenized_dataset = dataset.map(tokenize_func, batched=True)

In [86]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

### 4.2 Metrics

In [None]:
import evaluate

metric = evaluate.load(DATASET_NAME, DATASET_TASK)

metric_for_best_model = _task_to_metric[DATASET_TASK]
metric_names = get_metric_name_for_specific_task()
print(f'We will use "{metric_names}" as an evaluation metric for the task {DATASET_TASK}')

In [88]:
assert metric_for_best_model in metric_names, "Metric to optimize for not found in evaluation metrics provided by GLUE"

### 4.3 Training

In [89]:
BATCH_SIZE = 32

In [90]:
import gc

try:
    del model
    del trainer
except:
    pass


gc.collect()
torch.cuda.empty_cache()

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    PRE_TRAINED_CHECKPOINT,
    num_labels=num_labels,
    reference_compile=False
)

from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir=(TRAIN_OUTPUT_DIR / PRE_TRAINED_CHECKPOINT.replace("/", "_")).resolve(),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    lr_scheduler_type="linear",
    optim="adamw_torch",
    adam_beta1=0.9,
    adam_beta2=0.98,
    adam_epsilon=1e-6,
    learning_rate=8e-5,  # Original paper recommends 8e-5
    weight_decay=0.01,  # Original paper uses 0.01 on pre-training
    save_total_limit = 3,  # Keep at most the three checkpoints (latest + best one)
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy_avg",  # `eval_accuracy_avg` will be computed via a custom callback to be the avg of the accuracy for both validation datasets (matched and mismatched)
    bf16=True,
    bf16_full_eval=True,
)

from transformers import Trainer

validation_key = "validation_mismatched" if DATASET_TASK == "mnli-mm" else "validation_matched" if DATASET_TASK == "mnli" else "validation"

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=tokenized_dataset["train"],
    eval_dataset={
        "mnli_matched": tokenized_dataset["validation_matched"],
        "mnli_mismatched": tokenized_dataset["validation_mismatched"],
    },
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[AverageMatchedAndMismatchedAccuracies]
)

In [None]:
torch.cuda.empty_cache()

print(f"--- {training_arguments.output_dir=}")
print(f"--- {training_arguments.metric_for_best_model=}")
training_summary_modernbert_large = trainer.train()

In [None]:
training_summary_modernbert_large

In [None]:
best_model_evaluation = trainer.evaluate()
best_model_evaluation

In [None]:
training_history_modernbert_large = pd.DataFrame(trainer.state.log_history)
training_history_modernbert_large.epoch = training_history_modernbert_large.epoch.astype(int)
training_history_modernbert_large.groupby("epoch").first()

In [None]:
import seaborn as sns

data = training_history_modernbert_large[["loss", "eval_mnli_matched_loss", "epoch", "eval_mnli_matched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_modernbert_large.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training  Epochs (matched {PRE_TRAINED_CHECKPOINT})"))

In [None]:
import seaborn as sns

data = training_history_modernbert_large[["loss", "eval_mnli_mismatched_loss", "epoch", "eval_mnli_mismatched_accuracy", ]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
data = data[:-1]  # drop last row, as this row just contains the values for the best checkpoint again
data = pd.melt(data, ['Training Epoch']).dropna()


plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_modernbert_large.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")

from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training Steps (mismatched {PRE_TRAINED_CHECKPOINT})"))

### 3.4 Evaluation

In [None]:
predictions_matched = trainer.predict(tokenized_dataset["test_matched"])
predictions_mismatched = trainer.predict(tokenized_dataset["test_mismatched"])

In [None]:
import sklearn.metrics

modernbert_large_matched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_matched"]["label"],
    predictions_matched.predictions.argmax(-1),
)
plot = sns.heatmap(modernbert_large_matched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(f"### Prediction Confusion Matrix (matchded - {PRE_TRAINED_CHECKPOINT})")
)

In [None]:
import sklearn.metrics

modernbert_large_mismatched_cm = sklearn.metrics.confusion_matrix(
    tokenized_dataset["test_mismatched"]["label"],
    predictions_mismatched.predictions.argmax(-1),
)
plot = sns.heatmap(modernbert_large_mismatched_cm, annot=True, fmt="d")
plot.set_xlabel("True label")
plot.set_ylabel("Predicted label")

from IPython.display import Markdown, display

display(
    Markdown(
        f"### Prediction Confusion Matrix (mismatchded - {PRE_TRAINED_CHECKPOINT})"
    )
)

In [None]:
predictions_matched.metrics

In [None]:
predictions_mismatched.metrics

In [None]:
display(Markdown(f"### Best Model performance:"))
results["our ModernBERT_LARGE"] = [
    training_summary_modernbert_base.metrics["train_runtime"],
    best_model_evaluation["eval_mnli_matched_loss"],
    best_model_evaluation["eval_mnli_matched_accuracy"],
    best_model_evaluation["eval_mnli_mismatched_loss"],
    best_model_evaluation["eval_mnli_mismatched_accuracy"],
    best_model_evaluation["eval_accuracy_avg"],
    predictions_matched.metrics["test_accuracy"],
    predictions_mismatched.metrics["test_accuracy"],
]
results = results[
    [
        "our BERT_BASE",
        "original BERT_BASE",
        "our ModernBERT_BASE",
        "our BERT_LARGE",
        "original BERT_LARGE",
        "our ModernBERT_LARGE",
    ]
]
print('"BERT_BASE" and "BERT_LARGE" performance on GLUE testing data as reported in original paper.')
results

# Summary

In [None]:
import matplotlib.pyplot as plt

titles = ["BERT-base", "BERT-large", "ModernBERT-base", "ModernBERT-large"]
training_histories = [training_history_bert_base, training_history_bert_large, training_history_modernbert_base, training_history_modernbert_large]

fig, axes = plt.subplots(nrows=2,ncols=len(training_histories), sharey=True, sharex=True)

def draw_loss_eval_plot(title, history, ax, mismatched=False):
    eval_loss_key = "eval_mnli_mismatched_loss" if mismatched else "eval_mnli_matched_loss"
    eval_acc_key = "eval_mnli_mismatched_accuracy" if mismatched else "eval_mnli_matched_accuracy"

    data = history[["loss", eval_loss_key, "epoch", eval_acc_key]]
    data.columns = ["Train. Loss", "Eval. Loss", "Training Epoch", "Acc."]
    data = data[:-1]
    data = pd.melt(data, ['Training Epoch']).dropna()

    plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True, ax=ax)
    plot.set_ylabel("")
    plot.set(xticks=list(set(history.epoch)))
    plot.legend(title="", loc='upper left')
    plot.set_title(title)

for title, history, ax in zip(titles, training_histories, axes[0]):
    draw_loss_eval_plot(title, history, ax)
for title, history, ax in zip(titles, training_histories, axes[1]):
    draw_loss_eval_plot(title, history, ax, mismatched=True)

for ax in axes[0][1:]:
    ax.get_legend().remove()
for ax in axes[1]:
    ax.get_legend().remove()

axes[0][0].set_ylabel("MATCHED Evaluation / Test Set", weight='bold', fontsize=14)
axes[1][0].set_ylabel("MISMATCHED Evaluation / Test Set", weight='bold', fontsize=14)

fig.set_figwidth(20)
fig.set_figheight(10)
fig.tight_layout()

In [None]:
titles

In [None]:
titles = ["BERT-base", "BERT-large", "ModernBERT-base", "ModernBERT-large"]
our_results_matched = (
    results.loc["matched_test_accuracy"]
    .drop("original BERT_BASE")
    .drop("original BERT_LARGE")
)
our_results_mismatched = (
    results.loc["mismatched_test_accuracy"]
    .drop("original BERT_BASE")
    .drop("original BERT_LARGE")
)
titles_matched = [
    title + " - " + f"Matched Acc.: {matched_acc:.2f}"
    for title, matched_acc in zip(titles, our_results_matched)
]
titles_mismatched = [
    title + " - " + f"Mismatched Acc.: {mismatched_acc:.2f}"
    for title, mismatched_acc in zip(titles, our_results_mismatched)
]
matched_cms = [
    bert_base_matched_cm,
    bert_large_matched_cm,
    modernbert_base_matched_cm,
    modernbert_large_matched_cm,
]
mismatched_cms = [
    bert_base_mismatched_cm,
    bert_large_mismatched_cm,
    modernbert_base_mismatched_cm,
    modernbert_large_mismatched_cm,
]

fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(4.2, 4.2))


def draw_confusion_matrix_plot(title, cm, ax):
    include_cbar = title in (titles_matched[-1], titles_mismatched[-1])
    plot = sns.heatmap(
        cm, annot=True, fmt="d", square=True, cmap="viridis", cbar=include_cbar, ax=ax
    )
    plot.set_xlabel("True label")
    plot.set_ylabel("Predicted label")
    plot.set_title(title)


for title, history, ax in zip(titles_matched, matched_cms, axes[0]):
    draw_confusion_matrix_plot(title, history, ax)
for title, history, ax in zip(titles_mismatched, matched_cms, axes[1]):
    draw_confusion_matrix_plot(title, history, ax)


fig.set_figwidth(18)
fig.set_figheight(8)
fig.tight_layout(h_pad=2)

In [None]:
results

In [None]:
speedup = results["our BERT_BASE"]["train_runtime_s"] / results["our ModernBERT_BASE"]["train_runtime_s"] 
speedup