# Fine-tuning a Text Classifier with PyTorch Lightning and AI Accelerators
![Transformer block with an added classification head](./fine-tune.png)

Image Source: [How to Fine-Tune Language Models: First Principles to Scalable Performance](https://pub.towardsai.net/how-to-fine-tune-language-models-first-principles-to-scalable-performance-78f42b02f112)

## 1. Preprocessing

### Set a Random Seed

In [1]:
import lightning as L

L.seed_everything(42, workers=True)

Seed set to 42


42

### Read Data and Encode Labels Numerically

In [None]:
import pandas as pd
from datasets import load_dataset
 
# If you don't have a local data file to run this notebook, proceed as follows to download a dataset from Hugging Face:
all_data = load_dataset("intel/polite-guard", split="train")
all_data = all_data.to_pandas().sample(n=500)

# Assuming that your data is in "data.csv" with columns "label" and "text", comment out the previous lines and uncomment the following line:
# all_data = pd.read_csv("data.csv")

In [3]:
# Assign label mappings
unique_labels = sorted(all_data["label"].unique())
num_labels = len(unique_labels)

id2label = {k:v for k, v in enumerate(unique_labels)}
label2id = {v:k for k, v in id2label.items()}

all_data["label"] = all_data["label"].map(label2id)
label2id

{'impolite': 0, 'neutral': 1, 'polite': 2, 'somewhat polite': 3}

### Do a Train | Validation | Test Split

In [4]:
from sklearn.model_selection import train_test_split

# First split: 80% train, 20% temp (to be split into val and test)
train_df, temp = train_test_split(all_data, test_size=0.2, shuffle=True)

# Second split: 50% val, 50% test from the 20%
val_df, test_df = train_test_split(temp, test_size=0.5, shuffle=True)

## 2. Tokenization

### Pick a Pretrained Transformer and Load its Tokenizer

In [5]:
from transformers import AutoTokenizer

model_ckpt = "bert-base-uncased" # Pretrained transformer model for fine-tuning
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

### Tokenize and Create DataLoaders

In [None]:
import os

import torch
from torch.utils.data import DataLoader, Dataset

batch_size = 32
num_workers = 7

class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_dataframe = hasattr(data, "iloc")  # Check if the data is a dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.is_dataframe:
            row = self.data.iloc[idx]
            text = row["text"]
            label = row["label"]
        else:
            text = self.data[idx]
            label = None  # No label if data is a list of strings

        # Tokenize the input text
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        item = {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
        }

        if label is not None:
            item["label"] = torch.tensor(label, dtype=torch.long)

        return item


os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Create Dataset objects
train_dataset = TextDataset(train_df, tokenizer)
val_dataset = TextDataset(val_df, tokenizer)
test_dataset = TextDataset(test_df, tokenizer)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

## 3. Fine-tuning

### Add a Classification Head to the Pretrained Transformer and Prepare the Model for Training

In [7]:
import torchmetrics
from torchmetrics.classification import F1Score

from transformers import (
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
)


class LightningModel(L.LightningModule):
    def __init__(self, model_name, num_labels, label2id, id2label, learning_rate=5e-5, weight_decay=0.01):
        super().__init__()
        self.save_hyperparameters()

        # Initialize hyperparameters and model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.id2label = id2label
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, label2id=label2id, id2label=id2label)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Metrics
        self.val_f1 = F1Score(num_classes=num_labels, task="multiclass", average="weighted")
        self.test_f1 = F1Score(num_classes=num_labels, task="multiclass", average="weighted")

        self.val_acc = torchmetrics.Accuracy(num_classes=num_labels, task="multiclass")
        self.test_acc = torchmetrics.Accuracy(num_classes=num_labels, task="multiclass")

    def forward(self, input_ids, attention_mask, labels=None):
        return self.model(input_ids, attention_mask=attention_mask, labels=labels)

    def _shared_step(self, batch, stage):
        """
        A single step function for training, validation, and testing.
        """
        outputs = self(batch["input_ids"], batch["attention_mask"], labels=batch["label"])
        logits = outputs["logits"]
        loss = outputs["loss"]
        labels = batch["label"]

        # Update metrics
        if stage == "train":
            self.log("train_loss", loss)
            return loss
        
        if stage == "val":
            self.val_acc(logits, labels)
            self.val_f1(logits, labels)
            self.log("val_acc", self.val_acc, prog_bar=True)
            self.log("val_f1", self.val_f1, prog_bar=True)
            
        if stage == "test":
            self.test_acc(logits, labels)
            self.test_f1(logits, labels)
            self.log("test_acc", self.test_acc, prog_bar=True)
            self.log("test_f1", self.test_f1, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        self._shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        
        # Compute the number of training steps
        num_training_steps = len(train_loader) * self.trainer.max_epochs
        num_warmup_steps = int(0.1 * num_training_steps)  # 10% warm-up

        lr_scheduler = {
            "scheduler": get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps),
            "interval": "step",  # Update every step
            "frequency": 1
        }
        
        return [optimizer], [lr_scheduler]

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(batch["input_ids"], batch["attention_mask"])
        logits = outputs["logits"]
        predictions = torch.argmax(logits, dim=-1)
        # Convert numeric predictions to labels using id2label
        return [self.id2label[pred.item()] for pred in predictions]
        
    def save_to_hub(
        self,
        repo_name,
        private=False,
        model_commit_message="Add fine-tuned model",
        tokenizer_commit_message="Add tokenizer",
        token=None
    ):
        """
        Push the model to the Hugging Face Hub
        """
        self.model.push_to_hub(
            repo_name,
            private=private,
            commit_message=model_commit_message,
            token=token
        )
        self.tokenizer.push_to_hub(
            repo_name,
            private=private,
            commit_message=tokenizer_commit_message,
            token=token
        )

