# On Colab, you will need to restart the runtime after running the following cell.

In [None]:
!pip install optuna sentence-transformers -q
!pip install plotly -q
# I am reinstalling pyyaml because the default (>5.4) is incompatible with plotly
!pip uninstall -y pyyaml && pip install pyyaml==5.4.1 -q

!pip list | grep "transformers\|optuna\|torch\|plotly\|yaml"

Found existing installation: PyYAML 5.4.1
Uninstalling PyYAML-5.4.1:
  Successfully uninstalled PyYAML-5.4.1
optuna                        2.10.1
plotly                        5.5.0
sentence-transformers         2.2.2
torch                         1.12.0+cu113
torchaudio                    0.12.0+cu113
torchsummary                  1.5.1
torchtext                     0.13.0
torchvision                   0.13.0+cu113
transformers                  4.21.0


### I am using [this script](https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py) as an example

The following cell is generic code to set up the training. 

In [None]:
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import sys
import os
import gzip
import csv

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])



#Check if dataset exsist. If not, download and extract  it
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'

if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)


model_name = "sentence-transformers/all-MiniLM-L6-v2"

# Read the dataset
model_save_path = 'output/training_stsbenchmark_'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")


# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1
        inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)

        if row['split'] == 'dev':
            dev_samples.append(inp_example)
        elif row['split'] == 'test':
            test_samples.append(inp_example)
        else:
            train_samples.append(inp_example)


logging.info("Read STSbenchmark dev dataset")

## EvaluatorWrapper

This wraps the evaluator so that `trial.report` can be called whenever evaluation is done. This is helpful because Optuna will prune some runs if it deems them not to be worthwhile.

