# Tutorial 3: Hyperparameter Tuning with Optuna

This tutorial shows how to run Optuna sweeps inside a ZenML pipeline. You'll learn how to:

- Generate synthetic graph data for tuning
- Define a ZenML pipeline that includes an Optuna search step
- Inspect the best hyperparameters and the tuned model
- Evaluate the tuned model on the validation split

Unlike earlier versions, the full pipeline now lives in the notebook so you can modify
every step. We assume Optuna is installed (see `requirements.txt`).


In [None]:
import torch
import optuna
from torch_geometric.data import Data
from zenml import pipeline, step

from pioneerml.models import GroupClassifier
from pioneerml.training import GraphDataModule, GraphLightningModule
from pioneerml.zenml.materializers import (
    GraphDataModuleMaterializer,
    PyGDataListMaterializer,
)
from pioneerml.zenml import load_step_output
from pioneerml.zenml import utils as zenml_utils
from pioneerml.zenml.utils import detect_available_accelerator

zenml_client = zenml_utils.setup_zenml_for_notebook(use_in_memory=True)
print(f"ZenML initialized with stack: {zenml_client.active_stack_model.name}")


## Build the Tuning Pipeline

The pipeline mirrors the structure from earlier tutorials and adds an Optuna step:

1. `create_data`: synthetic graph classification dataset
2. `create_datamodule`: wraps the dataset with train/val splits
3. `run_hyperparameter_search`: Optuna study that picks hidden size, learning rate, dropout
4. `train_best_model`: trains one final model with the best parameters

Optuna runs entirely inside the pipeline, making the sweep reproducible and tracked in ZenML.


In [None]:
def create_synthetic_tuning_data(num_samples: int = 200) -> list[Data]:
    """Generate clustered graphs for a 3-class classification task."""
    class_offsets = torch.tensor([
        [2.0, 0.0, 0.5, 0.0, 0.0],
        [0.0, 2.0, 0.0, 0.5, 0.0],
        [-2.0, -2.0, -0.5, 0.0, 0.5],
    ])
    data: list[Data] = []
    for _ in range(num_samples):
        label = torch.randint(0, 3, (1,)).item()
        num_nodes = torch.randint(5, 10, (1,)).item()
        x = torch.randn(num_nodes, 5) * 0.4 + class_offsets[label]
        edge_index = torch.randint(0, num_nodes, (2, num_nodes * 3))
        edge_attr = torch.randn(edge_index.shape[1], 4) * 0.3
        y = torch.zeros(3)
        y[label] = 1.0
        data.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))
    return data


@step(output_materializers=PyGDataListMaterializer, enable_cache=False)
def create_data() -> list[Data]:
    """Step 1: Generate synthetic graphs for tuning."""
    return create_synthetic_tuning_data()


@step(output_materializers=GraphDataModuleMaterializer, enable_cache=False)
def create_datamodule(data: list[Data]) -> GraphDataModule:
    """Step 2: Wrap the dataset in a Lightning DataModule."""
    return GraphDataModule(dataset=data, val_split=0.25, batch_size=32, num_workers=0)


@step(enable_cache=False)
def run_hyperparameter_search(datamodule: GraphDataModule, n_trials: int = 4) -> dict:
    """Step 3: Perform an Optuna search over hidden size, dropout, and learning rate."""

    def objective(trial: optuna.Trial) -> float:
        hidden_dim = trial.suggest_categorical("hidden_dim", [64, 128, 256])
        dropout = trial.suggest_float("dropout", 0.0, 0.3)
        lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)

        model = GroupClassifier(num_classes=3, hidden=hidden_dim, dropout=dropout)
        lightning_module = GraphLightningModule(model, task="classification", lr=lr)
        accelerator, devices = detect_available_accelerator()

        import pytorch_lightning as pl

        trainer = pl.Trainer(
            accelerator=accelerator,
            devices=devices,
            max_epochs=2,
            logger=False,
            enable_checkpointing=False,
            enable_progress_bar=False,
        )

        datamodule.setup(stage="fit")
        trainer.fit(lightning_module, datamodule=datamodule)
        val_metrics = trainer.validate(lightning_module, datamodule=datamodule, verbose=False)
        if val_metrics and isinstance(val_metrics[0], dict):
            accuracy = val_metrics[0].get("val_accuracy")
            if accuracy is not None:
                return float(accuracy)
            loss = val_metrics[0].get("val_loss")
            if loss is not None:
                return 1.0 / (1.0 + float(loss))
        return 0.0

    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials)

    return {
        "best_hidden_dim": study.best_params["hidden_dim"],
        "best_dropout": study.best_params["dropout"],
        "best_lr": study.best_params["lr"],
        "best_accuracy": study.best_value,
        "n_trials": len(study.trials),
    }


