# Tutorial 4: Advanced ESM training with hyper-parameter optimization

In this tutorial, we will include hyper-parameter optimization in the ESM training process. In order to this more efficiently, we first compute the EmbeddingDataset once and store it. This allows us to quickly train many ESM versions using cross-validation.

We will use the Optuna framework for hyper-parameter optimization. But feel free to use any package / implementation that suits your needs.

In [1]:
!pip install optuna
!pip install scikit-learn



In [2]:
from hfselect import Dataset, create_embedding_dataset, ESMTrainer, EmbeddingDataset
from transformers import BertModel, BertTokenizer
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
import torch
from torch.utils.data import SequentialSampler, DataLoader
import optuna

In [3]:
device_name = "cpu"

base_model = BertModel.from_pretrained("google-bert/bert-base-uncased")
tuned_model = BertModel.from_pretrained("prithivMLmods/Spam-Bert-Uncased")

tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")

In [4]:
dataset = Dataset.from_hugging_face(
    name="prithivMLmods/Spam-Text-Detect-Analysis",
    split="train",
    text_col="Message",
    label_col="Category",
    is_regression=False,
    num_examples=1000
)

In the previous tutorial, the EmbeddingDataset was created internally and not saved. This time, we will create it explicitely in order to use it for many train iterations with different hyper-parameters.

In [10]:
embedding_dataset = create_embedding_dataset(
    dataset=dataset,
    base_model=base_model,
    tuned_model=tuned_model,
    tokenizer=tokenizer,
    device_name=device_name
)

Computing embedding dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:47<00:00, 20.99s/batch]


In [5]:
embedding_dataset.save("./embedding_dataset_test.npz")
#embedding_dataset = EmbeddingDataset.from_disk("./embedding_dataset_test.npz")

We define the space of hyper-parameters, the train loop and the objective.

# Hyper-parameter search

In [8]:
def objective(trial, k_folds=5):
    num_epochs = trial.suggest_int("num_epochs", 3, 20)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-5, 1, log=True)
    batch_size = trial.suggest_int('batch_size', 8, 32)

    trainer = ESMTrainer(weight_decay=weight_decay, learning_rate=learning_rate, device_name=device_name)

    kfold = KFold(n_splits=k_folds, shuffle=True)
    fold_results = []

    # Start print
    print('--------------------------------')

    # K-fold Cross Validation model evaluation
    for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
        train_split = embedding_dataset[train_ids]
        test_split = embedding_dataset[test_ids]
        test_split_x, test_split_y = test_split.x, test_split

        esm = trainer.train_with_embeddings(
            embedding_dataset=train_split,
            num_epochs=num_epochs,
            batch_size=batch_size,
            verbose=0
        )

        sampler = SequentialSampler(test_split)
        test_loader = DataLoader(
            test_split,
            sampler=sampler,
            batch_size=128
        )

        fold_mse = 0
        for batch in test_loader:
            batch_base_embeddings, batch_transformed_embeddings = batch
            with torch.no_grad():
                transformed_embeddings = esm(batch_base_embeddings).detach().numpy()
            batch_mse = mean_squared_error(transformed_embeddings, batch_transformed_embeddings)
            fold_mse += batch_mse * len(batch) / len(test_loader)

        fold_results.append(fold_mse)

    return sum(fold_results) / len(fold_results)


In this example, we will use a Gaussian Process Sampler for selecting hyper-parameters.

In [9]:
#study = optuna.create_study(direction="minimize", sampler=optuna.samplers.RandomSampler(seed=42))
study = optuna.create_study(direction="minimize", sampler=optuna.samplers.GPSampler(seed=42))

  study = optuna.create_study(direction="minimize", sampler=optuna.samplers.GPSampler(seed=42))
[I 2025-03-02 11:43:49,540] A new study created in memory with name: no-name-1e9ccd79-a1f0-4131-a1d2-f25734ee0858


We start the training.

In [10]:
study.optimize(objective, n_trials=100, show_progress_bar=True)

  0%|          | 0/100 [00:00<?, ?it/s]

