# Fine-tune multi-label classifier using `setfit`

## Setup 

#### Colab (if using Colab)

In [1]:
# check if on Colab
COLAB = True
try:
  from google import colab
except:
  COLAB = False

if COLAB:
    # shallow clone of current state of main branch 
    !git clone --branch main --single-branch --depth 1 --filter=blob:none https://github.com/haukelicht/advanced_text_analysis.git
    
    # make repo root findable for python
    import sys
    sys.path.append("/content/advanced_text_analysis/")

    !pip -q "sentence-transformers==5.1.0 setfit==1.1.3"

#### Load required libraries

In [2]:
# !pip install setfit==1.1.3

In [3]:
from pathlib import Path
import shutil

import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict

import torch
from transformers import set_seed
from src.setfit_utils import model_init as setfit_model_init
from src.setfit_utils import get_class_weights
from setfit import TrainingArguments, Trainer

In [4]:
from sklearn.metrics import hamming_loss, accuracy_score, f1_score, label_ranking_loss

def multilabel_metrics(y_pred, y_true):
    # Apply sigmoid and threshold
    probs = 1 / (1 + np.exp(-y_pred))  # Sigmoid
    pred_binary = (probs > 0.5).astype(int)

    return {
        "hamming_loss": hamming_loss(y_true, pred_binary),
        "subset_accuracy": accuracy_score(y_true, pred_binary),
        "f1_macro": f1_score(y_true, pred_binary, average="macro"),
        "f1_micro": f1_score(y_true, pred_binary, average="micro"),
        "ranking_loss": label_ranking_loss(y_true, probs),
    }

**Interpretation**

- *Hamming Loss*
  - Measures the fraction of labels that are incorrectly predicted (either a 0 instead of 1 or vice versa).
  - Lower is better; `0.0` means perfect prediction.
  - Formula: `(number of wrong labels) / (number of total labels)`
  - Good for understanding average label-wise error rate.

- *Subset Accuracy (Exact Match Ratio)*
  - Fraction of examples where **all** labels are predicted correctly.
  - Very strict; requires the entire label set to be correct per sample.
  - Value ranges from `0.0` (no perfect predictions) to `1.0` (all perfect).
  - Not very forgiving if you're slightly wrong on multi-hot labels.

- *F1-Macro*
  - Calculates F1 score **per label**, then takes the unweighted average.
  - Treats all labels equally regardless of how often they appear.
  - Sensitive to performance on rare labels.
  - Useful when class imbalance is a concern and all labels are important.

- *F1-Micro*
  - Aggregates true positives, false positives, and false negatives across all labels before computing F1.
  - Gives more weight to frequent labels.
  - Better when the number of positive examples per label varies a lot.
  - Often higher than macro F1 in imbalanced datasets.

- *Ranking Loss*
  - Measures how often a **relevant label** is ranked lower than an irrelevant one.
  - Lower is better; `0.0` means perfect ranking.
  - Requires access to the **raw prediction scores** (before thresholding).
  - Useful in retrieval or recommendation scenarios where ranking quality matters.


In [5]:
# check which device is available
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

In [None]:
SEED = 42
set_seed(SEED)

In [None]:
base_path = Path("/content/advanced_text_analysis/" if COLAB else "../../")
data_path = base_path / "data/labeled/erlich_multilabel_2023"

## Load and prepare the data

In [None]:
fp = data_path / "erlich_multilabel_2023-ati_reqeuests.tsv"
if not fp.exists():
    url = "https://cta-text-datasets.s3.eu-central-1.amazonaws.com/labeled/erlich_multilabel_2023/erlich_multilabel_2023-ati_reqeuests.tsv"
    df = pd.read_csv(url, sep="\t")
    fp.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(fp, sep="\t", index=False)

In [None]:
from src.utils.io import read_tabular

df = read_tabular(fp)

*NOTE*: `setfit` is a few-shot learning framework, so it can work with very little data; so let's subsample

In [None]:
df = df.sample(2_000, random_state=SEED)

In [None]:
del df['text']
df.rename(columns={'text_en': 'text'}, inplace=True)

In [None]:
from src.finetuning import split_data

data_splits = split_data(df, dev_size=0.25, test_size=0.25, seed=SEED, return_dict=True)

In [None]:
datasets = DatasetDict({
    s: Dataset.from_pandas(df, preserve_index=False)
    for s, df in data_splits.items()
})

In [None]:
label_cols = datasets.column_names['train']
label_cols = [c for c in label_cols if c not in ['id', 'text']]

id2label = {i: l for i, l in enumerate(label_cols)}
label2id = {l: i for i, l in enumerate(label_cols)}

id2label

In [None]:
def format_labels(example):
    example['labels'] = [float(example[col]) for col in label_cols]
    return example

datasets = datasets.map(format_labels)

In [None]:
keep_cols = ['text' 'labels']
rm_cols = [col for col in datasets['train'].column_names if col not in keep_cols]
datasets.remove_columns(rm_cols);

In [None]:
datasets.num_rows

