# Fine-tune transformer model for text (sequence) classification

This notebook shows a minimal working example of how to **fine-tune a transformer model** for sequence classification.
**Sequence classification** refers to the task of assigning a label to a sequence (of tokens). In our case, the sequence is a sentence (sequence of words).

The focus in this notebook lies on the **general workflow**:

1. Load the labeled text dataset
1. Split the dataset into train, dev, and test splits
1. Tokenize the texts in each split
1. Define the evaluation metrics that quanticy model performance
1. Prepare the model for fine-tuning
1. Setup a `Trainer` that handles the model fine-tuning
1. Use the `Trainer` to fine-tune on the training split examples, using the dev set examples to monitor performace
1. Evaluate on the fine-tuned model in the test set

## Setup

If you run this notebook on colab, you'll need to take a number of extra steps:

In [1]:
# check if on colab
COLAB = True
try:
    import google.colab
except:
    COLAB=False

if COLAB:
    # install required packages
    !pip install -q  scikit-learn==1.5.1 datasets==2.21.0 tokenizers==0.19.1 sentencepiece==0.2.0 protobuf==3.20.3 accelerate==0.33.0 transformers==4.44.1 torch~=2.4.0 seqeval==1.2.2

if COLAB:
    # download custom utils
    !mkdir -p utils
    !base_url=https://raw.githubusercontent.com/haukelicht/advanced_text_analysis/main/notebooks/utils
    !files=(io.py finetuning.py metrics.py)
    !for file in "${files[@]}"; do curl -o "utils/$file" "$base_url/$file"; done

import os
data_path = os.path.join('..', 'data', 'labeled', 'bestvater_sentiment_2023', '')
if COLAB:
    'https://raw.githubusercontent.com/haukelicht/advanced_text_analysis/data/labeled/bestvater_sentiment_2023/'

Next, we load the required modules, classes, and functions.

Note that some function come from the `utils` folder.
These are functions I have defined to handle general tasks, like

- reading data from a tabular file (e.g., CV);
- splitting the data into train, dev, and test split;
- tokenization,
- etc.

These functions should be general enough for many use cases. 
You can use them in your researhc if you want.
But please double check that they do what you want them to do if you want to publish results that depend on my code ;)

In [1]:
from utils.io import read_tabular
from utils.finetuning import (
    get_device, 
    split_data, 
    create_sequence_classification_dataset,
    preprocess_sequence_classification_dataset
)

from datasets import DatasetDict
from transformers import (
    set_seed,
    AutoTokenizer,
    DataCollatorWithPadding, 
    AutoModelForSequenceClassification, 
    Trainer,
    TrainingArguments
)

from utils.metrics import (
    parse_sequence_classifier_prediction_output,
    compute_sequence_classification_metrics_binary
)

In [2]:
SEED = 42
set_seed(SEED)

In [10]:
MODEL_NAME = 'roberta-base'
device = get_device()
print(f'Using device: {str(device)}')

Using device: mps


## Load and prepare the data

In [4]:
fp = data_path + 'bestvater_sentiment_2023-motn_responses_sentiment.tsv'
df = read_tabular(fp, columns=['text', 'label'])

In [5]:
len(df)

5417

In [6]:
df.label.value_counts(normalize=True)

label
0    0.565442
1    0.434558
Name: proportion, dtype: float64

In [7]:
data_splits = split_data(df, dev_size=0.15, test_size=0.15, seed=SEED, stratify_by='label', return_dict=True)

In [12]:
# note: always do this on the train split (the model can only be expected to predict classes it also sees during training)
label2id = {l: i for i, l in enumerate(data_splits['train'].label.unique())}
id2label = {i: l for l, i in label2id.items()}

In [8]:
data_splits = DatasetDict({s: create_sequence_classification_dataset(df) for s, df in data_splits.items()})

In [13]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
data_splits = data_splits.map(lambda x: preprocess_sequence_classification_dataset(x, tokenizer=tokenizer, label2id=label2id, truncation=True), batched=True)

Map:   0%|          | 0/3793 [00:00<?, ? examples/s]

Map:   0%|          | 0/812 [00:00<?, ? examples/s]

Map:   0%|          | 0/812 [00:00<?, ? examples/s]

In [14]:
data_splits = data_splits.remove_columns(['text', 'label'])
data_splits.set_format('torch')

## Prepare the model for fine-tuning with a `Trainer`

First, we define the `model_init` function that instantiates a pre-trained model with a sequence classification head that can be  fine-tuned.
We will pass this function to the trainer instead of the model itself.
The reason for this is that it ensures that everytime we call `trainer.train()` below, we start with a fresh model (i.e., no continued fine-tuning).

In [None]:
def model_init():
    """Function to instantiate a fine-tunable sequence classification model"""
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(label2id))
    if model.config.problem_type is None:
        model.config.problem_type = 'single_label_classification'
    if isinstance(id2label[0], str):
        model.config.id2label = id2label
        model.config.label2id = label2id
    model = model.to(device)
    return model

Next, we define a `compute_metrics` function.
This function is there for evaluating predicted against observed labels in some held-out data (the dev split during fine-tuning and the test split afterwards).
My implementation reports standard metrics for binary classification (precision, recall, F1-score).

If you want to adapt it, 

- keep the first row and work with the observed and predicted labels (`labels` and `predictions`)
- return a dictionary that reports evaluation metrics

