In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # to enable deterministic behavior with CuBLAS
# NOTE: to avoid error
#   RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` 
#   or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it
#   uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an 
#   environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or 
#   CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to 
#   https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

from transformers import set_seed
set_seed(42, deterministic=True) # for reproducibility

In [None]:
import numpy as np
np.set_printoptions(precision=4, suppress=True)

## Single-label classification

In [None]:
# load a example classification dataset from hf hub
from datasets import load_dataset
train_dataset = load_dataset("ag_news", split="train[:600]")
val_dataset = load_dataset("ag_news", split="test[:200]")
test_dataset = load_dataset("ag_news", split="test[200:400]")

num_classes = len(set(train_dataset["label"]))

In [None]:
from transformers.trainer_utils import PredictionOutput
from sklearn.metrics import f1_score
def compute_metrics(p: PredictionOutput):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    acc = np.sum(preds == labels) / len(labels)
    f1 = f1_score(labels, preds, average="macro")
    return {"accuracy": acc, "macro_f1": f1}

### using the `SetfitModel`'s `fit()` method

In [None]:
from sentence_transformers import SentenceTransformer
from setfit.modeling import SetFitHead
# from src.finetuning.setfit_extensions.class_weights_head import compute_class_weights, SetFitHeadWithClassWeights
from src.finetuning.setfit_extensions.early_stopping import (
    SetFitModelWithEarlyStopping,
    EarlyStoppingTrainingArguments
)

In [None]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"
body = SentenceTransformer(model_id, model_kwargs={"device_map": "auto"})

head = SetFitHead(
    in_features=body.get_sentence_embedding_dimension(),
    out_features=num_classes,
    device=body.device,
)

model = SetFitModelWithEarlyStopping(
    model_body=body,
    model_head=head,
    normalize_embeddings=True,
)
model.to(body.device);

In [None]:
args = EarlyStoppingTrainingArguments()
args.max_length = body.tokenizer.model_max_length

In [None]:
model.fit(
    x_train=train_dataset["text"], y_train=train_dataset["label"],
    x_eval=val_dataset["text"], y_eval=val_dataset["label"],
    
    num_epochs=10,
    batch_size=16,
    body_learning_rate=args.body_classifier_learning_rate,
    head_learning_rate=args.head_learning_rate,
    l2_weight=args.l2_weight,
    
    max_length=body.tokenizer.model_max_length,
    
    show_progress_bar=True,
    end_to_end=True,
    
    # added early stopping arguments
    compute_metrics=compute_metrics,
    metric_for_best_model="macro_f1", # NOTE: must match one of the keys returned by `compute_metrics`
    early_stopping_patience=2,
    early_stopping_threshold=0.03,
    greater_is_better=True,
)

In [None]:
# free GPU
model.to("cpu");
del model
import torch
torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())

### with custom early-stopping trainer class

In [None]:
import numpy as np
from sentence_transformers import SentenceTransformer
from setfit.modeling import SetFitHead
from src.finetuning.setfit_extensions.class_weights_head import (
    compute_class_weights,
    SetFitHeadWithClassWeights
)
from src.finetuning.setfit_extensions.early_stopping import (
    SetFitModelWithEarlyStopping, 
    EarlyStoppingTrainingArguments,
    EarlyStoppingCallback,
    EarlyStoppingTrainer
)

In [None]:
def model_init(
        model_name: str="sentence-transformers/all-MiniLM-L6-v2",
        num_classes: int=2, 
        class_weights: np._typing.NDArray=None,
        **kwargs
    ) -> SetFitModelWithEarlyStopping:
    
    model_kwargs={"device_map": "auto", **kwargs}
    body = SentenceTransformer(model_name, model_kwargs=model_kwargs, trust_remote_code=True)
    
    head_kwargs = dict(
        in_features=body.get_sentence_embedding_dimension(),
        out_features=num_classes,
        device=body.device,
    )
    head = SetFitHeadWithClassWeights(**head_kwargs, class_weights=class_weights) if class_weights is not None else SetFitHead(**head_kwargs)
    
    return SetFitModelWithEarlyStopping(
        model_body=body,
        model_head=head.to(body.device),
        normalize_embeddings=True,
    )

In [None]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"

In [None]:
training_args = EarlyStoppingTrainingArguments(
    num_epochs=(1, 15),
    # sentence transformer (embedding) finetuning arts
    eval_strategy="steps", # NOTE: currently no effect on (early stopping in) classification head training
    eval_steps=25, # NOTE: overwrites 0 epochs above for sentence transformer finetuning
    max_steps=200,
    eval_max_steps=200,
    # early stopping config
    metric_for_best_model=("embedding_loss", "f1"),
    greater_is_better=(False, True),
    load_best_model_at_end=True,
    save_total_limit=2, # NOTE: currently no effect on (early stopping in) classification head training
    # misc
    end_to_end=True,
)