@step(enable_cache=False)
def train_best_model(best_params: dict, datamodule: GraphDataModule) -> GraphLightningModule:
    """Step 4: Train a final model with the best Optuna parameters."""
    model = GroupClassifier(
        num_classes=3,
        hidden=best_params["best_hidden_dim"],
        dropout=best_params["best_dropout"],
    )
    lightning_module = GraphLightningModule(model, task="classification", lr=best_params["best_lr"])

    import pytorch_lightning as pl

    accelerator, devices = detect_available_accelerator()
    trainer = pl.Trainer(
        accelerator=accelerator,
        devices=devices,
        max_epochs=5,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
    )

    datamodule.setup(stage="fit")
    trainer.fit(lightning_module, datamodule=datamodule)
    return lightning_module.eval()


@pipeline
def tuning_pipeline(n_trials: int = 4):
    data = create_data()
    datamodule = create_datamodule(data)
    best_params = run_hyperparameter_search(datamodule, n_trials=n_trials)
    tuned_model = train_best_model(best_params, datamodule)
    return tuned_model, datamodule, best_params


## Run the Optuna Sweep

Execute the pipeline with a small number of trials (increase `n_trials` for real sweeps).
The pipeline stores all runs and best parameters in ZenML so you can reproduce the sweep later.


In [None]:
run = tuning_pipeline.with_options(enable_cache=False)(n_trials=4)
print(f"Pipeline run status: {run.status}")

tuned_module = load_step_output(run, "train_best_model")
datamodule = load_step_output(run, "create_datamodule")
best_params = load_step_output(run, "run_hyperparameter_search")

if tuned_module is None or datamodule is None or best_params is None:
    raise RuntimeError("Failed to load artifacts from the tuning pipeline.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tuned_module = tuned_module.to(device).eval()
datamodule.setup(stage="fit")

print("Best hyperparameters:")
for key, value in best_params.items():
    print(f"- {key}: {value}")


## Inspect the Tuned Model

Check that the tuned model has the expected shape, parameter count, and device placement.


In [None]:
device = next(tuned_module.parameters()).device
param_count = sum(p.numel() for p in tuned_module.parameters())

train_loader = datamodule.train_dataloader()
first_batch = next(iter(train_loader))

print("Tuned Model Summary:")
print(f"- Run: {run.name}")
print(f"- Device: {device}")
print(f"- Parameters: {param_count:,}")
print(f"- Nodes per batch: {first_batch.x.shape[0]} | Features: {first_batch.x.shape[1]}")
print(f"- Edges: {first_batch.edge_index.shape[1]}")

with torch.no_grad():
    logits = tuned_module(first_batch.to(device))
    print(f"- Output logits shape: {tuple(logits.shape)}")


## Evaluate Validation Accuracy

Run the tuned model on the validation split. We handle label shapes explicitly because
PyG batches graph-level labels into `(batch_size, num_classes)` tensors.


In [None]:
val_loader = datamodule.val_dataloader()
if isinstance(val_loader, list) and len(val_loader) == 0:
    val_loader = datamodule.train_dataloader()

correct = 0
total = 0
tuned_module.eval()

for batch in val_loader:
    batch = batch.to(device)
    with torch.no_grad():
        logits = tuned_module(batch)

    labels = batch.y
    if labels.dim() == 1:
        if labels.shape[0] % 3 == 0:
            labels = labels.view(-1, 3)
        else:
            labels = labels.unsqueeze(0)
    labels = torch.argmax(labels, dim=1)

    preds = torch.argmax(logits, dim=1)
    correct += int((preds == labels).sum().item())
    total += int(labels.numel())

accuracy = correct / total if total > 0 else 0.0
print(f"Validation accuracy: {accuracy:.1%} ({correct}/{total})")
