# Tutorial 3: Hyperparameter Tuning with Optuna

This tutorial shows how to run Optuna sweeps inside a ZenML pipeline.


In [1]:
import contextlib
import io
import warnings
import torch
import optuna
import pytorch_lightning as pl
from torch_geometric.data import Data
from zenml import pipeline, step

from pioneerml.models import GroupClassifier
from pioneerml.training import GraphDataModule, GraphLightningModule
from pioneerml.zenml.materializers import (
    GraphDataModuleMaterializer,
    PyGDataListMaterializer,
)
from pioneerml.zenml import load_step_output
from pioneerml.zenml import utils as zenml_utils
from pioneerml.zenml.utils import detect_available_accelerator

zenml_client = zenml_utils.setup_zenml_for_notebook(use_in_memory=True)
print(f"ZenML initialized with stack: {zenml_client.active_stack_model.name}")


Using ZenML repository root: /home/jack/python_projects/pioneerML
Ensure this is the top-level of your repo (.zen must live here).
ZenML initialized with stack: default


## Build the Tuning Pipeline

The pipeline mirrors the structure from earlier tutorials and adds an Optuna step:

1. `create_data`: synthetic graph classification dataset
2. `create_datamodule`: wraps the dataset with train/val splits
3. `run_hyperparameter_search`: Optuna study that picks hidden size, learning rate, dropout
4. `train_best_model`: trains one final model with the best parameters

Optuna runs entirely inside the pipeline, making the sweep reproducible and tracked in ZenML.


In [2]:

def create_synthetic_tuning_data(num_samples: int = 400) -> list[Data]:
    """Generate a deliberately tricky dataset so Optuna can't hit 100% accuracy."""
    data: list[Data] = []

    class_means = torch.tensor([
        [0.7, -0.3, 0.2, 0.1, -0.1],
        [-0.1, 0.8, -0.2, 0.3, 0.15],
        [-0.6, -0.4, 0.4, -0.2, 0.25],
    ])
    class_drift = torch.tensor([
        [0.4, 0.3, -0.1, 0.0, 0.2],
        [-0.3, 0.2, 0.25, -0.15, -0.1],
        [0.2, -0.4, 0.15, 0.25, -0.05],
    ])
    feature_scales = torch.tensor([
        [1.0, 0.9, 1.1, 0.95, 1.05],
        [0.95, 1.05, 0.85, 1.1, 0.9],
        [1.1, 0.95, 0.9, 1.0, 1.0],
    ])

    for _ in range(num_samples):
        label = torch.randint(0, 3, (1,)).item()
        num_nodes = torch.randint(7, 18, (1,)).item()

        mix_label = (label + torch.randint(1, 3, (1,)).item()) % 3
        mix_ratio = torch.rand(1).item() * 0.6 + 0.2
        prototype = mix_ratio * class_means[label] + (1 - mix_ratio) * class_means[mix_label]

        t = torch.linspace(0, 1, steps=num_nodes).unsqueeze(1)
        wiggles = torch.cat([
            torch.sin(2.0 * torch.pi * t + label * 0.3),
            torch.cos(3.0 * torch.pi * t + mix_ratio),
            torch.sin(4.0 * torch.pi * t - mix_ratio * 0.5),
            torch.cos(5.0 * torch.pi * t + label * 0.2),
            torch.sin(6.0 * torch.pi * t - 0.1),
        ], dim=1)

        x = prototype + 0.4 * wiggles
        x = x * feature_scales[label]

        noise = torch.randn(num_nodes, 5)
        correlated = noise + 0.35 * torch.matmul(noise, torch.ones(5, 5) * 0.1)
        drift = class_drift[label] * torch.randn(num_nodes, 1)
        x = x + 0.25 * correlated + drift

        projection = torch.randn(5, 5) * 0.2 + torch.eye(5)
        x = x @ projection

        if label == 0:
            ring = torch.stack([
                torch.arange(num_nodes),
                (torch.arange(num_nodes) + 1) % num_nodes,
            ])
            random_edges = torch.randint(0, num_nodes, (2, num_nodes * 2))
            edge_index = torch.cat([ring, random_edges], dim=1)
        elif label == 1:
            cluster = max(3, num_nodes // 2)
            src = torch.randint(0, cluster, (num_nodes * 3,))
            dst = torch.randint(0, cluster, (num_nodes * 3,))
            long_jump = torch.randint(0, num_nodes, (2, num_nodes))
            edge_index = torch.cat([torch.stack([src, dst], dim=0), long_jump], dim=1)
        else:
            src = torch.randint(0, num_nodes, (num_nodes * 2,))
            dst = (src + torch.randint(2, 7, (num_nodes * 2,))) % num_nodes
            extra = torch.randint(0, num_nodes, (2, num_nodes))
            edge_index = torch.cat([torch.stack([src, dst], dim=0), extra], dim=1)

        edge_attr = torch.randn(edge_index.shape[1], 4) * 0.15 + label * 0.04

        noisy_label = label
        if torch.rand(1).item() < 0.1:
            noisy_label = torch.randint(0, 3, (1,)).item()

        y = torch.zeros(3)
        y[noisy_label] = 1.0

        data.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))

    return data

