## Hyperparameter Search with Simulated Annealing (Optuna)

This example demonstrates how to perform hyperparameter search for a Simulated Annealing training setup using **Optuna** and pyperchâ€™s search utilities.  

A fixed base training configuration is defined, and Optuna is used to explore SA-specific hyperparameters such as temperature, cooling rate, step size, and training duration.  The objective is to maximize validation accuracy on a binary classification task.

In [11]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from pyperch import Trainer
from pyperch.config import TrainConfig, OptimizerConfig
from pyperch.core.metrics import Accuracy
from pyperch.models.mlp import SimpleMLP
from pyperch.core.callbacks import EarlyStopping, CallbackList

# Search components
from pyperch.search.strategy import OptunaStrategy
from pyperch.search.builder import TrainConfigBuilder
from pyperch.search.adapter import TrainerAdapter
from pyperch.search.factory import SearchFactory

# ------------------------------------------------------------
# 1. Reproducibility
# ------------------------------------------------------------
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# ------------------------------------------------------------
# 2. Dataset
# ------------------------------------------------------------
X, y = make_classification(
    n_samples=1000,
    n_features=12,
    n_informative=10,
    n_classes=2,
    random_state=seed,
)

X = X.astype(np.float32)
y = y.astype(np.int64)

X_train, X_valid, y_train, y_valid = train_test_split(
    X, y, test_size=0.2, random_state=seed
)

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train), torch.tensor(y_train)),
    batch_size=64,
    shuffle=True,
)

valid_loader = DataLoader(
    TensorDataset(torch.tensor(X_valid), torch.tensor(y_valid)),
    batch_size=64,
)

# ------------------------------------------------------------
# 3. Base training configuration
# ------------------------------------------------------------
base_opt_cfg = OptimizerConfig(
    name="sa",
    t=1.0,
    t_min=0.01,
    step_size=0.05,
    cooling=0.95,
)

base_cfg = TrainConfig(
    device="cpu",
    seed=seed,
    max_epochs=300,
    optimizer="sa",
    optimizer_config=base_opt_cfg,
    optimizer_mode="per_batch",   
    metrics={"train": [Accuracy()], "valid": [Accuracy()]},
)

builder = TrainConfigBuilder(base_cfg)

# ------------------------------------------------------------
# 4. Optuna search space
# ------------------------------------------------------------
def suggest_params(trial):
    return {
        "optimizer_config.t": trial.suggest_float("t", 0.5, 2.0),
        "optimizer_config.t_min": trial.suggest_float("t_min", 0.001, 0.01),
        "optimizer_config.step_size": trial.suggest_float("step_size", 0.025, 0.05),
        "optimizer_config.cooling": trial.suggest_float("cooling", 0.90, 0.999),
        "max_epochs": trial.suggest_int("max_epochs", 10, 200),
    }

strategy = OptunaStrategy(suggest_params)

# ------------------------------------------------------------
# 5. Callbacks
# ------------------------------------------------------------
# Early stopping reduces search runtime
early_stopping = EarlyStopping(
    monitor="valid_loss",   # loss is always minimized
    patience=15,
    min_delta=1e-4,
)

# ------------------------------------------------------------
# 6. Training function (Optuna objective)
# ------------------------------------------------------------
def train_fn(cfg):
    model = SimpleMLP(
        input_dim=12,
        hidden=[32],
        output_dim=2,
        activation="relu",
    )

    loss_fn = nn.CrossEntropyLoss()

    trainer = Trainer(
        model=model,
        config=cfg,
        loss_fn=loss_fn,
    )

    # Ensure default callbacks (e.g., HistoryCallback) are preserved
    if trainer.callbacks is None:
        trainer.callbacks = CallbackList()

    trainer.callbacks.append(early_stopping)

    history = trainer.fit(train_loader, valid_loader)

    # Optuna objective: best validation accuracy achieved
    return max(history["valid_metrics"].get("accuracy", [0.0]))

adapter = TrainerAdapter(builder, strategy, train_fn)

# ------------------------------------------------------------
# 7. Run search
# ------------------------------------------------------------
import optuna

optuna.delete_study(
    study_name="sa_search_demo",
    storage="sqlite:///sa_demo.db",
)

search = SearchFactory.optuna_sqlite(
    adapter=adapter,
    study_name="sa_search_demo",
    storage="sqlite:///sa_demo.db",
)

study = search.run(
    n_trials=150,
    n_jobs=4,  # use -1 for all available cores
)

print("Best parameters:", search.best_params)
print("Best score:", search.best_value)
print("Best trial:", search.best_trial.number)

[I 2025-12-12 03:04:25,339] A new study created in RDB with name: sa_search_demo
[I 2025-12-12 03:04:26,682] Trial 0 finished with value: 0.555 and parameters: {'t': 1.1356285009476135, 't_min': 0.0054114058870892694, 'step_size': 0.03967622302766703, 'cooling': 0.9745954523435062, 'max_epochs': 10}. Best is trial 0 with value: 0.555.
[I 2025-12-12 03:04:29,531] Trial 4 finished with value: 0.715 and parameters: {'t': 0.583053501242841, 't_min': 0.0021783023304412117, 'step_size': 0.049642719433669415, 'cooling': 0.9623734728647027, 'max_epochs': 125}. Best is trial 4 with value: 0.715.
[I 2025-12-12 03:04:30,551] Trial 1 finished with value: 0.615 and parameters: {'t': 1.8032642119217235, 't_min': 0.008247979048161512, 'step_size': 0.03395370073297856, 'cooling': 0.927371655531226, 'max_epochs': 188}. Best is trial 4 with value: 0.715.
[I 2025-12-12 03:04:30,735] Trial 2 finished with value: 0.515 and parameters: {'t': 1.8342480161390515, 't_min': 0.0028028019785117306, 'step_size': 0

Best parameters: {'t': 0.5835824769050657, 't_min': 0.001810413819371178, 'step_size': 0.0457338679570872, 'cooling': 0.9281705505474588, 'max_epochs': 92}
Best score: 0.86
Best trial: 3