training_callbacks = [
    EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.03), # for sentence transformer finetuning
    EarlyStoppingCallback(early_stopping_patience=4, early_stopping_threshold=0.02), # for classifier finetuning
]

In [None]:
# compute class weights (inversely proportional to class frequencies)
class_weights = compute_class_weights(train_dataset["label"])

In [None]:
# initialize Trainer
trainer = EarlyStoppingTrainer(
    model_init=lambda : model_init(
        model_name=model_id,
        num_classes=num_classes,
        class_weights=class_weights,
    ),
    metric="f1",
    metric_kwargs={"average": "macro"},
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=training_callbacks,
)
# fix max_length issue
trainer._args.max_length = trainer.st_trainer.model.tokenizer.model_max_length

# set seeds for reproducibility
trainer._args.seed = 42
trainer.st_trainer.args.seed = 42
trainer.st_trainer.args.data_seed = 42
trainer.st_trainer.args.full_determinism = True

# don't report to wandb or other experiment trackers
trainer._args.report_to = 'none'
trainer.st_trainer.args.report_to = 'none'

In [None]:
# train
trainer.train()

In [None]:
# verify best model loaded
trainer.evaluate(val_dataset)

In [None]:
# eval
trainer.evaluate(test_dataset, "test")

In [None]:
from sklearn.metrics import classification_report
preds = trainer.model.predict(test_dataset["text"], as_numpy=True)
print(classification_report(test_dataset['label'], preds))

## Multi-label classification

In [None]:
from datasets import load_dataset
from collections import Counter

dataset_id = 'acloudfan/toxicity-multi-label-classifier'

train_dataset = load_dataset(dataset_id, split="train")
val_dataset = load_dataset(dataset_id, split="validation")
test_dataset = load_dataset(dataset_id, split="test")

label_cols = ['toxic', 'threat', 'insult', 'identity_hate']
for col in label_cols:
    print(col, dict(Counter(train_dataset[col])), sep=": ")

num_classes = len(label_cols)
id2label = dict(enumerate(label_cols))
label2id = {v: k for k, v in id2label.items()}

# convert to multi-label format
def format_dataset_multi_label(example):
    example['label'] = [example[label_col] for label_col in label2id.keys()]
    return example

train_dataset = train_dataset.map(format_dataset_multi_label, batched=False)
train_dataset = train_dataset.rename_column("comment_text", "text")
train_dataset = train_dataset.remove_columns(label_cols)

val_dataset = val_dataset.map(format_dataset_multi_label, batched=False)
val_dataset = val_dataset.rename_column("comment_text", "text")
val_dataset = val_dataset.remove_columns(label_cols)

test_dataset = test_dataset.map(format_dataset_multi_label, batched=False)
test_dataset = test_dataset.rename_column("comment_text", "text")
test_dataset = test_dataset.remove_columns(label_cols)

In [None]:
import numpy as np
import torch
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)

def compute_metrics_multilabel(p):
    """
    Compute evaluation metrics for multi-label classification.
    eval_pred: transformers.trainer_utils.PredictionOutput
               (contains .predictions and .label_ids)
    """
    # unpack
    logits, labels = p.predictions, p.label_ids
    # apply sigmoid to get probabilities in [0,1]
    #probs = 1 / (1 + np.exp(-logits))
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    # threshold at 0.5 for binary decisions
    preds = (probs > 0.5).astype(int)

    # compute metrics
    accuracy = accuracy_score(labels, preds)
    f1_macro = f1_score(labels, preds, average="macro", zero_division=0.0)
    f1_micro = f1_score(labels, preds, average="micro", zero_division=0.0)
    precision_macro = precision_score(labels, preds, average="macro", zero_division=0.0)
    recall_macro = recall_score(labels, preds, average="macro", zero_division=0.0)

    # optional: subset accuracy (exact match ratio)
    subset_acc = (labels == preds).all(axis=1).mean()

    return {
        "f1_macro": f1_macro,
        # "f1_micro": f1_micro,
        # "precision_macro": precision_macro,
        # "recall_macro": recall_macro,
        "accuracy": accuracy,
        # "subset_accuracy": subset_acc,
    }

# test:
# y_true = train_dataset['label'][:10]
# # simulate some predictions by sampling (shape (10, num_classes)) uniformly from 0-1
# np.random.seed(42)
# y_pred = np.random.rand(10, num_classes)
# p = PredictionOutput(predictions=y_pred, label_ids=y_true, metrics=None)
# compute_metrics_multilabel(p)

### simple `fit` method

In [None]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"
body = SentenceTransformer(model_id, model_kwargs={"device_map": "auto"})

head = SetFitHead(
    in_features=body.get_sentence_embedding_dimension(),
    out_features=num_classes,
    device=body.device,
    multitarget=True,
)

model = SetFitModelWithEarlyStopping(
    model_body=body,
    model_head=head,
    multi_target_strategy="one-vs-rest",
    normalize_embeddings=True,
)
model.to(body.device);