--------------------------------
[I 2025-03-02 11:43:58,557] Trial 0 finished with value: 0.21443050354719162 and parameters: {'num_epochs': 9, 'learning_rate': 0.566984951147885, 'weight_decay': 0.0457056309980145, 'batch_size': 22}. Best is trial 0 with value: 0.21443050354719162.
--------------------------------
[I 2025-03-02 11:44:01,017] Trial 1 finished with value: 0.29433411210775373 and parameters: {'num_epochs': 5, 'learning_rate': 6.025215736203858e-05, 'weight_decay': 1.951722464144947e-05, 'batch_size': 29}. Best is trial 0 with value: 0.21443050354719162.
--------------------------------
[I 2025-03-02 11:44:08,637] Trial 2 finished with value: 0.08357351049780845 and parameters: {'num_epochs': 13, 'learning_rate': 0.034702669886504146, 'weight_decay': 1.2674255898937214e-05, 'batch_size': 32}. Best is trial 2 with value: 0.08357351049780845.
--------------------------------
[I 2025-03-02 11:44:22,904] Trial 3 finished with value: 0.1585923619568348 and parameters: {'num_ep

--------------------------------
[I 2025-03-02 11:50:19,470] Trial 29 finished with value: 0.12005924135446548 and parameters: {'num_epochs': 20, 'learning_rate': 0.0005864417078183384, 'weight_decay': 1e-05, 'batch_size': 32}. Best is trial 10 with value: 0.0762409970164299.
--------------------------------
[I 2025-03-02 11:50:36,187] Trial 30 finished with value: 0.07924627214670181 and parameters: {'num_epochs': 17, 'learning_rate': 0.005490299934171491, 'weight_decay': 0.0005378956837723897, 'batch_size': 8}. Best is trial 10 with value: 0.0762409970164299.
--------------------------------
[I 2025-03-02 11:50:43,416] Trial 31 finished with value: 0.08049811348319054 and parameters: {'num_epochs': 15, 'learning_rate': 0.011772579500161395, 'weight_decay': 0.0002976994447658206, 'batch_size': 15}. Best is trial 10 with value: 0.0762409970164299.
--------------------------------
[I 2025-03-02 11:50:56,428] Trial 32 finished with value: 0.07405714839696884 and parameters: {'num_epochs'

--------------------------------
[I 2025-03-02 11:56:36,299] Trial 58 finished with value: 0.08203751742839813 and parameters: {'num_epochs': 15, 'learning_rate': 0.002256811499992727, 'weight_decay': 0.0009334603570344453, 'batch_size': 8}. Best is trial 52 with value: 0.07335403822362423.
--------------------------------
[I 2025-03-02 11:56:55,408] Trial 59 finished with value: 0.26669165939092637 and parameters: {'num_epochs': 20, 'learning_rate': 1e-05, 'weight_decay': 1e-05, 'batch_size': 8}. Best is trial 52 with value: 0.07335403822362423.
--------------------------------
[I 2025-03-02 11:56:59,255] Trial 60 finished with value: 0.08196538127958775 and parameters: {'num_epochs': 11, 'learning_rate': 0.035514293168716675, 'weight_decay': 1e-05, 'batch_size': 23}. Best is trial 52 with value: 0.07335403822362423.
--------------------------------
[I 2025-03-02 11:57:12,131] Trial 61 finished with value: 0.07829885967075825 and parameters: {'num_epochs': 13, 'learning_rate': 0.00593

--------------------------------
[I 2025-03-02 12:02:41,820] Trial 87 finished with value: 0.07728773094713688 and parameters: {'num_epochs': 18, 'learning_rate': 0.017420764334501106, 'weight_decay': 1e-05, 'batch_size': 19}. Best is trial 82 with value: 0.07332753874361515.
--------------------------------
[I 2025-03-02 12:02:53,773] Trial 88 finished with value: 0.0732130877673626 and parameters: {'num_epochs': 18, 'learning_rate': 0.01409923568733672, 'weight_decay': 1e-05, 'batch_size': 8}. Best is trial 88 with value: 0.0732130877673626.
--------------------------------
[I 2025-03-02 12:03:05,341] Trial 89 finished with value: 0.07276419922709465 and parameters: {'num_epochs': 18, 'learning_rate': 0.0141766680631574, 'weight_decay': 1e-05, 'batch_size': 8}. Best is trial 89 with value: 0.07276419922709465.
--------------------------------
[I 2025-03-02 12:03:16,724] Trial 90 finished with value: 0.07326767481863498 and parameters: {'num_epochs': 18, 'learning_rate': 0.01687348814

We'll have a look at the best found hyper-parameters.

In [14]:
study.best_params

{'num_epochs': 18,
 'learning_rate': 0.013296373945166117,
 'weight_decay': 1e-05,
 'batch_size': 8}

Furthermore, we can visualize, which hyper-parameters are promising.

In [17]:
optuna.visualization.plot_slice(study)

In [18]:
optuna.visualization.plot_parallel_coordinate(study)

##

In [12]:
trainer = ESMTrainer(
    weight_decay=study.best_params["weight_decay"],
    learning_rate=study.best_params["learning_rate"],
    device_name=device_name
)

esm = trainer.train_with_embeddings(
    embedding_dataset=embedding_dataset,
    num_epochs=study.best_params["num_epochs"],
    batch_size=study.best_params["batch_size"]
)

esm

Training ESM: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:04<00:00,  3.71epoch/s, avg_train_loss=0.00497]


ESM - Task ID: prithivMLmods/Spam-Text-Detect-Analysis - Subset: None

In [13]:
esm.config

ESMConfig {
  "base_model_name": "google-bert/bert-base-uncased",
  "developers": null,
  "esm_architecture": null,
  "esm_batch_size": 8,
  "esm_embedding_dim": null,
  "esm_learning_rate": 0.013296373945166117,
  "esm_num_epochs": 18,
  "esm_optimizer": null,
  "esm_weight_decay": 1e-05,
  "label_column": "Category",
  "language": null,
  "lm_batch_size": null,
  "lm_learning_rate": null,
  "lm_num_epochs": null,
  "lm_optimizer": null,
  "lm_weight_decay": null,
  "num_examples": 1000,
  "seed": null,
  "streamed": false,
  "task_id": "prithivMLmods/Spam-Text-Detect-Analysis",
  "task_split": "train",
  "task_subset": null,
  "text_column": "Message",
  "transformers_version": "4.47.1"
}