# Fine-tune multi-label classifier

## Setup

#### Colab (if necessary)

In [None]:
# check if on Colab
COLAB = True
try:
  from google import colab
except:
  COLAB = False

if COLAB:
    # shallow clone of current state of main branch 
    !git clone --branch main --single-branch --depth 1 --filter=blob:none https://github.com/haukelicht/advanced_text_analysis.git
    
    # make repo root findable for python
    import sys
    sys.path.append("/content/advanced_text_analysis/")


#### Load required libraries

In [None]:
from pathlib import Path
import shutil

import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict

from src.utils.io import read_tabular
from src.finetuning import split_data

import torch
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
    set_seed,
)

from src.metrics import (
    parse_sequence_classifier_prediction_output_multilabel,
    compute_sequence_classification_metrics_multilabel
)

In [None]:
def compute_metrics(pred):
    y_true, y_pred = parse_sequence_classifier_prediction_output_multilabel(pred)
    return compute_sequence_classification_metrics_multilabel(y_true, y_pred)

In [None]:
# check which device is available
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
device

In [None]:
SEED = 42
set_seed(SEED)

In [None]:
MODEL_NAME = "answerdotai/ModernBERT-base"

In [None]:
base_path = Path("/content/advanced_text_analysis/" if COLAB else "../../")


## Load and prepare the data

In [None]:
data_path = base_path / "data" / "labeled" / "erlich_multilabel_2023"

In [None]:
fp = data_path / "erlich_multilabel_2023-ati_reqeuests.tsv"
if not fp.exists():
    url = "https://cta-text-datasets.s3.eu-central-1.amazonaws.com/labeled/erlich_multilabel_2023/erlich_multilabel_2023-ati_reqeuests.tsv"
    df = pd.read_csv(url, sep="\t")
    fp.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(fp, sep="\t", index=False)

In [None]:
df = read_tabular(fp)

In [None]:
del df['text']
df.rename(columns={'text_en': 'text'}, inplace=True)

In [None]:
data_splits = split_data(df, dev_size=0.10, test_size=0.15, seed=SEED, return_dict=True)

In [None]:
datasets = DatasetDict({
    s: Dataset.from_pandas(df, preserve_index=False)
    for s, df in data_splits.items()
})

In [None]:
label_cols = datasets.column_names['train'][2:]
label_cols

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

In [None]:
def preprocess_function(examples):
    return tokenizer(examples["text"], max_length=tokenizer.max_len_single_sentence, padding=False, truncation=True)

def format_labels(example):
    example['labels'] = [float(example[col]) for col in label_cols]
    return example

In [None]:
datasets = datasets.map(format_labels)
datasets = datasets.map(preprocess_function)

In [None]:
keep_cols = ['input_ids', 'attention_mask', 'labels']
rm_cols = [col for col in datasets['train'].column_names if col not in keep_cols]
datasets = datasets.remove_columns(rm_cols)

## Prepare the model fine-tuning

In [None]:
# NOTE: the `model_init` function is used by the Trainer to initialize the model
#   and is called each time before training starts.
#  So we define it here to load the model from the Huggingface model hub
#   and set the number of labels to the number of unique labels in the dataset
#   and the label2id and id2label mappings
def model_init():
    config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
    config.num_labels = len(label_cols)
    config.label2id = {l: i for i, l in enumerate(label_cols)}
    config.id2label = {i: l for i, l in enumerate(label_cols)}
    config.problem_type = "multi_label_classification"
    return AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, config=config, trust_remote_code=True, device_map='auto')

### Define the training arguments

In [None]:
model_path = base_path / "models" / "erlich_multilabel"

In [None]:
out_dir = model_path
checkpoints_dir = out_dir / 'checkpoints'
logs_dir = out_dir / 'logs'

training_args = TrainingArguments(
    
    # hyperparameters
    num_train_epochs=10,
    learning_rate=4e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=32,
    weight_decay=0.3,
    optim='adamw_torch',
    
    # when to evaluate
    eval_strategy='epoch',
    # how to select "best" model
    do_eval=bool('dev' in datasets),
    metric_for_best_model='f1_macro',
    load_best_model_at_end=True,
    # when to save
    save_strategy='epoch',
    save_total_limit=2 if 'dev' in datasets else None, # don't save all model checkpoints
    # where to store results
    output_dir=checkpoints_dir,
    overwrite_output_dir=True,
    
    # logging
    logging_dir=logs_dir,
    logging_strategy='epoch',
    
    # reproducibility
    seed=SEED,
    data_seed=SEED,
    full_determinism=True
)


# build callbacks
callbacks = []
if 'dev' in datasets:
    callbacks.append(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.03))

### Create the trainer

In [None]:
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=datasets['train'],
    eval_dataset=datasets['dev'] if 'dev' in datasets else None,
    processing_class=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
    callbacks=callbacks
)

## Train

In [None]:
print('Training ...')
train_hist = trainer.train()

