In [1]:
from datasets import load_dataset, concatenate_datasets, DatasetDict, load_from_disk

load_from_file = False

if load_from_disk:
    datasets = load_from_disk("../swag.hf")
else:
    datasets = load_dataset("swag", "regular")
    # the labels for the test split are not public, therefore we create our own split
    # 60% train, 20% validation, 20% test
    merged_datasets = concatenate_datasets([datasets["train"], datasets["validation"]])
    
    train_testvalid = merged_datasets.train_test_split(test_size=0.4)
    # Split the 10% test + valid in half test, half valid
    test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
    # gather everyone if you want to have a single DatasetDict
    datasets = DatasetDict({
        'train': train_testvalid['train'],
        'test': test_valid['test'],
        'validation': test_valid['train']})

    datasets.save_to_disk("../swag.hf")

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")

In [3]:
import torch
import torch.nn as nn
from transformers import ModernBertModel

class ModernBERTForMultipleChoice(nn.Module):
    def __init__(self, model_name="answerdotai/ModernBERT-large"):
        super(ModernBERTForMultipleChoice, self).__init__()
        self.modernBERT = ModernBertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(1024, 1)
        self.loss_fct = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask=None, labels=None):
        bsz, num_choices, seq_len = input_ids.size()
        
        input_ids = input_ids.view(-1, seq_len)
        attention_mask = attention_mask.view(-1, seq_len)
        # Pass through BERT
        outputs = self.modernBERT(input_ids=input_ids, attention_mask=attention_mask)
        # Extract [CLS] token representation: (batch_size * 4, hidden_size)
        hidden_state = outputs.last_hidden_state  # (batch_size*num_choices, seq_len, hidden_size)
        pooled_output = hidden_state[:, 0, :]  

        logits = self.classifier(self.dropout(pooled_output))  # (batch_size*num_choices, 1)
        reshaped_logits = logits.view(bsz, num_choices)  # (batch_size, num_choices)
        
        # Compute loss if labels are provided
        loss = None
        if labels is not None:
            loss = self.loss_fct(reshaped_logits, labels)

        return {"loss": loss, "logits": reshaped_logits}

In [4]:
# adopted from: https://huggingface.co/docs/transformers/tasks/multiple_choice
def preprocess_function(examples):
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    answer_choices = [examples[f"ending{i}"] for i in range(4)]

    second_sentences = [
        [f"{header} {choice}" for choice in choices]
        for header, choices in zip(question_headers, zip(*answer_choices))
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized = tokenizer(
        first_sentences,
        second_sentences,
        truncation=True,
        padding="max_length", 
        max_length=128,        
        return_tensors="pt",
    )

    return {
        "input_ids": tokenized["input_ids"].view(-1, 4, tokenized["input_ids"].shape[-1]),
        "attention_mask": tokenized["attention_mask"].view(-1, 4, tokenized["attention_mask"].shape[-1]),
        "labels": examples["label"],
    }

encoded_dataset = datasets.map(preprocess_function, batched=True)

Map:   0%|          | 0/56131 [00:00<?, ? examples/s]

Map:   0%|          | 0/18711 [00:00<?, ? examples/s]

Map:   0%|          | 0/18710 [00:00<?, ? examples/s]

In [5]:
import numpy as np

def compute_metrics(eval_predictions):
    predictions, label_ids = eval_predictions
    preds = np.argmax(predictions, axis=1)
    return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()}

In [None]:
from transformers import Trainer, TrainingArguments

model = ModernBERTForMultipleChoice("answerdotai/ModernBERT-large")

training_args = TrainingArguments(
    "modernBert-large-finetuned-swag",
    eval_strategy= "steps",
    eval_steps=1000,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    compute_metrics=compute_metrics,
)
torch.cuda.empty_cache()

trainer.train()

Step,Training Loss,Validation Loss,Accuracy
1000,0.6494,0.570847,0.782416
2000,0.5568,0.578723,0.788081
3000,0.5607,0.518744,0.80294
4000,0.2357,0.621283,0.803848


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [17]:
evaluated = trainer.evaluate()
evaluated

{'eval_loss': 1.1972671747207642,
 'eval_accuracy': 0.8199358582496643,
 'eval_runtime': 479.6592,
 'eval_samples_per_second': 39.007,
 'eval_steps_per_second': 2.439,
 'epoch': 3.0}

In [19]:
import pandas as pd
training_history_modernBert = pd.DataFrame(trainer.state.log_history)
training_history_modernBert.epoch = training_history_modernBert.epoch.astype(int)
training_history_modernBert.groupby("epoch").first()

Unnamed: 0_level_0,loss,grad_norm,learning_rate,step,eval_loss,eval_accuracy,eval_runtime,eval_samples_per_second,eval_steps_per_second,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
0,0.7273,11.111187,1.9e-05,500,0.570847,0.782416,481.6994,38.842,2.429,,,,,
1,0.2357,15.915008,1.2e-05,4000,0.621283,0.803848,478.0415,39.139,2.447,,,,,
2,0.0387,13.706006,6e-06,7500,1.16525,0.81737,477.7302,39.164,2.449,,,,,
3,,,,10527,1.197267,0.819936,477.6933,39.167,2.449,16674.5278,10.099,0.631,0.0,0.284195


In [18]:
test_pred = trainer.predict(encoded_dataset["test"])
test_pred.metrics

{'test_loss': 1.1635264158248901,
 'test_accuracy': 0.8218160271644592,
 'test_runtime': 478.3647,
 'test_samples_per_second': 39.115,
 'test_steps_per_second': 2.446}

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Prepare the data
data = training_history_modernBert[["loss", "eval_loss", "step", "eval_accuracy"]]
data.columns = ["Train. Loss", "Eval. Loss", "Training Steps", "Accuracy"]
data = data[:-1]  # drop last row
data = pd.melt(data, ['Training Steps']).dropna()

# Plot using Matplotlib
plt.figure(figsize=(10, 6))

# Plot each variable separately without using groupby
train_loss = data[data['variable'] == "Train. Loss"]
eval_loss = data[data['variable'] == "Eval. Loss"]
acc = data[data['variable'] == "Accuracy"]

plt.plot(train_loss["Training Steps"], train_loss["value"], marker='o', label="Train. Loss")
plt.plot(eval_loss["Training Steps"], eval_loss["value"], marker='o', label="Eval. Loss")
plt.plot(acc["Training Steps"], acc["value"], marker='o', label="Accuracy")

# Labels and Title
plt.ylabel('Accuarcy/Loss')
plt.xlabel('Step')
plt.title('ModernBert-Large: Training Accuarcy vs Evaluation Accuarcy')
plt.legend()
plt.grid(True)

plt.savefig('ModernBert_large_swag_finetuned.png')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

import evaluate

metric = evaluate.load("confusion_matrix")

predictions = np.argmax(test_pred.predictions, axis=-1)
metric.add_batch(predictions=predictions, references=datasets["test"]["label"])

# Compute confusion matrix
conf_matrix = metric.compute()['confusion_matrix']

# Define class labels manually (adjust as needed)
labels = ["1", "2", "3", "4"]

# Plotting the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('ModernBert-Large: Confusion Matrix')
plt.savefig('swag_ModernBert_large_confusion_matrix.png')