In [12]:
from optuna import Trial
from typing import Dict, Union, Any
from setfit import TrainingArguments
from setfit import SetFitModel

In [13]:
def model_init(params: Dict[str, Any]) -> SetFitModel:
    params = params or {}
    max_iter = params.get("max_iter", 100)
    solver = params.get("solver", "liblinear")
    params = {
        "head_params": {
            "max_iter": max_iter,
            "solver": solver,
        }
    }
    return SetFitModel.from_pretrained("KBLab/sentence-bert-swedish-cased", **params)


In [14]:
def hp_space(trial: Trial) -> Dict[str, Union[float, int, str]]:
    return {
        #"body_learning_rate": trial.suggest_float("body_learning_rate", 1e-6, 1e-3, log=True),
        "num_epochs": trial.suggest_int("num_epochs", 1, 10),
        "batch_size": trial.suggest_categorical("batch_size", [16, 32, 64]),
        #"seed": trial.suggest_int("seed", 1, 40),
        #"max_iter": trial.suggest_int("max_iter", 50, 300),
        #"solver": trial.suggest_categorical("solver", ["newton-cg", "lbfgs", "liblinear"]),
    }

In [15]:
from datasets import load_dataset
dataset = load_dataset("csv", data_files="../data/target_pilot_1_2_sv.csv")
dataset = dataset.class_encode_column("label")

  obj.co_lnotab,  # for < python 3.10 [not counted in args]


In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['ID', 'text', 'label'],
        num_rows: 72
    })
})

In [17]:
dataset_tt = dataset["train"].train_test_split(train_size=58, stratify_by_column='label')

In [18]:
dataset_tt

DatasetDict({
    train: Dataset({
        features: ['ID', 'text', 'label'],
        num_rows: 58
    })
    test: Dataset({
        features: ['ID', 'text', 'label'],
        num_rows: 14
    })
})

In [19]:
from setfit import Trainer

In [20]:
args = TrainingArguments(
    batch_size=16,
    num_epochs=5,
)

In [21]:
trainer = Trainer(
    model_init=model_init,
    args=args,
    train_dataset=dataset_tt["train"],
    eval_dataset=dataset_tt["test"]
)

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.
  obj.co_lnotab,  # for < python 3.10 [not counted in args]
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 58/58 [00:00<00:00, 4415.54 examples/s]


In [22]:
trainer.train()
trainer.evaluate()

***** Running training *****
  Num unique pairs = 1758
  Batch size = 16
  Num epochs = 5


Step,Training Loss
1,0.3013
50,0.2009
100,0.0029
150,0.0006
200,0.0005
250,0.0004


***** Running evaluation *****                                                                                                                                                                                                                                                                                                                         


{'accuracy': 0.7142857142857143}

In [24]:
trainer.model.save_pretrained("../models/target_sv")

In [None]:
best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=10)
print(best_run)

In [12]:
trainer.apply_hyperparameters(best_run.hyperparameters, final_model=True)
trainer.train()

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
***** Running training *****
  Num unique pairs = 1218
  Batch size = 32
  Num epochs = 3


Step,Training Loss
1,0.2816
50,0.2749


In [13]:
trainer.evaluate()

***** Running evaluation *****


{'accuracy': 0.625}