lightning_model = LightningModel(model_ckpt, num_labels=num_labels, label2id=label2id, id2label=id2label, learning_rate=5e-5, weight_decay=0.01) 

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Fit the Model to the Training Dataset

In [8]:
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

callbacks = [
        ModelCheckpoint(save_top_k=1, mode="max", monitor="val_f1"),
        EarlyStopping(monitor="val_f1", patience=3, min_delta=0.005, mode="max", verbose=True),
    ]
logger = TensorBoardLogger(save_dir="./logs", name="Best-Validation-F1")

trainer = L.Trainer(
        max_epochs=2,
        callbacks=callbacks,
        accelerator="auto",
        precision="bf16-mixed",
        devices="auto",
        logger=logger,
        log_every_n_steps=10,
    )

trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name     | Type                          | Params | Mode 
-------------------------------------------------------------------
0 | model    | BertForSequenceClassification | 109 M  | eval 
1 | val_f1   | MulticlassF1Score             | 0      | train
2 | test_f1  | MulticlassF1Score             | 0      | train
3 | val_acc  | MulticlassAccuracy            | 0      | train
4 | test_acc | MulticlassAccuracy            | 0      | train
-------------------------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.941   Total estimated model params size (MB)
4         Modules in train mode
231       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_f1 improved. New best score: 0.980


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_f1 improved by 0.020 >= min_delta = 0.005. New best score: 1.000
`Trainer.fit` stopped: `max_epochs=2` reached.


### Test the Best Model

In [None]:
trainer.test(lightning_model, test_loader, ckpt_path="best")

Restoring states from the checkpoint path at ./logs/Best-Validation-F1/version_1/checkpoints/epoch=1-step=26.ckpt
Loaded model weights from the checkpoint at ./logs/Best-Validation-F1/version_1/checkpoints/epoch=1-step=26.ckpt


In [None]:
texts = [
    "I sincerely apologize for the inconvenience you've experienced. Please allow me a moment to resolve this for you as quickly as possible.",
    "I understand this isn't ideal, but could we move forward with this solution?",
    "The product specifications are as follows.",
    "You must be new here; you clearly don't know what you're doing.",
]
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset)

outputs = trainer.predict(lightning_model, dataloaders=dataloader)
for text, output in zip(texts, outputs):
    print(f'"{text}": {output[0]}')