In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # to enable deterministic behavior with CuBLAS
# NOTE: to avoid error
#   RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` 
#   or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it
#   uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an 
#   environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or 
#   CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to 
#   https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

from transformers import set_seed
set_seed(42, deterministic=True) # for reproducibility

In [2]:
import numpy as np
from sentence_transformers import SentenceTransformer
from setfit.modeling import SetFitHead
# from src.finetuning.setfit_extensions.class_weights_head import compute_class_weights, SetFitHeadWithClassWeights
from src.finetuning.setfit_extensions.early_stopping import (
    SetFitModelWithEarlyStopping,
    EarlyStoppingTrainingArguments,
)
from transformers.trainer_utils import PredictionOutput

In [3]:
# load a example classification dataset from hf hub
from datasets import load_dataset
train_dataset = load_dataset("ag_news", split="train[:600]")
val_dataset = load_dataset("ag_news", split="test[:200]")
test_dataset = load_dataset("ag_news", split="test[200:400]")

num_classes = len(set(train_dataset["label"]))

### using the `SetfitModel`'s `fit()` method

In [4]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"
body = SentenceTransformer(model_id, model_kwargs={"device_map": "auto"})

head = SetFitHead(
    in_features=body.get_sentence_embedding_dimension(),
    out_features=num_classes,
    device=body.device,
)

model = SetFitModelWithEarlyStopping(
    model_body=body,
    model_head=head,
    normalize_embeddings=True,
)
model.to(body.device);

In [5]:
args = EarlyStoppingTrainingArguments()
args.max_length = body.tokenizer.model_max_length

In [6]:
from sklearn.metrics import f1_score
def compute_metrics(p: PredictionOutput):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    acc = np.sum(preds == labels) / len(labels)
    f1 = f1_score(labels, preds, average="macro")
    return {"accuracy": acc, "macro_f1": f1}

In [7]:
model.fit(
    x_train=train_dataset["text"], y_train=train_dataset["label"],
    x_eval=val_dataset["text"], y_eval=val_dataset["label"],
    
    num_epochs=10,
    batch_size=16,
    body_learning_rate=args.body_classifier_learning_rate,
    head_learning_rate=args.head_learning_rate,
    l2_weight=args.l2_weight,
    
    max_length=body.tokenizer.model_max_length,
    
    show_progress_bar=True,
    end_to_end=False, # !!! fine-tune only head
    
    # added early stopping arguments
    compute_metrics=compute_metrics,
    metric_for_best_model="macro_f1", # NOTE: must match one of the keys returned by `compute_metrics`
    early_stopping_patience=2,
    early_stopping_threshold=0.03,
    greater_is_better=True,
)

Epoch:  10%|█         | 1/10 [00:00<00:04,  1.88it/s]

{'training loss': 0.997852025847686, 'validation loss': 1.087493910239293, 'accuracy': 0.48, 'macro_f1': 0.4134369245670616}


Epoch:  20%|██        | 2/10 [00:00<00:03,  2.31it/s]

{'training loss': 0.6535947197361996, 'validation loss': 0.9029704882548406, 'accuracy': 0.635, 'macro_f1': 0.6035682708096501}


Epoch:  30%|███       | 3/10 [00:01<00:02,  2.51it/s]

{'training loss': 0.5147908455447147, 'validation loss': 0.7719288881008441, 'accuracy': 0.695, 'macro_f1': 0.6850539109126752}


Epoch:  40%|████      | 4/10 [00:01<00:02,  2.57it/s]

{'training loss': 0.4385863715096524, 'validation loss': 0.7151274589391855, 'accuracy': 0.73, 'macro_f1': 0.7132838589981447}


Epoch:  50%|█████     | 5/10 [00:01<00:01,  2.64it/s]

{'training loss': 0.3909744169366987, 'validation loss': 0.6981672988488123, 'accuracy': 0.74, 'macro_f1': 0.7209796641915848}


Epoch:  60%|██████    | 6/10 [00:02<00:01,  2.68it/s]

{'training loss': 0.3556148684338519, 'validation loss': 0.7059849294332358, 'accuracy': 0.75, 'macro_f1': 0.7294087721878985}


Epoch:  60%|██████    | 6/10 [00:02<00:01,  2.21it/s]

