In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # to enable deterministic behavior with CuBLAS
# NOTE: to avoid error
#   RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` 
#   or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it
#   uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an 
#   environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or 
#   CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to 
#   https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

from transformers import set_seed
set_seed(42, deterministic=True) # for reproducibility

In [None]:
import numpy as np
np.set_printoptions(precision=4, suppress=True)

from src.finetuning.setfit_extensions.class_weights_head import compute_class_weights

###  prepare the dataset

In [None]:
# load a example classification dataset from hf hub
from datasets import load_dataset, Dataset
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

dataset_id = "jakartaresearch/semeval-absa"
train_dataset = load_dataset(dataset_id, "restaurant", split="train[:1000]")
val_dataset = load_dataset(dataset_id, "restaurant", split="validation[:200]")
test_dataset = load_dataset(dataset_id, "restaurant", split="validation[200:400]")

def to_span_classification_format(example, allowed_labels=None):
    out = []
    aspects = example['aspects']
    for f, t, lab in zip(aspects['from'], aspects['to'], aspects['polarity']):
        if allowed_labels is None or lab in allowed_labels:
            out.append({'text': example['text'], 'span': (f, t), 'label': lab})
    return out

def dataset_to_span_classification_format(dataset, allowed_labels=None):
    return Dataset.from_list([example for examples in dataset.to_list() for example in to_span_classification_format(examples, allowed_labels)])

allowed_labels = ['negative', 'neutral', 'positive']
train_dataset = dataset_to_span_classification_format(train_dataset, allowed_labels)
val_dataset = dataset_to_span_classification_format(val_dataset, allowed_labels)
test_dataset = dataset_to_span_classification_format(test_dataset, allowed_labels)

In [None]:
num_classes = len(set(train_dataset["label"]))
num_classes

In [None]:
label2id = {label: i for i, label in enumerate(sorted(set(train_dataset["label"])))}
id2label = {i: label for label, i in label2id.items()}
label2id

In [None]:
# apply label2id mapping to labels
def apply_label_mapping(x, mapping):
    x['label'] = mapping[x['label']]
    return x
train_dataset= train_dataset.map(lambda x: apply_label_mapping(x, label2id), batched=False)
val_dataset= val_dataset.map(lambda x: apply_label_mapping(x, label2id), batched=False)
test_dataset= test_dataset.map(lambda x: apply_label_mapping(x, label2id), batched=False)

In [None]:
class_weights = compute_class_weights(train_dataset['label'])
class_weights

In [None]:
from transformers.trainer_utils import PredictionOutput
from sklearn.metrics import f1_score

def compute_metrics(p: PredictionOutput):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    acc = np.sum(preds == labels) / len(labels)
    f1 = f1_score(labels, preds, average="macro")
    return {"accuracy": acc, "macro_f1": f1}

### using the `SetfitModel`'s `fit()` method

In [None]:
from setfit.modeling import SetFitHead
from src.finetuning.setfit_extensions.class_weights_head import SetFitHeadWithClassWeights
from src.finetuning.setfit_extensions.span_embedding import (
    SetFitModelForSpanClassification,
    SentenceTransformerForSpanEmbedding,
)

In [None]:
# NOTE: needto manually convert to tuple for span, as datasets library converts tuples to lists
def dataset_to_xy_inputs(dataset):
    X = [(example['text'], tuple(example['span'])) for example in dataset]  # NOTE: need to convert span to tuple
    y = dataset['label'] # [label2id[example['label']] for example in dataset]
    return X, y

x_train, y_train = dataset_to_xy_inputs(train_dataset)
x_eval, y_eval = dataset_to_xy_inputs(val_dataset)
x_test, y_test = dataset_to_xy_inputs(test_dataset)

In [None]:
set_seed(42, deterministic=True)

model_id = "sentence-transformers/all-MiniLM-L6-v2"
body = SentenceTransformerForSpanEmbedding(model_id, model_kwargs={"device_map": "auto"})

head = SetFitHeadWithClassWeights(
    in_features=body.get_sentence_embedding_dimension(),
    out_features=num_classes,
    device=body.device,
    class_weights=class_weights
)

model = SetFitModelForSpanClassification(
    model_body=body,
    model_head=head,
    normalize_embeddings=True,
)
model.to(body.device);

In [None]:
from setfit import TrainingArguments
# from src.finetuning.setfit_extensions.span_embedding import SetFitTrainerForSpanClassification
args = TrainingArguments()
args.max_length = body.tokenizer.model_max_length

In [None]:
model.fit(
    x_train=x_train, y_train=y_train,
    x_eval=x_eval, y_eval=y_eval,
    
    num_epochs=15,
    batch_size=16,
    body_learning_rate=args.body_classifier_learning_rate,
    head_learning_rate=args.head_learning_rate,
    l2_weight=args.l2_weight,
    
    max_length=body.tokenizer.model_max_length,
    
    show_progress_bar=True,
    end_to_end=True, # !!! fine-tune also the body (important for adapting sentence embedding model for span embedding)
    
    # added early stopping arguments
    compute_metrics=compute_metrics,
    metric_for_best_model="macro_f1", # NOTE: must match one of the keys returned by `compute_metrics`
    early_stopping_patience=3,
    early_stopping_threshold=0.02,
    greater_is_better=True,
)

