<a href="https://colab.research.google.com/github/calvinli2024/mitbforalldemo/blob/main/fine_tuning/bert-fine-tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Purpose

The purpose of this notebook is to experiment with building a multilabel classifier in PyTorch

# Setup

## Packages

In [None]:
%pip install datasets transformers torchmetrics

## Imports

In [None]:
from tqdm import tqdm
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, TrainingArguments, Trainer, EvalPrediction
from datasets import load_dataset, Features, Value, Sequence
import pandas as pd
from torchmetrics.classification import MultilabelF1Score
from torch.cuda import empty_cache
from torch import from_numpy

## Dataset

### Philosophy Schools Multilabel

In [None]:
phil_ds = load_dataset("maximuspowers/philosophy-schools-multilabel", split="train")

remove_cols = ["title", "description", "link", "source", "philosophy_schools"]

phil_ds = phil_ds.remove_columns(remove_cols)

label_cols = [col for col in phil_ds.features.keys() if col != 'summary']

id2label = {idx:label for idx, label in enumerate(label_cols)}
label2id = {label:idx for idx, label in enumerate(label_cols)}

tokenizer = DistilBertTokenizerFast.from_pretrained(
    "distilbert/distilbert-base-uncased"
)

### Encoding

In [None]:
def encode_data(data):
  summaries = data["summary"]

  encoding = tokenizer(summaries, truncation=True, padding="max_length", max_length=512)

  labels = pd.DataFrame({})
  for idx, label_col in enumerate(label_cols):
    labels.loc[:, label_col] = data[label_col]

  encoding["labels"] = labels.values.tolist()

  return encoding

encoded_dataset = phil_ds.map(encode_data, batched=True, remove_columns=phil_ds.column_names)

encoded_dataset.set_format("torch")

encoded_dataset = encoded_dataset.cast(Features({
  "labels": Sequence(Value("float32")),
  "input_ids": Sequence(Value("int32")),
  "attention_mask": Sequence(Value("int8"))
}))

train_valid_split = encoded_dataset.train_test_split(test_size=0.2)

## MultiLabel Classifier

In [None]:
model = DistilBertForSequenceClassification.from_pretrained(
  "distilbert/distilbert-base-uncased",
  num_labels=len(label_cols),
  problem_type="multi_label_classification",
  id2label=id2label,
  label2id=label2id
)

# Train

## Compute Metrics

In [None]:
def compute_metrics(ep: EvalPrediction):
  metric = MultilabelF1Score(num_labels=len(label_cols))

  f1 = metric(from_numpy(ep.predictions), from_numpy(ep.label_ids))

  return {"f1": f1}

## Training Arguments

In [None]:
batch_size = 32

training_args = TrainingArguments(
    "bert-fine-tuning",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=0.001,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    report_to="none"
)

## Trainer

In [None]:
empty_cache()

trainer = Trainer(
    model,
    training_args,
    train_dataset=train_valid_split["train"],
    eval_dataset=train_valid_split["test"],
    processing_class=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()