{'training loss': 0.34054943017269435, 'validation loss': 0.6692329622231997, 'accuracy': 0.76, 'macro_f1': 0.741553673243814}
Early stopping triggered after 7 epochs
Loading best model





In [8]:
# free GPU
model.to("cpu");
del model
import torch
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  65536 KiB | 235349 KiB | 262097 MiB | 262033 MiB |
|       from large pool |  65536 KiB | 219863 KiB | 247782 MiB | 247718 MiB |
|       from small pool |      0 KiB |  16570 KiB |  14314 MiB |  14314 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  65536 KiB | 235349 KiB | 262097 MiB | 262033 MiB |
|       from large pool |  65536 KiB | 219863 KiB | 247782 MiB | 247718 MiB |
|       from small pool |      0 KiB |  16570 KiB |  14314 MiB |  14314 MiB |
|---------------------------------------------------------------

### with custom early-stopping trainer class

In [9]:
import numpy as np
from transformers import AutoConfig
from sentence_transformers import SentenceTransformer
from setfit.modeling import SetFitHead
from src.finetuning.setfit_extensions.class_weights_head import (
    compute_class_weights,
    SetFitHeadWithClassWeights
)
from src.finetuning.setfit_extensions.early_stopping import (
    SetFitModelWithEarlyStopping, 
    EarlyStoppingTrainingArguments,
    EarlyStoppingCallback,
    EarlyStoppingTrainer
)

def model_init(
        model_name: str="sentence-transformers/all-MiniLM-L6-v2",
        num_classes: int=2, 
        class_weights: np._typing.NDArray=None,
        **kwargs
    ) -> SetFitModelWithEarlyStopping:
    
    model_kwargs={"device_map": "auto", **kwargs}
    body = SentenceTransformer(model_name, model_kwargs=model_kwargs, trust_remote_code=True)
    
    # TODO: support multi-label classification
    head_kwargs = dict(
        in_features=body.get_sentence_embedding_dimension(),
        out_features=num_classes,
        device=body.device,
    )
    head = SetFitHeadWithClassWeights(**head_kwargs, class_weights=class_weights) if class_weights is not None else SetFitHead(**head_kwargs)
    
    return SetFitModelWithEarlyStopping(
        model_body=body,
        model_head=head.to(body.device),
        normalize_embeddings=True,
    )

In [10]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"

config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)

In [11]:
training_args = EarlyStoppingTrainingArguments(
    num_epochs=(1, 15),
    # sentence transformer (embedding) finetuning arts
    eval_strategy="steps", # NOTE: currently no effect on (early stopping in) classification head training
    eval_steps=25, # NOTE: overwrites 0 epochs above for sentence transformer finetuning
    max_steps=200,
    eval_max_steps=200,
    # early stopping config
    metric_for_best_model=("embedding_loss", "macro_f1"),
    greater_is_better=(False, True),
    load_best_model_at_end=True,
    save_total_limit=2, # NOTE: currently no effect on (early stopping in) classification head training
)

training_callbacks = [
    EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.03), # for sentence transformer finetuning
    EarlyStoppingCallback(early_stopping_patience=4, early_stopping_threshold=0.02), # for classifier finetuning
]

In [12]:
# compute class weights (inversely proportional to class frequencies)
class_weights = compute_class_weights(train_dataset["label"])

In [13]:
# initialize Trainer
trainer = EarlyStoppingTrainer(
    model_init=lambda : model_init(
        model_name=model_id,
        num_classes=num_classes,
        class_weights=class_weights,
    ),
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=training_callbacks,
    compute_metrics=compute_metrics,
)
# fix max_length issue
trainer._args.max_length = trainer.st_trainer.model.tokenizer.model_max_length

# set seeds for reproducibility
trainer._args.seed = 42
trainer.st_trainer.args.seed = 42
trainer.st_trainer.args.data_seed = 42
trainer.st_trainer.args.full_determinism = True

# don't report to wandb or other experiment trackers
trainer._args.report_to = 'none'
trainer.st_trainer.args.report_to = 'none'

In [14]:
# train
trainer.train()

***** Running training *****
  Num unique pairs = 3200
  Batch size = 16


  Num epochs = 1


Step,Training Loss,Validation Loss
25,0.6673,0.232995
50,0.2649,0.215451
75,0.2649,0.239145
100,0.1438,0.22463
125,0.1438,0.233633