def run_silently(fn):
    """Run a Lightning call with stdout/stderr, warnings, and PL logs disabled."""
    # Disable python warnings
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        # Disable Lightning logs globally
        pl._logger.setLevel("ERROR")

        # Silence stdout and stderr
        buffer_out = io.StringIO()
        buffer_err = io.StringIO()
        with contextlib.redirect_stdout(buffer_out), contextlib.redirect_stderr(buffer_err):
            return fn()



@step(output_materializers=PyGDataListMaterializer, enable_cache=False)
def create_data() -> list[Data]:
    """Step 1: Generate synthetic graphs for tuning."""
    return create_synthetic_tuning_data()


@step(output_materializers=GraphDataModuleMaterializer, enable_cache=False)
def create_datamodule(data: list[Data]) -> GraphDataModule:
    """Step 2: Wrap the dataset in a Lightning DataModule."""
    return GraphDataModule(dataset=data, val_split=0.25, batch_size=32, num_workers=0)


@step(enable_cache=False)
def run_hyperparameter_search(datamodule: GraphDataModule, n_trials: int = 4) -> dict:
    """Step 3: Perform an Optuna search over hidden size, dropout, and learning rate."""

    def objective(trial: optuna.Trial) -> float:
        hidden_dim = trial.suggest_categorical("hidden_dim", [64, 128, 256])
        dropout = trial.suggest_float("dropout", 0.0, 0.3)
        lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)

        model = GroupClassifier(num_classes=3, hidden=hidden_dim, dropout=dropout)
        lightning_module = GraphLightningModule(model, task="classification", lr=lr)
        accelerator, devices = detect_available_accelerator()

        trainer = pl.Trainer(
            accelerator=accelerator,
            devices=devices,
            max_epochs=2,
            logger=False,
            enable_checkpointing=False,
            enable_progress_bar=False,
        )

        datamodule.setup(stage="fit")

        def fit():
            trainer.fit(lightning_module, datamodule=datamodule)

        def validate():
            return trainer.validate(lightning_module, datamodule=datamodule, verbose=False)

        run_silently(fit)
        val_metrics = run_silently(validate)

        if val_metrics and isinstance(val_metrics[0], dict):
            accuracy = val_metrics[0].get("val_accuracy")
            if accuracy is not None:
                return float(accuracy)
            loss = val_metrics[0].get("val_loss")
            if loss is not None:
                return 1.0 / (1.0 + float(loss))
        return 0.0

    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials)

    return {
        "best_hidden_dim": study.best_params["hidden_dim"],
        "best_dropout": study.best_params["dropout"],
        "best_lr": study.best_params["lr"],
        "best_accuracy": study.best_value,
        "n_trials": len(study.trials),
    }


@step(enable_cache=False)
def train_best_model(best_params: dict, datamodule: GraphDataModule) -> GraphLightningModule:
    """Step 4: Train a final model with the best Optuna parameters."""
    model = GroupClassifier(
        num_classes=3,
        hidden=best_params["best_hidden_dim"],
        dropout=best_params["best_dropout"],
    )
    lightning_module = GraphLightningModule(
        model,
        task="classification",
        lr=best_params["best_lr"],
    )

    accelerator, devices = detect_available_accelerator()
    trainer = pl.Trainer(
        accelerator=accelerator,
        devices=devices,
        max_epochs=5,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
    )

    datamodule.setup(stage="fit")

    def fit():
        trainer.fit(lightning_module, datamodule=datamodule)

    run_silently(fit)
    return lightning_module.eval()


@pipeline
def tuning_pipeline(n_trials: int = 4):
    data = create_data()
    datamodule = create_datamodule(data)
    best_params = run_hyperparameter_search(datamodule, n_trials=n_trials)
    tuned_model = train_best_model(best_params, datamodule)
    return tuned_model, datamodule, best_params



## Run the Optuna Sweep

Execute the pipeline with a small number of trials (increase `n_trials` for real sweeps).
The pipeline stores all runs and best parameters in ZenML so you can reproduce the sweep later.


In [3]:
run = tuning_pipeline.with_options(enable_cache=False)(n_trials=30)
print(f"Pipeline run status: {run.status}")

tuned_module = load_step_output(run, "train_best_model")
datamodule = load_step_output(run, "create_datamodule")
best_params = load_step_output(run, "run_hyperparameter_search")

