# Tutorial 0: ZenML Quickstart

Run the smallest end-to-end ZenML pipeline in this repository, load the
trained model and data module from the ZenML artifacts, and generate a few
quick diagnostic plots.

What you'll see (with detailed interpretation guidance):
- How to spin up ZenML in in-memory mode (no server, minimal local state).
- A minimal training run on synthetic data using our `GroupClassifier`.
- How to pull artifacts back out of ZenML and compute plots.
- How to interpret each plot (axes, computation, and what “good” looks like).



In [None]:
import torch

from pioneerml.evaluation.plots import (
    plot_multilabel_confusion_matrix,
    plot_precision_recall_curves,
    plot_roc_curves,
)
from pioneerml.training import plot_loss_curves
from pioneerml.zenml import load_step_output
from pioneerml.zenml import utils as zenml_utils
from pioneerml.zenml.pipelines import zenml_training_pipeline

# Initialize ZenML for notebook use
# setup_zenml_for_notebook automatically finds the project root by searching
# upward for .zen or .zenml directories, ensuring we use the root configuration.
# use_in_memory=True creates a temporary in-memory SQLite store, perfect for
# tutorials where we don't need persistent artifact storage.
zenml_client = zenml_utils.setup_zenml_for_notebook(use_in_memory=True)
print(f"ZenML initialized with stack: {zenml_client.active_stack_model.name}")



## Run the Training Pipeline

Here we execute the complete ZenML training pipeline. The pipeline consists of
several steps:

1. **build_datamodule**: Creates synthetic graph data and splits it into train/val sets
2. **build_module**: Instantiates the GroupClassifier model wrapped in a Lightning module
3. **train_module**: Trains the model using PyTorch Lightning (auto-detects CPU/GPU)
4. **collect_predictions**: Runs inference on the validation set to get predictions and targets

**Why use `enable_cache=False`?** This ensures the pipeline runs fresh each time,
which is useful for tutorials. In production, you'd typically enable caching to
skip re-running unchanged steps.

After the pipeline completes, we load the artifacts (trained model, datamodule,
predictions, targets) using `load_step_output`. These artifacts are stored by
ZenML and can be reloaded anytime without re-running the pipeline - this makes
notebooks fast and interactive.



In [None]:
run = zenml_training_pipeline.with_options(enable_cache=False)()
print(f"Pipeline run status: {run.status}")

trained_module = load_step_output(run, "train_module")
datamodule = load_step_output(run, "build_datamodule")
predictions = load_step_output(run, "collect_predictions")[0]
targets = load_step_output(run, "collect_predictions")[1]

if trained_module is None or datamodule is None:
    raise RuntimeError("Could not load artifacts from the zenml_training_pipeline run.")

trained_module.eval()
datamodule.setup(stage="fit")
device = next(trained_module.parameters()).device
print(f"Loaded artifacts from run {run.name} (device={device})")



## Verify Predictions Were Collected

The `collect_predictions` step in the pipeline has already run inference and
collected predictions and targets. This cell simply verifies how many samples
were processed. The predictions are raw logits (before sigmoid), and targets
are one-hot encoded class labels - both ready for evaluation metrics and plots.



In [None]:
print(f"Collected predictions for {len(targets)} samples via pipeline step.")



## Visualize Training Diagnostics

We generate four diagnostic plots to understand model performance. All plots are
displayed inline in the notebook (no files saved) by setting `show=True` and
`save_path=None`. This makes the notebook self-contained and easy to share.

**1. Loss Curves** - Shows training and validation loss over epochs
- **What it shows**: How well the model is learning during training
- **Good signs**: Both curves decrease steadily and stay close together
- **Warning signs**: Large gap between train/val (overfitting), or flat lines (not learning)

**2. Confusion Matrices** - Per-class classification accuracy
- **What it shows**: For each class (π, μ, e+), how many true positives, false positives,
  true negatives, and false negatives
- **Good signs**: Dark diagonal (correct predictions), light off-diagonal (few errors)
- **Why normalized**: Makes it easy to compare classes with different sample sizes

**3. ROC Curves** - Ranking quality across all possible thresholds
- **What it shows**: True Positive Rate vs False Positive Rate as we vary the
  classification threshold
- **AUC score**: Area under curve - higher is better (1.0 = perfect, 0.5 = random)
- **Good signs**: Curves in top-left corner, AUC > 0.8

**4. Precision-Recall Curves** - Performance on imbalanced data
- **What it shows**: Precision vs Recall trade-off as we vary the threshold
- **Average Precision**: Summarizes performance, especially important for imbalanced classes
- **Good signs**: High curves that maintain precision even at high recall



In [None]:
plot_loss_curves(trained_module, title="Quickstart: Loss Curves", show=True)

plot_multilabel_confusion_matrix(
    predictions=predictions,
    targets=targets,
    class_names=["pi", "mu", "e+"],
    threshold=0.5,
    normalize=True,
    save_path=None,
    show=True,
)

plot_roc_curves(
    predictions=predictions,
    targets=targets,
    class_names=["pi", "mu", "e+"],
    save_path=None,
    show=True,
)

plot_precision_recall_curves(
    predictions=predictions,
    targets=targets,
    class_names=["pi", "mu", "e+"],
    save_path=None,
    show=True,
)