In [None]:
y_pred = model.predict(inputs=x_test, as_numpy=True)

In [None]:
# verify that it accepts different input formats
y_pred2 = model.predict(texts=test_dataset['text'], spans=test_dataset['span'], as_numpy=True)

span_texts = [ex['text'][slice(*ex['span'])] for ex in test_dataset]
y_pred3 = model.predict(texts=test_dataset['text'], span_texts=span_texts, as_numpy=True)

In [None]:
all(y_pred == y_pred2), all(y_pred == y_pred3), all(y_pred2 == y_pred3)

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred, target_names=[id2label[i] for i in range(num_classes)]))

In [None]:
# free GPU
model.to("cpu");
del model
import torch
torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())

### with custom early-stopping trainer class

In [None]:
from setfit.modeling import SetFitHead
from src.finetuning.setfit_extensions.class_weights_head import SetFitHeadWithClassWeights
from src.finetuning.setfit_extensions.span_embedding import (
    SentenceTransformerForSpanEmbedding,
    SetFitModelForSpanClassification,
    SetFitTrainerForSpanClassification,
)
from src.finetuning.setfit_extensions.early_stopping import (
    EarlyStoppingTrainingArguments,
    EarlyStoppingCallback,
)

In [None]:
def model_init(
        model_name: str="sentence-transformers/all-MiniLM-L6-v2",
        num_classes: int=2, 
        class_weights: np._typing.NDArray=None,
        **kwargs
    ) -> SetFitModelForSpanClassification:
    
    model_kwargs={"device_map": "auto", **kwargs}
    body = SentenceTransformerForSpanEmbedding(model_name, model_kwargs=model_kwargs, trust_remote_code=True)
    
    # TODO: support multi-label classification
    head_kwargs = dict(
        in_features=body.get_sentence_embedding_dimension(),
        out_features=num_classes,
        device=body.device,
    )
    if class_weights is not None:
        head_kwargs['class_weights'] = class_weights
        head = SetFitHeadWithClassWeights(**head_kwargs)
    else:
        head = SetFitHead(**head_kwargs)
    
    return SetFitModelForSpanClassification(
        model_body=body,
        model_head=head.to(body.device),
        normalize_embeddings=True,
    )

In [None]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"

In [None]:
training_args = EarlyStoppingTrainingArguments(
    num_epochs=(1, 15),
    # sentence transformer (embedding) finetuning arts
    eval_strategy="steps", # NOTE: currently no effect on (early stopping in) classification head training
    eval_steps=25, # NOTE: overwrites 0 epochs above for sentence transformer finetuning
    max_steps=50,
    eval_max_steps=200,
    # early stopping config
    metric_for_best_model=("embedding_loss", "f1"),
    greater_is_better=(False, True),
    load_best_model_at_end=True,
    save_total_limit=2, # NOTE: currently no effect on (early stopping in) classification head training
    # misc
    end_to_end=True,
)

training_callbacks = [
    EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.03), # for sentence transformer finetuning
    EarlyStoppingCallback(early_stopping_patience=4, early_stopping_threshold=0.02), # for classifier finetuning
]

In [None]:
# initialize Trainer
trainer = SetFitTrainerForSpanClassification(
    model_init=lambda : model_init(
        model_name=model_id,
        num_classes=num_classes,
        class_weights=class_weights,
    ),
    metric="f1",
    metric_kwargs={"average": "macro"},
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=training_callbacks,
    # compute_metrics=compute_metrics,
)
# fix max_length issue
trainer._args.max_length = trainer.st_trainer.model.tokenizer.model_max_length

# set seeds for reproducibility
trainer._args.seed = 42
trainer.st_trainer.args.seed = 42
trainer.st_trainer.args.data_seed = 42
trainer.st_trainer.args.full_determinism = True

# don't report to wandb or other experiment trackers
trainer._args.report_to = 'none'
trainer.st_trainer.args.report_to = 'none'

In [None]:
# train
trainer.train()

In [None]:
# verify best model loaded
probs = trainer.model.predict_proba(texts=val_dataset['text'], spans=val_dataset['span'], as_numpy=True)
preds = probs.argmax(axis=1)
p = PredictionOutput(predictions=probs, label_ids=np.array(val_dataset['label']), metrics={})
compute_metrics(p)

In [None]:
# evaluate on test set
probs = trainer.model.predict_proba(texts=test_dataset['text'], spans=test_dataset['span'], as_numpy=True)
preds = probs.argmax(axis=1)
p = PredictionOutput(predictions=probs, label_ids=np.array(test_dataset['label']), metrics={})
compute_metrics(p)

In [None]:
from sklearn.metrics import classification_report
preds = trainer.model.predict(texts=test_dataset['text'], spans=test_dataset['span'], as_numpy=True)
print(classification_report(test_dataset['label'], preds, target_names=list(label2id.keys())))

In [None]:
# NOTE: only shows F1 because passing metric="f1" to Trainer
trainer.evaluate(test_dataset)

: 