In [20]:
def compute_metrics(p):
    labels, predictions = parse_sequence_classifier_prediction_output(p)
    return compute_sequence_classification_metrics_binary(y_true=labels, y_pred=predictions)

Next, we define the **training arguments**.
I have added comments to group arguments based on what they are there for.
Here some explanation:

- *hyperparameters*: they govern how the model learns from the training data
    - `optim`: name of optimization algorithm (handles parameter updating)
    - `num_train_epochs`: Number of iterations over all training examples
    - `per_device_train_batch_size`: Number of examples grouped per updating step
- *evaluation*
    - `eval_strategy`: when to evaluate (`'epoch'` means after each epoch, i.e., after every completed iteration over all training split examples)
- *model saving:*
    - `metric_for_best_model`: When we evaluate at the end of each epoch ( see `eval_strategy`), we get one "checkpoint" per epoch. `metric_for_best_model` names the metric that is used to determine which of two models checkpoints performed better in the held-out dev split examples. **Important:** The name must be in the dictionary returned by the `compute_metrics` finction (see below)
    - `load_best_model_at_end`: Whether or not to load the best model (judged based on `metric_for_best_model`) should be loaded when finetuning ends. `True` (recommended) means that the `trainer` represents the best model instance (judged based on the `metric_for_best_model` metric, e.g. F1, in the dev split examples). 
    - `save_total_limit` determines how many checkpoints to save at most. Note that each model checkpoint will have several GB. So set this to a low number (e.g., 2) to avoid spamming your computer. **Important:** Setting this to 2 is the minimal required value if you set `load_best_model_at_end=True`
    



In [17]:
# path to folder where model checkpoints and finetuning logs will be saved
dest = './../results/example_classifier/'
training_args = TrainingArguments(
    output_dir=dest,
    # hyperparameters
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    optim='adamw_torch',
    # use_mps_device=str(device)=='mps', # uncomment this when using older version of `transformers` library
    fp16=str(device).startswith('cuda'),
    # evaluation on dev set
    eval_strategy='epoch',
    # model saving
    metric_for_best_model='f1', # use 'f1_macro' for multiclass classification
    greater_is_better=True,
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=2,
    # logging
    logging_strategy='epoch',
    logging_dir=dest+'logs',
    # for reproducibility
    seed=SEED,
    data_seed=SEED,
    full_determinism=True
)

Now we can create a `Trainer` instance that handles the fine-tuning and dev split evaluation.
We call this object `trainer`.

In [21]:
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=data_splits['train'],
    eval_dataset=data_splits['dev'],
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer),
)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Fine-tune

In [22]:
trainer.train()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/714 [00:00<?, ?it/s]

{'loss': 0.3803, 'grad_norm': 2.3094935417175293, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}


  0%|          | 0/26 [00:00<?, ?it/s]

{'eval_loss': 0.26127269864082336, 'eval_accuracy': 0.9014778325123153, 'eval_accuracy_balanced': 0.8978071555975239, 'eval_f1': 0.8847262247838616, 'eval_precision': 0.9002932551319648, 'eval_recall': 0.8696883852691218, 'eval_runtime': 20.7806, 'eval_samples_per_second': 39.075, 'eval_steps_per_second': 1.251, 'epoch': 1.0}
{'loss': 0.2125, 'grad_norm': 0.3254886567592621, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}


  0%|          | 0/26 [00:00<?, ?it/s]

{'eval_loss': 0.26307421922683716, 'eval_accuracy': 0.9187192118226601, 'eval_accuracy_balanced': 0.9150203361168201, 'eval_f1': 0.9046242774566474, 'eval_precision': 0.9233038348082596, 'eval_recall': 0.886685552407932, 'eval_runtime': 8.704, 'eval_samples_per_second': 93.291, 'eval_steps_per_second': 2.987, 'epoch': 2.0}
{'loss': 0.1355, 'grad_norm': 0.19413040578365326, 'learning_rate': 0.0, 'epoch': 3.0}


  0%|          | 0/26 [00:00<?, ?it/s]

{'eval_loss': 0.3295578956604004, 'eval_accuracy': 0.9273399014778325, 'eval_accuracy_balanced': 0.9262437741857839, 'eval_f1': 0.9165487977369166, 'eval_precision': 0.9152542372881356, 'eval_recall': 0.9178470254957507, 'eval_runtime': 8.5952, 'eval_samples_per_second': 94.471, 'eval_steps_per_second': 3.025, 'epoch': 3.0}
{'train_runtime': 463.8785, 'train_samples_per_second': 24.53, 'train_steps_per_second': 1.539, 'train_loss': 0.2427623064912, 'epoch': 3.0}


TrainOutput(global_step=714, training_loss=0.2427623064912, metrics={'train_runtime': 463.8785, 'train_samples_per_second': 24.53, 'train_steps_per_second': 1.539, 'total_flos': 377377822814460.0, 'train_loss': 0.2427623064912, 'epoch': 3.0})

In [25]:
trainer.evaluate(data_splits['test'], metric_key_prefix='test')

  0%|          | 0/26 [00:00<?, ?it/s]

{'test_loss': 0.43572258949279785,
 'test_accuracy': 0.9076354679802956,
 'test_accuracy_balanced': 0.906197732476686,
 'test_f1': 0.8939179632248939,
 'test_precision': 0.8926553672316384,
 'test_recall': 0.8951841359773371,
 'test_runtime': 29.0263,
 'test_samples_per_second': 27.975,
 'test_steps_per_second': 0.896,
 'epoch': 3.0}