## Fine-tuning BERT on GoEmotions for Emotion Classification

This notebook demonstrates fine-tuning a BERT model for multi-label emotion classification using the `GoEmotions` dataset, followed by export to ONNX

## Set-up and imports

In [None]:
!pip install transformers datasets torch scikit-learn --quiet

In [None]:
import torch
import numpy as np
from datasets import load_dataset
from transformers import (
    BertTokenizerFast,
    BertForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding
)
from sklearn.metrics import f1_score

## Loading and preprocessing dataset

In [None]:
dataset = load_dataset("go_emotions")
num_labels = dataset["train"].features["labels"].feature.num_classes

### Tokenizing

Tokenizing the text using BertTokenizerFast and pad/truncate sequences to a max length of 128 tokens.

In [None]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

def tokenize(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

dataset = dataset.map(tokenize, batched=True)

### Multi-Label Preparation

Cconvert the dataset labels to multi-hot vectors and ensuring the dataset is returned as PyTorch tensors, as well as preparing a custom collator to handle multi-label tensors during batching.

In [None]:
def to_multihot(batch):
    multi_hot = []
    for labels in batch["labels"]:
        vec = [0.0] * num_labels
        for l in labels:
            vec[l] = 1.0
        multi_hot.append(vec)
    batch["labels"] = multi_hot
    return batch

dataset = dataset.map(to_multihot, batched=True)

dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"]
)

In [None]:
class MultiLabelCollator(DataCollatorWithPadding):
  def __call__(self, features):
    batch = super().__call__(features)
    batch["labels"] = batch["labels"].float()
    return batch

data_collator = MultiLabelCollator(tokenizer=tokenizer)

## Model

### Model initialization

In [None]:
model = BertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

### Defining F1-macro metric for evaluation

In [None]:
def compute_metrics(pred):
    logits = torch.tensor(pred.predictions)
    probs = torch.sigmoid(logits)
    y_pred = (probs > 0.5).int().numpy()
    y_true = pred.label_ids
    return {
        "f1_macro": f1_score(y_true, y_pred, average="macro")
    }

### Training arguments

In [None]:
training_args = TrainingArguments(
    output_dir="./bert_goemotions",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=100,
    load_best_model_at_end=True,
)

### Trainer set-up and training

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

In [None]:
trainer.train()
trainer.evaluate(dataset["test"])

In [None]:
trainer.save_model("./bert_goemotions_model")
tokenizer.save_pretrained("./bert_goemotions_model")

In [None]:
!pip install optimum[onnxruntime]

## Saving and exporting the model

In [None]:

sample = tokenizer(
    "I am extremely happy today!",
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=128
)

from optimum.onnxruntime import ORTModelForSequenceClassification

ort_model = ORTModelForSequenceClassification.from_pretrained(
    "./bert_goemotions_model",
    export=True
)

ort_model.save_pretrained("./onnx_model")

In [None]:
import onnxruntime as ort
import numpy as np
import torch

sample = tokenizer(
    "I am extremely happy today!",
    return_tensors="np",
    padding="max_length",
    truncation=True,
    max_length=128
)

session = ort.InferenceSession("onnx_model/model.onnx")

inputs = {
    "input_ids": sample["input_ids"],
    "attention_mask": sample["attention_mask"],
    "token_type_ids": sample["token_type_ids"]
}

outputs = session.run(None, inputs)

logits = torch.tensor(outputs[0])
probs = torch.sigmoid(logits)

print("Probabilities:", probs)