## Prepare the model fine-tuning

In [None]:
labs = np.array(datasets['train']['labels'])
class_weights = get_class_weights(labs, multitarget=True)
class_weights = class_weights.astype(float)
class_weights

In [None]:
def model_init():
    return setfit_model_init(
        model_name=MODEL_NAME,
        id2label=id2label,
        multitarget_strategy='one-vs-rest', # !!!
        class_weights=class_weights,
    )

In [None]:
from src.metrics import compute_sequence_classification_metrics_multilabel
def compute_metrics(y_pred, y_true):
    # return multilabel_metrics(y_pred.numpy(), 
    return compute_sequence_classification_metrics_multilabel(np.array(y_true), y_pred.numpy())

### Define the training arguments

In [None]:
model_path = base_path / "models" / "erlich_multilabel_setfit"

In [None]:
out_dir = model_path
checkpoints_dir = out_dir / 'checkpoints'
logs_dir = out_dir / 'logs'

from sentence_transformers.losses import CosineSimilarityLoss
from transformers import EarlyStoppingCallback

training_args = TrainingArguments(
    
    num_epochs=(1, 10), # embedding finetuning and classification head training epochs, respectively
    max_steps=250,
    batch_size=(16, 8),
    max_length=512,
    end_to_end=True,
    
    loss=CosineSimilarityLoss,
    
    # when to evaluate
    eval_strategy='steps',
    eval_steps=50,
    eval_max_steps=50,
    # how to select "best" model
    # do_eval=bool('dev' in datasets),
    metric_for_best_model='embedding_loss',
    load_best_model_at_end=True,
    # when to save
    save_strategy='steps',
    save_steps=50,
    save_total_limit=2 if 'dev' in datasets else None, # don't save all model checkpoints
    # where to store results
    output_dir=checkpoints_dir,
    
    # logging
    logging_dir=logs_dir,
    logging_strategy='steps',
    
    report_to='none',
    seed=SEED,
)

# build callbacks
callbacks = []
if 'dev' in datasets:
    callbacks.append(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.03))

### Create the trainer

In [None]:
trainer = Trainer(
    model_init=model_init,
    train_dataset=datasets['train'],
    eval_dataset=datasets['dev'],
    args=training_args,
    column_mapping={"text": "text", "labels": "label"},
    metric=compute_metrics,
    callbacks=callbacks,
)

# for deterministic results
trainer._args.seed = SEED
trainer.st_trainer.args.seed = SEED
trainer.st_trainer.args.data_seed = SEED
trainer.st_trainer.args.full_determinism = True

# adapt max length
trainer.model.model_body.tokenizer.model_max_length = training_args.max_length
trainer.model.model_body.max_seq_length = training_args.max_length

## Train

In [None]:
print('Training ...')
trainer.train()

### Evaluate

In [None]:
datasets['test']

In [None]:
# apply the best model loaded after finishing training to the test set
print('Evaluating ...')
test_res = trainer.evaluate(datasets['test'], metric_key_prefix='test')

In [None]:
test_res

### Inference

In [None]:
classifier = trainer.model.predict

In [None]:
# for batch inference (see https://huggingface.co/docs/transformers/pipeline_tutorial#batch-inference)
preds = classifier(data_splits['test']['text'].tolist(), as_numpy=True)

In [None]:
preds_df = pd.DataFrame(preds, columns=[l+'__pred' for l in label_cols], index=data_splits['test'].index)

In [None]:
tmp = pd.concat([data_splits['test'], preds_df], axis=1)
tmp = tmp.melt(id_vars=['id', 'text'], var_name='col', value_name='value')
tmp['what'] = 'obs'
tmp.loc[tmp['col'].str.endswith('__pred'), 'what'] = 'pred'
tmp['col'] = tmp['col'].str.removesuffix('__pred')
tmp = tmp.pivot(index=['id', 'text', 'col'], columns='what', values='value').reset_index()

In [None]:
# distribution of how many of the `len(label_cols)` labels were misclassified per text
tmp.assign(misclassified=lambda df: df.eval('obs != pred').astype(int)).groupby('id').misclassified.sum().value_counts().sort_index().to_frame(name='n')

In [None]:
# Sample examples of misclassifications
miss_examples = tmp.query('obs!=pred').assign(misclassified=lambda df: 'miss').pivot_table(index=['id', 'text'], columns='col', values='misclassified', fill_value='', aggfunc='first').reset_index().sample(20, random_state=SEED)
miss_examples.columns.name = None
miss_examples

## Finally

#### Delete intermediate checkpoints and log files

In [None]:
# finally: clean up
if checkpoints_dir.exists():
    shutil.rmtree(checkpoints_dir)
if logs_dir.exists():
    shutil.rmtree(logs_dir)

#### Save the best model (if desired)

In [None]:
trainer.model._save_pretrained(out_dir)

### Free the GPU and remove large objects

In [None]:
import gc
trainer = trainer.model.to('cpu')
del trainer, datasets
gc.collect()

: 