# Tutorial 2: Custom Models and Training

This tutorial demonstrates how to define a custom Graph Convolutional Network (GCN)
and integrate it into a ZenML pipeline. You'll learn:

- How to build a custom GCN model from scratch
- How to structure the model to work with PyTorch Geometric
- How to wrap your custom model in a ZenML pipeline
- How to train and evaluate the custom model

**Why custom models?** While PIONEER ML provides pre-built models like `GroupClassifier`,
you may want to experiment with different architectures. This tutorial shows you how to
create your own graph neural network and seamlessly integrate it into the ZenML workflow.


In [1]:
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from zenml import pipeline, step

from pioneerml.models.base import GraphModel
from pioneerml.training import GraphDataModule, GraphLightningModule
from pioneerml.zenml.materializers import (
    GraphDataModuleMaterializer,
    PyGDataListMaterializer,
)
from pioneerml.zenml import load_step_output
from pioneerml.zenml import utils as zenml_utils
from pioneerml.zenml.utils import detect_available_accelerator

# Initialize ZenML for notebook use
# setup_zenml_for_notebook automatically finds the project root by searching
# upward for .zen or .zenml directories, ensuring we use the root configuration.
# use_in_memory=True creates a temporary in-memory SQLite store, perfect for
# tutorials where we don't need persistent artifact storage.
zenml_client = zenml_utils.setup_zenml_for_notebook(use_in_memory=True)
print(f"ZenML initialized with stack: {zenml_client.active_stack_model.name}")