In [None]:
class EvaluatorWrapper(EmbeddingSimilarityEvaluator):
    """
    Subclass whichever evaluator you are using. This wrapper will
    report the results to the trial when evaluating throughout the training.

    For bad runs, the pruner will end the run early, but it can only do so
    if the evaluation scores are reported to the trial.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.step = 0

    def set_trial(self, trial):
        self.trial = trial

    def __call__(self, *args, **kwargs):
        score = super().__call__(*args, **kwargs)

        step = kwargs.get("steps", -1)
        if step == -1:
            step = self.step
            self.step += 1

        self.trial.report(score, step)

        return score

# Create Optuna objective


Each trial will call the objective function. Use the following functions to define a search space. I chose to search through batch size, number of epochs, learning rate, weight decay, and warmup steps.

```python
trial.suggest_categorical(name, list_of_values)
trial.suggest_int(name, low, high)
trial.suggest_float(name, low, high, log=True/False)
trial.suggest_discrete_uniform(name, low, high, q)
trial.suggest_loguniform(name, low, high)
trial.suggest_uniform(name, low, high)
```

See details here: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna-trial-trial

In [None]:
from sentence_transformers import SentenceTransformer, losses


def objective(trial):

    word_embedding_model = models.Transformer(model_name)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                pooling_mode_mean_tokens=True,
                                pooling_mode_cls_token=False,
                                pooling_mode_max_tokens=False)

    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    train_loss = losses.CosineSimilarityLoss(model=model)

    train_dataloader = DataLoader(
        train_samples, 
        shuffle=True, 
        batch_size=trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64])
        )
    
    evaluator = EvaluatorWrapper.from_input_examples(dev_samples, name='sts-dev')

    evaluator.set_trial(trial)

    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluator,
        epochs=trial.suggest_int("epochs", 1, 5),
        use_amp=True,
        warmup_steps=trial.suggest_categorical("warmup_steps", [10, 100, 1000]),
        optimizer_params={'lr': trial.suggest_float("lr", low=1e-6, high=7e-5, log=True)},
        weight_decay=trial.suggest_categorical("weight_decay", [0.0, 0.01, 0.009]),
        checkpoint_path=model_save_path,
        checkpoint_save_steps=len(train_dataloader),
        show_progress_bar=False,
    )

    return model.evaluate(evaluator)

## Run study

In [None]:
import optuna

# if your evaluator is a loss, then you want to minimize
# if your evaluator is a similarity score, then you want to maximize
study = optuna.create_study(direction="maximize")

study.optimize(objective, n_trials=20)

[32m[I 2022-07-29 12:42:59,076][0m A new study created in memory with name: no-name-f634a2cf-529c-41b0-811b-124924e3cdf1[0m
[32m[I 2022-07-29 12:45:44,749][0m Trial 0 finished with value: 0.8912262556141459 and parameters: {'batch_size': 4, 'epochs': 2, 'warmup_steps': 1000, 'lr': 4.459899115686889e-06, 'weight_decay': 0.01}. Best is trial 0 with value: 0.8912262556141459.[0m
[32m[I 2022-07-29 12:48:20,376][0m Trial 1 finished with value: 0.893464949329911 and parameters: {'batch_size': 8, 'epochs': 4, 'warmup_steps': 100, 'lr': 4.928491858691746e-05, 'weight_decay': 0.009}. Best is trial 1 with value: 0.893464949329911.[0m
[32m[I 2022-07-29 12:49:24,453][0m Trial 2 finished with value: 0.8934982554027179 and parameters: {'batch_size': 32, 'epochs': 5, 'warmup_steps': 100, 'lr': 5.2265512600247e-05, 'weight_decay': 0.01}. Best is trial 2 with value: 0.8934982554027179.[0m
[32m[I 2022-07-29 12:50:17,903][0m Trial 3 finished with value: 0.8911205704878308 and parameters: {'

## Print best results and params

In [None]:
print(study.best_trial)
print(study.best_params)

FrozenTrial(number=9, values=[0.8966375689570919], datetime_start=datetime.datetime(2022, 7, 29, 12, 54, 14, 897064), datetime_complete=datetime.datetime(2022, 7, 29, 12, 55, 22, 927008), params={'batch_size': 16, 'epochs': 3, 'warmup_steps': 10, 'lr': 6.192407579991508e-05, 'weight_decay': 0.01}, distributions={'batch_size': CategoricalDistribution(choices=(4, 8, 16, 32, 64)), 'epochs': IntUniformDistribution(high=5, low=1, step=1), 'warmup_steps': CategoricalDistribution(choices=(10, 100, 1000)), 'lr': LogUniformDistribution(high=7e-05, low=1e-06), 'weight_decay': CategoricalDistribution(choices=(0.0, 0.01, 0.009))}, user_attrs={}, system_attrs={}, intermediate_values={0: 0.8939536952188011, 1: 0.8954105752277682, 2: 0.8966375689570919, 3: 0.8966375689570919}, trial_id=9, state=TrialState.COMPLETE, value=None)
{'batch_size': 16, 'epochs': 3, 'warmup_steps': 10, 'lr': 6.192407579991508e-05, 'weight_decay': 0.01}


## Load results into a dataframe and save

In [None]:
df = study.trials_dataframe()
df.to_csv("optuna-hp-tune-results.csv", index=False)
df

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_batch_size,params_epochs,params_lr,params_warmup_steps,params_weight_decay,state
0,0,0.891226,2022-07-29 12:42:59.085671,2022-07-29 12:45:44.749319,0 days 00:02:45.663648,4,2,4e-06,1000,0.01,COMPLETE
1,1,0.893465,2022-07-29 12:45:44.751842,2022-07-29 12:48:20.375884,0 days 00:02:35.624042,8,4,4.9e-05,100,0.009,COMPLETE
2,2,0.893498,2022-07-29 12:48:20.378069,2022-07-29 12:49:24.453098,0 days 00:01:04.075029,32,5,5.2e-05,100,0.01,COMPLETE
3,3,0.891121,2022-07-29 12:49:24.455175,2022-07-29 12:50:17.903019,0 days 00:00:53.447844,32,4,5e-06,100,0.0,COMPLETE
4,4,0.892101,2022-07-29 12:50:17.905074,2022-07-29 12:50:43.150349,0 days 00:00:25.245275,64,2,1.7e-05,100,0.0,COMPLETE
5,5,0.892116,2022-07-29 12:50:43.152520,2022-07-29 12:51:00.885523,0 days 00:00:17.733003,64,1,4.3e-05,10,0.009,COMPLETE
6,6,0.891754,2022-07-29 12:51:00.891158,2022-07-29 12:52:21.448198,0 days 00:01:20.557040,4,1,8e-06,100,0.009,COMPLETE
7,7,0.891191,2022-07-29 12:52:21.450774,2022-07-29 12:52:54.470651,0 days 00:00:33.019877,64,3,4.5e-05,1000,0.01,COMPLETE
8,8,0.894391,2022-07-29 12:52:54.473060,2022-07-29 12:54:14.894640,0 days 00:01:20.421580,4,1,3.2e-05,1000,0.0,COMPLETE
9,9,0.896638,2022-07-29 12:54:14.897064,2022-07-29 12:55:22.927008,0 days 00:01:08.029944,16,3,6.2e-05,10,0.01,COMPLETE


## Plot results in parallel coordinates plot

In [None]:
import plotly.express as px

fig = px.parallel_coordinates(
    df[["params_batch_size", "params_epochs", "params_lr", "params_warmup_steps", "params_weight_decay", "value"]],
    color="value",
    color_continuous_scale="agsunset"
)

# This sets scientific notation for small values (like lr and weight decay)
for i in range(len(fig.data[0]["dimensions"])):
    if min(fig.data[0]["dimensions"][i]["values"]) < 1:
        fig.data[0]["dimensions"][i]["tickformat"] = "1.1e"

fig.show()