Epoch:   7%|▋         | 1/15 [00:01<00:23,  1.69s/it]

{'training loss': 0.5792229017739494, 'validation loss': 0.611130109205842, 'accuracy': 0.795, 'macro_f1': 0.7751903735632184}


Epoch:  13%|█▎        | 2/15 [00:03<00:21,  1.67s/it]

{'training loss': 0.35001990529398125, 'validation loss': 0.5642310957238078, 'accuracy': 0.82, 'macro_f1': 0.7934924319869581}


Epoch:  20%|██        | 3/15 [00:04<00:19,  1.66s/it]

{'training loss': 0.2867766917310655, 'validation loss': 0.5155245288088918, 'accuracy': 0.81, 'macro_f1': 0.7901657210375367}


Epoch:  27%|██▋       | 4/15 [00:06<00:18,  1.66s/it]

{'training loss': 0.23860107495139043, 'validation loss': 0.6621418678294867, 'accuracy': 0.78, 'macro_f1': 0.7540520662287904}


Epoch:  27%|██▋       | 4/15 [00:08<00:22,  2.08s/it]

{'training loss': 0.21036491377900043, 'validation loss': 0.6237878849543631, 'accuracy': 0.8, 'macro_f1': 0.7743435281075662}
Early stopping triggered after 5 epochs
Loading best model





In [15]:
# eval
probs = trainer.model.predict_proba(test_dataset['text'], as_numpy=True)
preds = probs.argmax(axis=1)

In [16]:
from sklearn.metrics import classification_report
print(classification_report(test_dataset['label'], preds))

              precision    recall  f1-score   support

           0       0.69      0.79      0.73        42
           1       0.95      0.81      0.88        70
           2       0.70      0.72      0.71        43
           3       0.73      0.78      0.75        45

    accuracy                           0.78       200
   macro avg       0.77      0.77      0.77       200
weighted avg       0.79      0.78      0.78       200



### with multi-label classification

In [53]:
dataset_id = 'acloudfan/toxicity-multi-label-classifier'

from datasets import load_dataset
train_dataset = load_dataset(dataset_id, split="train")
val_dataset = load_dataset(dataset_id, split="validation")
test_dataset = load_dataset(dataset_id, split="test")

In [54]:
from collections import Counter

label_cols = ['toxic', 'threat', 'insult', 'identity_hate']
for col in label_cols:
    print(Counter(train_dataset[col]))

Counter({1: 58, 0: 31})
Counter({0: 69, 1: 20})
Counter({0: 63, 1: 26})
Counter({0: 69, 1: 20})


In [55]:
id2label = dict(enumerate(label_cols))
label2id = {v: k for k, v in id2label.items()}
# convert to multi-label format

In [56]:
# convert to multi-label format
def format_dataset_multi_label(example):
    example['label'] = [example[label_col] for label_col in label2id.keys()]
    return example

In [57]:
train_dataset = train_dataset.map(format_dataset_multi_label, batched=False)
train_dataset = train_dataset.rename_column("comment_text", "text")
train_dataset = train_dataset.remove_columns(label_cols)

val_dataset = val_dataset.map(format_dataset_multi_label, batched=False)
val_dataset = val_dataset.rename_column("comment_text", "text")
val_dataset = val_dataset.remove_columns(label_cols)

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

In [59]:
train_dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 89
})

In [62]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"
body = SentenceTransformer(model_id, model_kwargs={"device_map": "auto"})

head = SetFitHead(
    in_features=body.get_sentence_embedding_dimension(),
    
    out_features=num_classes,
    device=body.device,
)

model = SetFitModelWithEarlyStopping(
    model_body=body,
    model_head=head,
    multi_target_strategy="one-vs-rest",
    normalize_embeddings=True,
)
model.to(body.device);

In [63]:
args = EarlyStoppingTrainingArguments()
args.max_length = body.tokenizer.model_max_length

In [94]:
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)