### Evaluate

In [None]:
# apply the best model loaded after finishing training to the test set
print('Evaluating ...')
test_res = trainer.evaluate(datasets['test'], metric_key_prefix='test')

In [None]:
test_res

In [None]:
preds = trainer.predict(datasets['test'])

In [None]:
# compute sigmoid and threshold
probs = 1 / (1 + np.exp(-preds.predictions))  # Sigmoid
pred_binary = (probs > 0.5).astype(int)

# create a confusion matrix
from sklearn.metrics import multilabel_confusion_matrix
cm = multilabel_confusion_matrix(preds.label_ids, pred_binary)

# reshape to 2D
cm_wide = cm.reshape(-1, 4)

# create a DataFrame
cm_df = pd.DataFrame(cm_wide, columns=['TN', 'FP', 'FN', 'TP'])
cm_df.index = label_cols
cm_df = cm_df[['TP', 'TN', 'FP', 'FN']] 
cm_df /= len(datasets['test'])
cm_df.round(3)

In [None]:
# which label classes get most frquently confused?
cm_df['FP_rate'] = cm_df['FP'] / (cm_df['FP'] + cm_df['TN'])
cm_df['FN_rate'] = cm_df['FN'] / (cm_df['FN'] + cm_df['TP'])
cm_df['FP_rate'].sort_values(ascending=False).head(10)

In [None]:
# which label classes get most frquently confused with each other?
conf = np.zeros((len(label_cols), len(label_cols)))
for y_true, y_pred in zip(preds.label_ids, pred_binary):
    # get label classes with false positive classification
    fp = np.where((y_true == 0) & (y_pred == 1))[0]
    # get label classes with false negative classification
    fn = np.where((y_true == 1) & (y_pred == 0))[0]
    # update confusion matrix
    for i in fp:
        for j in fn:
            conf[i, j] += 1
conf_df = pd.DataFrame(conf, index=label_cols, columns=label_cols)
conf_df = conf_df.astype(int)

# pivot longer
conf_df = conf_df.stack().reset_index()
conf_df.columns = ['label_true', 'label_pred', 'count']
# conf_df = conf_df[conf_df['count'] > 0]
# conf_df = conf_df.sort_values(by='count', ascending=False)
conf_df.pivot(index='label_true', columns='label_pred', values='count').fillna(0).astype(int)

### Inference

In [None]:
from datasets import Dataset

from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset

from tqdm import tqdm

We'll use the transformer's `pipeline` for inference (i.e., predicting spans in unlabeled data).

Specifically, we use the **NER** (named entity recognition) task and pass the fine-tuned model from the trainer.

In [None]:
classifier = pipeline("text-classification", model=trainer.model, tokenizer=tokenizer, return_all_scores=True, top_k=None)

In [None]:
# for batch inference (see https://huggingface.co/docs/transformers/pipeline_tutorial#batch-inference)
kd = KeyDataset(Dataset.from_pandas(data_splits['test'][['text']]), 'text')

In [None]:
# apply the extractor to the dataset
preds = [p for p in tqdm(classifier(kd, batch_size=64), total=len(kd))]

In [None]:
preds_df = pd.DataFrame([{p['label']+'__pred': int(p['score'] >= 0.5) for p in pred} for pred in preds])
preds_df.index = data_splits['test'].index

In [None]:
tmp = pd.concat([data_splits['test'], preds_df], axis=1)
tmp = tmp.melt(id_vars=['id', 'text'], var_name='col', value_name='value')
tmp['what'] = 'obs'
tmp.loc[tmp['col'].str.endswith('__pred'), 'what'] = 'pred'
tmp['col'] = tmp['col'].str.removesuffix('__pred')
tmp = tmp.pivot(index=['id', 'text', 'col'], columns='what', values='value').reset_index()

In [None]:
# distribution of how many of the `len(label_cols)` labels were misclassified per text
tmp.assign(misclassified=lambda df: df.eval('obs != pred').astype(int)).groupby('id').misclassified.sum().value_counts().sort_index().to_frame(name='n')

In [None]:
# Sample examples of misclassifications
miss_examples = tmp.query('obs!=pred').assign(misclassified=lambda df: 'miss').pivot_table(index=['id', 'text'], columns='col', values='misclassified', fill_value='', aggfunc='first').reset_index().sample(20, random_state=SEED)
miss_examples.columns.name = None
miss_examples

## Finally

#### Delete intermediate checkpoints and log files

In [None]:
import os
# finally: clean up
if checkpoints_dir.exists():
    shutil.rmtree(checkpoints_dir)
if logs_dir.exists():
    shutil.rmtree(logs_dir)

#### Save the best model (if desired)

In [None]:
trainer.save_model(out_dir)
tokenizer.save_pretrained(out_dir);

### Free the GPU and remove large objects

In [None]:
import gc
trainer = trainer.model.to('cpu')
del trainer, tokenizer, datasets
gc.collect()