In [None]:
args = EarlyStoppingTrainingArguments()
args.max_length = body.tokenizer.model_max_length

model.fit(
    x_train=train_dataset["text"], y_train=train_dataset["label"],
    x_eval=val_dataset["text"], y_eval=val_dataset["label"],
    
    num_epochs=30,
    batch_size=16,
    body_learning_rate=args.body_classifier_learning_rate,
    head_learning_rate=args.head_learning_rate,
    l2_weight=args.l2_weight,
    
    max_length=args.max_length,
    
    show_progress_bar=True,
    end_to_end=True,
    
    # added early stopping arguments
    compute_metrics=compute_metrics_multilabel,
    metric_for_best_model="f1_macro", # NOTE: must match one of the keys returned by `compute_metrics`
    early_stopping_patience=4,
    early_stopping_threshold=0.03,
    greater_is_better=True,
)

In [None]:
# verify best model loaded
logits = model.predict_logits(val_dataset["text"], as_numpy=True)
p = PredictionOutput(predictions=logits, label_ids=np.array(val_dataset["label"]), metrics={})
compute_metrics_multilabel(p)

#### with trainer

In [None]:
import numpy as np
from transformers import AutoConfig
from sentence_transformers import SentenceTransformer
from setfit.modeling import SetFitHead
from src.finetuning.setfit_extensions.class_weights_head import (
    compute_class_weights,
    SetFitHeadWithClassWeights
)
from src.finetuning.setfit_extensions.early_stopping import (
    SetFitModelWithEarlyStopping, 
    EarlyStoppingTrainingArguments,
    EarlyStoppingCallback,
    EarlyStoppingTrainer
)

In [None]:
from typing import Literal
def multiclass_model_init(
        model_name: str="sentence-transformers/all-MiniLM-L6-v2",
        num_classes: int=2, 
        multi_target_strategy: Literal["one-vs-rest", "multi-output"]="one-vs-rest",
        class_weights: np._typing.NDArray=None,
        **kwargs
    ) -> SetFitModelWithEarlyStopping:
    
    model_kwargs={"device_map": "auto", **kwargs}
    body = SentenceTransformer(model_name, model_kwargs=model_kwargs, trust_remote_code=True)
    
    head_kwargs = dict(
        in_features=body.get_sentence_embedding_dimension(),
        out_features=num_classes,
        device=body.device,
        multitarget=(multi_target_strategy is not None),
    )
    head = SetFitHeadWithClassWeights(**head_kwargs, class_weights=class_weights) if class_weights is not None else SetFitHead(**head_kwargs)
    
    return SetFitModelWithEarlyStopping(
        model_body=body,
        model_head=head.to(body.device),
        multi_target_strategy=multi_target_strategy,
        normalize_embeddings=True,
    )

In [None]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"

In [None]:
training_args = EarlyStoppingTrainingArguments(
    num_epochs=(1, 15),
    # sentence transformer (embedding) finetuning arts
    eval_strategy="steps", # NOTE: currently no effect on (early stopping in) classification head training
    eval_steps=25, # NOTE: overwrites 0 epochs above for sentence transformer finetuning
    max_steps=200,
    eval_max_steps=200,
    # train end to end
    end_to_end=True,
    # early stopping config
    metric_for_best_model=("embedding_loss", "f1"),
    greater_is_better=(False, True),
    load_best_model_at_end=True,
    save_total_limit=2, # NOTE: currently no effect on (early stopping in) classification head training
)

In [None]:
# compute class weights (inversely proportional to class frequencies)
class_weights = compute_class_weights(train_dataset["label"], multitarget=True)
class_weights

In [None]:
training_callbacks = [
    EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.03), # for sentence transformer finetuning
    EarlyStoppingCallback(early_stopping_patience=4, early_stopping_threshold=0.02), # for classifier finetuning
]

# initialize Trainer
trainer = EarlyStoppingTrainer(
    model_init=lambda : multiclass_model_init(
        model_name=model_id,
        num_classes=num_classes,
        # class_weights=class_weights,
    ),
    args=training_args,
    metric="f1",
    metric_kwargs={"average": "macro"},
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=training_callbacks,
    # compute_metrics=compute_metrics_multilabel,
)
# fix max_length issue
trainer._args.max_length = trainer.st_trainer.model.tokenizer.model_max_length

# set seeds for reproducibility
trainer._args.seed = 42
trainer.st_trainer.args.seed = 42
trainer.st_trainer.args.data_seed = 42
trainer.st_trainer.args.full_determinism = True

# don't report to wandb or other experiment trackers
trainer._args.report_to = 'none'
trainer.st_trainer.args.report_to = 'none'

In [None]:
# train
trainer.train()

In [None]:
trainer.evaluate(val_dataset, metric_key_prefix="eval")

In [None]:
from sklearn.metrics import classification_report
preds = trainer.model.predict(test_dataset['text'], as_numpy=True)
print(classification_report(test_dataset['label'], preds, zero_division=0, target_names=label_cols))