[33mThe current repo active project is no longer available.[0m
[37mSetting the repo active project to 'default'.[0m
[33mThe current repo active stack is no longer available. Resetting the active stack to default.[0m
ZenML initialized with stack: default


## Define a Custom Graph Convolutional Network

We'll create a simple but effective GCN model for graph classification. The architecture
consists of:

1. **Two GCN layers**: Each layer aggregates information from neighboring nodes
   - First layer: Projects 5D node features → hidden dimension (64)
   - Second layer: Further refines node representations within the hidden space

2. **Global mean pooling**: Aggregates all node features into a single graph-level
   representation by taking the mean across all nodes

3. **Classifier head**: A linear layer that maps the graph representation to class
   logits (3 classes: π, μ, e+)

**Why this architecture?** GCNs are excellent for learning from graph-structured data
because they respect the graph topology - nodes learn representations based on their
neighbors. Global pooling allows us to make graph-level predictions from node features.


In [2]:
class SimpleGCN(GraphModel):
    """A simple custom Graph Convolutional Network for graph classification.
    
    This model demonstrates the core components of a GCN:
    - Graph convolution layers that aggregate neighbor information
    - Global pooling to create graph-level representations
    - A classifier head for final predictions
    """
    
    def __init__(self, num_classes: int = 3, hidden_dim: int = 64):
        super().__init__()
        
        # First GCN layer: projects 5D node features to hidden dimension
        # GCNConv learns to aggregate information from each node's neighbors
        self.conv1 = GCNConv(5, hidden_dim)
        
        # Second GCN layer: refines node representations within hidden space
        # This allows the model to learn more complex patterns in the graph structure
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        
        # Classifier head: maps graph-level representation to class logits
        # Output size matches number of classes (3 for π, μ, e+)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, batch):
        """Forward pass through the GCN.
        
        Args:
            batch: PyTorch Geometric Batch object containing:
                - x: Node features [num_nodes, 5]
                - edge_index: Edge connectivity [2, num_edges]
                - batch: Batch assignment vector [num_nodes]
        
        Returns:
            Logits for each graph in the batch [batch_size, num_classes]
        """
        # Extract graph components from the batch
        x, edge_index, batch_indices = batch.x, batch.edge_index, batch.batch

        # First graph convolution: aggregate neighbor features
        # ReLU activation introduces non-linearity
        x = self.conv1(x, edge_index).relu()
        
        # Second graph convolution: further refine node representations
        x = self.conv2(x, edge_index).relu()

        # Global mean pooling: aggregate all node features into a single vector
        # This creates one representation per graph in the batch
        # Shape: [batch_size, hidden_dim]
        x = global_mean_pool(x, batch_indices)

        # Final classification: map graph representation to class logits
        # Shape: [batch_size, num_classes]
        return self.classifier(x)


# Verify the model can be instantiated
model = SimpleGCN(num_classes=3, hidden_dim=64)
print(f"Custom GCN model created: {model.num_parameters:,} parameters")


Custom GCN model created: 4,739 parameters


## Build the Training Pipeline

Now we'll create a complete ZenML pipeline that uses our custom GCN model. The pipeline
follows the same structure as Tutorial 1:

1. **`create_data`**: Generates synthetic graph data with class labels
2. **`create_datamodule`**: Splits data into train/validation sets
3. **`create_model`**: Instantiates our custom SimpleGCN
4. **`create_lightning_module`**: Wraps the model in PyTorch Lightning
5. **`train_model`**: Executes the training loop

**Key difference from Tutorial 1**: Here we use our custom `SimpleGCN` instead of the
pre-built `GroupClassifier`. Everything else works exactly the same - this demonstrates
the flexibility of the ZenML + GraphModel architecture.


In [3]:
def create_simple_synthetic_data(num_samples: int = 150) -> list[Data]:
    """Generate synthetic graph data for the custom model.
    
    Creates graphs with random node features and edges, each labeled with
    one of three classes (π, μ, e+).
    """
    data = []
    for _ in range(num_samples):
        # Create graphs with 4-8 nodes
        num_nodes = torch.randint(4, 9, (1,)).item()
        
        # Random node features (5D: coord, z, energy, view, group_energy)
        x = torch.randn(num_nodes, 5)
        
        # Random edge connectivity
        edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2))
        edge_attr = torch.randn(edge_index.shape[1], 4)

        # Random class label (one-hot encoded)
        label = torch.randint(0, 3, (1,)).item()
        y = torch.zeros(3)
        y[label] = 1.0

        data.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))
    return data


@step(output_materializers=PyGDataListMaterializer, enable_cache=False)
def create_data() -> list[Data]:
    """Step 1: Generate synthetic graph data.
    
    The PyGDataListMaterializer ensures efficient serialization of PyTorch Geometric
    Data objects, avoiding pickle warnings.
    """
    return create_simple_synthetic_data()


@step(output_materializers=GraphDataModuleMaterializer, enable_cache=False)
def create_datamodule(data: list[Data]) -> GraphDataModule:
    """Step 2: Create data module with train/val split.
    
    - val_split=0.3: 70% train, 30% validation
    - batch_size=16: smaller batches for the custom model
    """
    return GraphDataModule(dataset=data, val_split=0.3, batch_size=16, num_workers=0)


@step
def create_model(num_classes: int = 3, hidden_dim: int = 64) -> SimpleGCN:
    """Step 3: Instantiate our custom GCN model.
    
    This step creates the SimpleGCN we defined above with the specified
    number of classes and hidden dimension.
    """
    return SimpleGCN(num_classes=num_classes, hidden_dim=hidden_dim)


@step
def create_lightning_module(model: SimpleGCN) -> GraphLightningModule:
    """Step 4: Wrap model in PyTorch Lightning module.
    
    Adds training logic, loss function (BCE for multilabel classification),
    and optimizer configuration.
    """
    return GraphLightningModule(model, task="classification", lr=5e-4)


@step
def train_model(
    lightning_module: GraphLightningModule,
    datamodule: GraphDataModule
) -> GraphLightningModule:
    """Step 5: Execute training loop.
    
    Automatically detects available hardware (CPU/GPU) and runs for 5 epochs.
    Returns the trained module in eval mode.
    """
    import pytorch_lightning as pl

    accelerator, devices = detect_available_accelerator()

    trainer = pl.Trainer(
        accelerator=accelerator,
        devices=devices,
        max_epochs=5,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
    )

    trainer.fit(lightning_module, datamodule=datamodule)
    return lightning_module.eval()


@pipeline
def custom_model_pipeline():
    """Compose all steps into a complete training pipeline.
    
    ZenML automatically wires the outputs of one step to the inputs of the next
    based on parameter names. For example, `create_datamodule(data)` receives
    the output from `create_data()`.
    """
    data = create_data()
    datamodule = create_datamodule(data)
    model = create_model()
    lightning_module = create_lightning_module(model)
    trained_module = train_model(lightning_module, datamodule)
    return trained_module, datamodule


## Run the Pipeline

Execute the pipeline and load the trained model. After the pipeline completes, we load
the trained model and datamodule using `load_step_output`. These artifacts are stored
by ZenML and can be reloaded anytime without re-running the pipeline.

**Why use `enable_cache=False`?** This ensures the pipeline runs fresh each time,
which is useful for tutorials. In production, you'd typically enable caching to
skip re-running unchanged steps.


In [4]:
run = custom_model_pipeline.with_options(enable_cache=False)()
print(f"Pipeline run status: {run.status}")

# Load artifacts from the run
trained_module = load_step_output(run, "train_model")
datamodule = load_step_output(run, "create_datamodule")
custom_model = load_step_output(run, "create_model")

if trained_module is None or datamodule is None or custom_model is None:
    raise RuntimeError("Could not load artifacts from the custom_model_pipeline run.")

# Move model to best available device
# Note: Models saved by ZenML may be on CPU. We move to GPU if available
# to match the training device and speed up inference.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trained_module = trained_module.to(device).eval()
datamodule.setup(stage="fit")
print(f"Loaded artifacts from run {run.name} (device={device})")


[37mInitiating a new run for the pipeline: [0m[38;5;105mcustom_model_pipeline[37m.[0m
[37mRegistered new pipeline: [0m[38;5;105mcustom_model_pipeline[37m.[0m
[37mCaching is disabled by default for [0m[38;5;105mcustom_model_pipeline[37m.[0m
[37mUsing user: [0m[38;5;105mdefault[37m[0m
[37mUsing stack: [0m[38;5;105mdefault[37m[0m
[37m  artifact_store: [0m[38;5;105mdefault[37m[0m
[37m  orchestrator: [0m[38;5;105mdefault[37m[0m
[37m  deployer: [0m[38;5;105mdefault[37m[0m
[37mYou can visualize your pipeline runs in the [0m[38;5;105mZenML Dashboard[37m. In order to try it locally, please run [0m[38;5;105mzenml login --local[37m.[0m
[37mStep [0m[38;5;105mcreate_data[37m has started.[0m
[37mStep [0m[38;5;105mcreate_data[37m has finished in [0m[38;5;105m0.086s[37m.[0m
[37mStep [0m[38;5;105mcreate_model[37m has started.[0m
[37mStep [0m[38;5;105mcreate_model[37m has finished in [0m[38;5;105m0.047s[37m.[0m
[37mStep [0m[38;

## Inspect the Custom Model

Let's examine the trained model to understand its structure and verify it trained
correctly. We'll check:

- **Parameter count**: How many trainable parameters does our custom GCN have?
- **Input/output shapes**: What dimensions does the model expect and produce?
- **Device placement**: Where is the model running (CPU/GPU)?

This helps validate that the model architecture is correct and that training completed
successfully.


In [5]:
device = next(trained_module.parameters()).device
param_count = sum(p.numel() for p in trained_module.parameters())

# Get a sample batch to inspect input shapes
train_loader = datamodule.train_dataloader()
first_batch = next(iter(train_loader))

print("Custom GCN Model Summary:")
print(f"- Run: {run.name}")
print(f"- Device: {device}")
print(f"- Total parameters: {param_count:,}")
print(f"- Input node features: {first_batch.x.shape[1]}D (shape: {tuple(first_batch.x.shape)})")
print(f"- Number of nodes in batch: {first_batch.x.shape[0]}")
print(f"- Number of edges: {first_batch.edge_index.shape[1]}")

# Test forward pass to verify output shape
with torch.no_grad():
    output = trained_module(first_batch.to(device))
    print(f"- Output logits shape: {tuple(output.shape)} (batch_size, num_classes)")


Custom GCN Model Summary:
- Run: custom_model_pipeline-2025_11_25-08_43_57_031350
- Device: cpu
- Total parameters: 4,739
- Input node features: 5D (shape: (88, 5))
- Number of nodes in batch: 88
- Number of edges: 176
- Output logits shape: (16, 3) (batch_size, num_classes)


## Evaluate Model Performance

Finally, let's compute validation accuracy to confirm the custom model trained
successfully. We'll run inference on the validation set and compare predictions
to ground truth labels.

**What we're measuring**: Classification accuracy - the percentage of graphs
correctly classified into their true class (π, μ, or e+).

**Note on device placement**: The model was trained on GPU (as shown in the pipeline
output), but when loaded from ZenML artifacts, it may be on CPU. We explicitly move
it to the best available device (GPU if available) to match training conditions and
speed up inference. This is why you see the device change from CPU to GPU (or vice versa)
depending on your hardware.


In [6]:
# Get validation loader (fallback to train if val is empty)
val_loader = datamodule.val_dataloader()
if isinstance(val_loader, list) and len(val_loader) == 0:
    val_loader = datamodule.train_dataloader()

# Compute accuracy
correct = 0
total = 0
trained_module.eval()

for batch in val_loader:
    batch = batch.to(device)
    with torch.no_grad():
        logits = trained_module(batch)
    
    # Handle labels: PyTorch Geometric batches graph-level labels
    # Each graph has y of shape (num_classes,), batched into (batch_size, num_classes)
    labels = batch.y
    
    # Ensure labels are 2D: (batch_size, num_classes)
    if labels.dim() == 1:
        # If flattened, reshape assuming num_classes=3
        if labels.shape[0] % 3 == 0:
            labels = labels.view(-1, 3)
        else:
            # Single graph case: add batch dimension
            labels = labels.unsqueeze(0)
    
    # Convert one-hot encoded labels to class indices
    # Shape: (batch_size, num_classes) -> (batch_size,)
    labels = torch.argmax(labels, dim=1)
    
    # Get predicted class (highest logit)
    # Shape: (batch_size, num_classes) -> (batch_size,)
    preds = torch.argmax(logits, dim=1)
    
    # Count correct predictions (both should be shape (batch_size,))
    correct += int((preds == labels).sum().item())
    total += int(labels.numel())

accuracy = correct / total if total > 0 else 0.0
print(f"Validation accuracy: {accuracy:.1%} ({correct}/{total} correct)")
print(f"\nThe custom GCN successfully learned to classify graphs!")
