<a href="https://colab.research.google.com/github/calvinli2024/mitbforalldemo/blob/main/fine_tuning/bert-fine-tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Purpose

The purpose of this notebook is to experiment with building a multilabel classifier in PyTorch

# Setup

## Packages

In [None]:
%pip install pytorch-ignite datasets transformers
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126

## Imports

In [None]:
from torch import nn, optim
import torch.nn.functional as F
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, logging
from ignite.engine import Engine, Events
from ignite.metrics import Precision, Recall, Loss
from ignite.contrib.handlers import ProgressBar
from ignite.handlers import Checkpoint, EarlyStopping, DiskSaver
from datasets import load_dataset
from sklearn.model_selection import train_test_split, KFold
import pandas as pd

## Dataset

### Philosophy Schools Multilabel

In [None]:
phil_ds = load_dataset("maximuspowers/philosophy-schools-multilabel")["train"]

phil_df = pd.DataFrame(phil_ds)

columns = list(phil_df.columns)

for remove in ["title", "description", "link", "source", "philosophy_schools"]:
  columns.pop(columns.index(remove))

phil_df = phil_df.loc[:, columns]

train, test = train_test_split(phil_df, test_size=0.2, shuffle=True, random_state=42)

x_train = list(train.loc[:, "summary"].values)
x_test = list(test.loc[:, "summary"].values)

columns.pop(columns.index("summary"))

y_train = list(train.loc[:, columns].values)
y_test = list(test.loc[:, columns].values)

cv = KFold(n_splits=5, shuffle=True, random_state=42)

### BertTextDataset

In [None]:
class BertTextDataset(Dataset):
  def __init__(self, input, target):
    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert/distilbert-base-uncased")

    self.input = tokenizer(input, return_tensors='pt', truncation=True, padding=True, max_length=512)

    self.target = target

  def __len__(self):
    return len(self.input['input_ids'])

  def __getitem__(self, idx):
    return {
      "input_ids": self.input['input_ids'][idx],
      "attention_mask": self.input['attention_mask'][idx],
      "labels": torch.tensor(self.target[idx])
    }

train_ds = BertTextDataset(x_train, y_train)
test_ds = BertTextDataset(x_test, y_test) #holdout test set

# Device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device '{device}'")

## MultiLabel Classifier

In [None]:
model = DistilBertForSequenceClassification.from_pretrained(
  "distilbert/distilbert-base-uncased",
  num_labels=len(y_train[0]),
  problem_type="multi_label_classification"
)

model.to(dtype=torch.float, device=device)

## Optimizer

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.01)

## Loss Function

In [None]:
loss_fn = nn.BCEWithLogitsLoss()

## Scheduler

In [None]:
scheduler = ExponentialLR(optimizer, gamma=0.9)

# Fine Tuning

## Train

In [None]:
def train_step(
    engine,
    batch
):
  model.train()

  optimizer.zero_grad()

  input_ids = batch['input_ids'].to(device=device)
  attention_mask = batch['attention_mask'].to(device=device)
  labels = batch['labels'].to(dtype=torch.float, device=device)

  output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

  logits = output.logits

  loss = loss_fn(logits, labels)

  loss.backward()

  optimizer.step()

  return loss.item()

trainer = Engine(train_step)

## Test

In [None]:
def eval_step(engine, batch):
    model.eval()

    with torch.no_grad():
      input_ids = batch['input_ids'].to(device=device)
      attention_mask = batch['attention_mask'].to(device=device)
      labels = batch['labels'].to(dtype=torch.float, device=device)

      seq_clf_output = model(input_ids=input_ids, attention_mask=attention_mask)

      probs = torch.sigmoid(seq_clf_output.logits)
      preds = (probs >= 0.5).float()

    return preds, labels

evaluator = Engine(eval_step)

metric = Loss(loss_fn)

metric.attach(evaluator, 'loss')

precision = Precision(average=False)
recall = Recall(average=False)

F1 = (precision * recall * 2 / (precision + recall)).mean()

metric.attach(F1, 'f1')

## Checkpoint

In [None]:
checkpointer = Checkpoint(
  to_save={'model': model, 'optimizer': optimizer, 'trainer': trainer},
  save_handler=DiskSaver('./', create_dir=False, require_empty=False),
  n_saved=2
)

trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpointer)

## Early Stopping

In [None]:
def score_function(engine):
  return engine.state.metrics['loss']

early_stopping = EarlyStopping(patience=1, score_function=score_function, trainer=trainer)

evaluator.add_event_handler(Events.COMPLETED, early_stopping)

## Scheduler Step

In [None]:
@trainer.on(Events.EPOCH_COMPLETED(every=1))
def scheduler_step():
  scheduler.step()

## Run

In [None]:
max_epochs = 9
validate_every = 3

ProgressBar().attach(trainer)

for fold, (train_idx, valid_idx) in enumerate(cv.split(train_ds)):
  print(f"Fold: {fold}")
  print("-----------------------------")

  train_dl = DataLoader(
    train_ds,
    batch_size=512,
    sampler=torch.utils.data.SubsetRandomSampler(train_idx),
  )

  valid_dl = DataLoader(
    train_ds,
    batch_size=512,
    sampler=torch.utils.data.SubsetRandomSampler(valid_idx),
  )

  @trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
  def run_eval():
    evaluator.run(valid_dl)

  @trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
  def log_evaluation_results():
    metrics = evaluator.state.metrics

    print(f"Epoch: {trainer.state.epoch}, Loss: {metrics['loss']}, F1: {metrics['f1']}\n")

  trainer.run(train_dl, max_epochs=max_epochs)

## Holdout Test

In [None]:
test_dl = DataLoader(
  test_ds,
  batch_size=256
)

all_outputs = []
all_targets = []

model.eval()

for data in test_dl:
  with torch.no_grad():
    input_ids = data['input_ids'].to(device=device)
    attention_mask = data['attention_mask'].to(device=device)
    labels = data['labels'].to(dtype=torch.float, device=device)

    seq_clf_output = model(input_ids=input_ids, attention_mask=attention_mask)

    probs = torch.sigmoid(seq_clf_output.logits)
    preds = (probs >= 0.5).float()

    F1.update((preds.cpu(), labels.cpu()))

f1_score = F1.compute()

print(f"Holdout F1: {f1_score}")