if tuned_module is None or datamodule is None or best_params is None:
    raise RuntimeError("Failed to load artifacts from the tuning pipeline.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tuned_module = tuned_module.to(device).eval()
datamodule.setup(stage="fit")

print("Best hyperparameters:")
for key, value in best_params.items():
    print(f"- {key}: {value}")


[37mInitiating a new run for the pipeline: [0m[38;5;105mtuning_pipeline[37m.[0m
[37mCaching is disabled by default for [0m[38;5;105mtuning_pipeline[37m.[0m
[37mUsing user: [0m[38;5;105mdefault[37m[0m
[37mUsing stack: [0m[38;5;105mdefault[37m[0m
[37m  deployer: [0m[38;5;105mdefault[37m[0m
[37m  orchestrator: [0m[38;5;105mdefault[37m[0m
[37m  artifact_store: [0m[38;5;105mdefault[37m[0m
[37mYou can visualize your pipeline runs in the [0m[38;5;105mZenML Dashboard[37m. In order to try it locally, please run [0m[38;5;105mzenml login --local[37m.[0m
[37mStep [0m[38;5;105mcreate_data[37m has started.[0m
[37mStep [0m[38;5;105mcreate_data[37m has finished in [0m[38;5;105m0.153s[37m.[0m
[37mStep [0m[38;5;105mcreate_datamodule[37m has started.[0m
[37mStep [0m[38;5;105mcreate_datamodule[37m has finished in [0m[38;5;105m0.074s[37m.[0m
[37mStep [0m[38;5;105mrun_hyperparameter_search[37m has started.[0m


[I 2025-11-25 19:11:20,284] A new study created in memory with name: no-name-1dc8e1d5-c88f-42c7-af90-b51aa93570d6
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
[I 2025-11-25 19:11:21,101] Trial 0 finished with value: 0.699999988079071 and parameters: {'hidden_dim': 256, 'dropout': 0.1917318643571018, 'lr': 0.00022081801915407043}. Best is trial 0 with value: 0.699999988079071.
[I 2025-11-25 19:11:21,394] Trial 1 finished with value: 0.8333333730697632 and parameters: {'hidden_dim': 256, 'dropout': 0.04549340117110281, 'lr': 0.00045642066793705655}. Best is trial 1 with value: 0.8333333730697632.
[I 2025-11-25 19:11:21,671] Trial 2 finished with value: 0.8133333325386047 and parameters: {'hidden_dim': 128, 'dropout': 0.14228022583255323, 'lr': 0.0033759559542362047}. Best is trial 1 with value: 0.8333333730697632.
[I 2025-11-25 19:11:21,965] Trial 3 finished with value: 0.7966667413711548 and parameters: {'hidden_dim': 256, 'dropout': 0.2062969983644003

[37mStep [0m[38;5;105mrun_hyperparameter_search[37m has finished in [0m[38;5;105m9.167s[37m.[0m
[37mStep [0m[38;5;105mtrain_best_model[37m has started.[0m
[37mStep [0m[38;5;105mtrain_best_model[37m has finished in [0m[38;5;105m1.160s[37m.[0m
[37mPipeline run has finished in [0m[38;5;105m12.384s[37m.[0m
Pipeline run status: completed
Best hyperparameters:
- best_hidden_dim: 256
- best_dropout: 0.08196706282398299
- best_lr: 0.0005439883955058748
- best_accuracy: 0.846666693687439
- n_trials: 30


## Inspect the Tuned Model

Check that the tuned model has the expected shape, parameter count, and device placement.


In [4]:
device = next(tuned_module.parameters()).device
param_count = sum(p.numel() for p in tuned_module.parameters())

train_loader = datamodule.train_dataloader()
first_batch = next(iter(train_loader))

print("Tuned Model Summary:")
print(f"- Run: {run.name}")
print(f"- Device: {device}")
print(f"- Parameters: {param_count:,}")
print(f"- Nodes per batch: {first_batch.x.shape[0]} | Features: {first_batch.x.shape[1]}")
print(f"- Edges: {first_batch.edge_index.shape[1]}")

with torch.no_grad():
    logits = tuned_module(first_batch.to(device))
    print(f"- Output logits shape: {tuple(logits.shape)}")


Tuned Model Summary:
- Run: tuning_pipeline-2025_11_26-00_11_19_389987
- Device: cuda:0
- Parameters: 1,848,324
- Nodes per batch: 370 | Features: 5
- Edges: 1218
- Output logits shape: (32, 3)


## Evaluate Validation Accuracy

Run the tuned model on the validation split. We handle label shapes explicitly because
PyG batches graph-level labels into `(batch_size, num_classes)` tensors.


In [5]:
val_loader = datamodule.val_dataloader()
if isinstance(val_loader, list) and len(val_loader) == 0:
    val_loader = datamodule.train_dataloader()

correct = 0
total = 0
tuned_module.eval()

for batch in val_loader:
    batch = batch.to(device)
    with torch.no_grad():
        logits = tuned_module(batch)

    labels = batch.y
    if labels.dim() == 1:
        if labels.shape[0] % 3 == 0:
            labels = labels.view(-1, 3)
        else:
            labels = labels.unsqueeze(0)
    labels = torch.argmax(labels, dim=1)

    preds = torch.argmax(logits, dim=1)
    correct += int((preds == labels).sum().item())
    total += int(labels.numel())

accuracy = correct / total if total > 0 else 0.0
print(f"Validation accuracy: {accuracy:.1%} ({correct}/{total})")


Validation accuracy: 77.0% (77/100)