def compute_metrics_multilabel(p):
    """
    Compute evaluation metrics for multi-label classification.
    eval_pred: transformers.trainer_utils.PredictionOutput
               (contains .predictions and .label_ids)
    """
    # unpack
    logits, labels = p.predictions, p.label_ids
    # apply sigmoid to get probabilities in [0,1]
    probs = 1 / (1 + np.exp(-logits))
    # threshold at 0.5 for binary decisions
    preds = (probs > 0.5).astype(int)

    # compute metrics
    accuracy = accuracy_score(labels, preds)
    f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
    f1_micro = f1_score(labels, preds, average="micro", zero_division=0)
    precision_macro = precision_score(labels, preds, average="macro", zero_division=0)
    recall_macro = recall_score(labels, preds, average="macro", zero_division=0)

    # optional: subset accuracy (exact match ratio)
    subset_acc = (labels == preds).all(axis=1).mean()

    return {
        "accuracy": accuracy,
        "subset_accuracy": subset_acc,
        "f1_macro": f1_macro,
        "f1_micro": f1_micro,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
    }

# test:
# y_true = train_dataset['label'][:10]
# # simulate some predictions by sampling (shape (10, num_classes)) uniformly from 0-1
# np.random.seed(42)
# y_pred = np.random.rand(10, num_classes)
# p = PredictionOutput(predictions=y_pred, label_ids=y_true, metrics=None)
# compute_metrics_multilabel(p)

In [96]:
model.fit(
    x_train=train_dataset["text"], y_train=train_dataset["label"],
    x_eval=val_dataset["text"], y_eval=val_dataset["label"],
    
    num_epochs=10,
    batch_size=16,
    body_learning_rate=args.body_classifier_learning_rate,
    head_learning_rate=args.head_learning_rate,
    l2_weight=args.l2_weight,
    
    max_length=body.tokenizer.model_max_length,
    
    show_progress_bar=True,
    end_to_end=True, # !!! fine-tune only head
    
    # added early stopping arguments
    compute_metrics=compute_metrics_multilabel,
    metric_for_best_model="f1_macro", # NOTE: must match one of the keys returned by `compute_metrics`
    early_stopping_patience=2,
    early_stopping_threshold=0.03,
    greater_is_better=True,
)

Epoch:  10%|█         | 1/10 [00:00<00:01,  5.19it/s]

{'training loss': 1.4056835174560547, 'validation loss': 1.6588252385457356, 'accuracy': 0.3142857142857143, 'subset_accuracy': 0.3142857142857143, 'f1_macro': 0.6158771929824561, 'f1_micro': 0.7068965517241379, 'precision_macro': 0.6293290043290043, 'recall_macro': 0.6833964646464646}


Epoch:  20%|██        | 2/10 [00:00<00:01,  5.78it/s]

{'training loss': 1.3107593655586243, 'validation loss': 1.7910070419311523, 'accuracy': 0.34285714285714286, 'subset_accuracy': 0.34285714285714286, 'f1_macro': 0.6670193841246473, 'f1_micro': 0.7368421052631579, 'precision_macro': 0.6842532467532467, 'recall_macro': 0.7111742424242424}


Epoch:  30%|███       | 3/10 [00:00<00:01,  6.11it/s]

{'training loss': 1.2375029027462006, 'validation loss': 1.581046462059021, 'accuracy': 0.37142857142857144, 'subset_accuracy': 0.37142857142857144, 'f1_macro': 0.7113844393592678, 'f1_micro': 0.7521367521367521, 'precision_macro': 0.6918290043290044, 'recall_macro': 0.766729797979798}


Epoch:  40%|████      | 4/10 [00:00<00:00,  6.49it/s]

{'training loss': 1.1650748252868652, 'validation loss': 1.883272687594096, 'accuracy': 0.37142857142857144, 'subset_accuracy': 0.37142857142857144, 'f1_macro': 0.7113844393592678, 'f1_micro': 0.7521367521367521, 'precision_macro': 0.6918290043290044, 'recall_macro': 0.766729797979798}


Epoch:  40%|████      | 4/10 [00:00<00:01,  5.12it/s]

{'training loss': 1.1123599310715993, 'validation loss': 1.6481937567392986, 'accuracy': 0.37142857142857144, 'subset_accuracy': 0.37142857142857144, 'f1_macro': 0.6875749155497439, 'f1_micro': 0.7413793103448276, 'precision_macro': 0.683495670995671, 'recall_macro': 0.7389520202020201}
Early stopping triggered after 5 epochs
Loading best model



