# Tutorial 5: Building a Custom Lightning DataModule (Non-Graph)

Learn how to implement your own `LightningDataModule`, plug it into a ZenML pipeline, and run it end-to-end.
We'll build a simple tabular classifier using a custom DataModule, then execute a ZenML pipeline that trains
and evaluates it.


In [1]:
import torch

from pioneerml.zenml import load_step_output
from pioneerml.zenml import utils as zenml_utils
from pioneerml.zenml.pipelines.tutorial_examples.tabular_datamodule_pipeline import (
    TabularConfig,
    TabularDataModule,
    TabularClassifier,
    tabular_datamodule_pipeline,
)

zenml_client = zenml_utils.setup_zenml_for_notebook(use_in_memory=True)
print(f"ZenML initialized with stack: {zenml_client.active_stack_model.name}")


Using ZenML repository root: /home/jack/python_projects/pioneerML
Ensure this is the top-level of your repo (.zen must live here).
ZenML initialized with stack: default


## 1. The DataModule blueprint

`TabularDataModule` inherits from `pytorch_lightning.LightningDataModule` and manages:
- Synthetic tabular dataset creation (clustered features per class)
- Train/val/test splits with deterministic seeds
- Standard PyTorch DataLoaders

It lives in `src/pioneerml/zenml/pipelines/tutorial_examples/tabular_datamodule_pipeline.py` so notebooks can import it directly.


In [2]:
config = TabularConfig(
    num_samples=300,
    num_features=8,
    num_classes=3,
    batch_size=32,
    val_split=0.2,
    test_split=0.1,
    seed=42,
)

datamodule = TabularDataModule(config)
datamodule.setup(stage="fit")

train_batch = next(iter(datamodule.train_dataloader()))
print("Train batch shapes -> x:", tuple(train_batch[0].shape), "y:", tuple(train_batch[1].shape))
if datamodule.val_dataset:
    val_batch = next(iter(datamodule.val_dataloader()))
    print("Val batch shapes   -> x:", tuple(val_batch[0].shape), "y:", tuple(val_batch[1].shape))


Train batch shapes -> x: (32, 8) y: (32,) (32, 8) y: (32,)
Val batch shapes   -> x: (32, 8) y: (32,) (32, 8) y: (32,)


## 2. The LightningModule

`TabularClassifier` is a tiny MLP with a cross-entropy objective. It logs train/val loss and accuracy during training.


In [3]:
model = TabularClassifier(config)
with torch.no_grad():
    sample_logits = model(train_batch[0])
print("Logits shape:", tuple(sample_logits.shape))


Logits shape: (32, 3) (32, 3)


## 3. Run the ZenML pipeline

`tabular_datamodule_pipeline` wires together steps to build the DataModule, build the model, train, and collect predictions/targets.


In [4]:
run = tabular_datamodule_pipeline.with_options(enable_cache=False)(config)
print(f"Pipeline run status: {run.status}")

trained_model = load_step_output(run, "train_tabular_model")
datamodule_run = load_step_output(run, "build_tabular_datamodule")
preds = load_step_output(run, "evaluate_tabular_model", output_name="output_0", index=0)
targets = load_step_output(run, "evaluate_tabular_model", output_name="output_1", index=0)

if preds is None or targets is None:
    outputs = load_step_output(run, "evaluate_tabular_model")
    if isinstance(outputs, (tuple, list)) and len(outputs) == 2:
        preds, targets = outputs

print("Preds shape:", tuple(preds.shape) if preds is not None else None)
print("Targets shape:", tuple(targets.shape) if targets is not None else None)


[37mInitiating a new run for the pipeline: [0m[38;5;105mtabular_datamodule_pipeline[37m.[0m
[37mCaching is disabled by default for [0m[38;5;105mtabular_datamodule_pipeline[37m.[0m
[37mUsing user: [0m[38;5;105mdefault[37m[0m
[37mUsing stack: [0m[38;5;105mdefault[37m[0m
[37m  deployer: [0m[38;5;105mdefault[37m[0m
[37m  orchestrator: [0m[38;5;105mdefault[37m[0m
[37m  artifact_store: [0m[38;5;105mdefault[37m[0m
[37mYou can visualize your pipeline runs in the [0m[38;5;105mZenML Dashboard[37m. In order to try it locally, please run [0m[38;5;105mzenml login --local[37m.[0m
[37mStep [0m[38;5;105mbuild_tabular_datamodule[37m has started.[0m
[33m[build_tabular_datamodule] No materializer is registered for type [0m[38;5;105m<class 'pioneerml.zenml.pipelines.tutorial_examples.tabular_datamodule_pipeline.TabularDataModule'>[33m, so the default Pickle materializer was used. Pickle is not production ready and should only be used for prototyping as t

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 1.4 K  | train
---------------------------------------------
1.4 K     Trainable params
0         Non-trainable params
1.4 K     Total params
0.006     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


[33m[train_tabular_model] /home/jack/virtual_environments/miniconda3/envs/pioneerml/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the [0m[38;5;105mnum_workers[33m argument[0m[38;5;105m to [33mnum_workers=15[0m[38;5;105m in the [33mDataLoader` to improve performance.
[0m
[33m[train_tabular_model] /home/jack/virtual_environments/miniconda3/envs/pioneerml/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the [0m[38;5;105mnum_workers[33m argument[0m[38;5;105m to [33mnum_workers=15[0m[38;5;105m in the [33mDataLoader` to improve performance.
[0m


`Trainer.fit` stopped: `max_epochs=10` reached.


[37mStep [0m[38;5;105mtrain_tabular_model[37m has finished in [0m[38;5;105m0.546s[37m.[0m
[37mStep [0m[38;5;105mevaluate_tabular_model[37m has started.[0m
[37mStep [0m[38;5;105mevaluate_tabular_model[37m has finished in [0m[38;5;105m0.638s[37m.[0m
[37mPipeline run has finished in [0m[38;5;105m2.153s[37m.[0m
Pipeline run status: completed
Preds shape: (60, 3) (60, 3)
Targets shape: (60,) (60,)


## 4. Compute a quick accuracy

Use the collected predictions/targets to verify the pipeline artifacts.


In [5]:
def simple_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
    preds = logits.argmax(dim=1)
    return float((preds == labels).float().mean().item())

acc = simple_accuracy(preds, targets)
print(f"Validation accuracy: {acc:.3f}")


Validation accuracy: 0.800


## 5. Recap

- Inherit from `LightningDataModule` to manage dataset creation and splits.
- Keep configuration in a dataclass (`TabularConfig`) for easy reuse.
- Wrap the DataModule/LightningModule in ZenML steps and run via a pipeline.
- Load artifacts with `load_step_output` to inspect predictions and metrics.
