# PIONEER ML Tutorial

Quick walkthrough of the pipeline framework, Lightning utilities, and how to plug in a model. This notebook builds a tiny synthetic dataset, wraps it in a DataModule, trains a model with the Lightning pipeline stage, and inspects results.

## 1) Imports

In [None]:
import numpy as np

from pioneerml.data import GraphGroupDataset
from pioneerml.models import GroupClassifier
from pioneerml.pipelines import Pipeline, Context, StageConfig
from pioneerml.pipelines.stages import LightningTrainStage
from pioneerml.training import GraphDataModule, GraphLightningModule


## 2) Create a synthetic dataset
We generate a few fake time-group records with the standardized per-hit fields expected by `GraphGroupDataset`.

In [None]:
def make_record(num_hits: int, event_id: int) -> dict:
    coord = np.random.randn(num_hits).astype(np.float32)
    z = np.random.randn(num_hits).astype(np.float32)
    energy = np.abs(np.random.randn(num_hits)).astype(np.float32)
    view = np.random.randint(0, 2, num_hits).astype(np.float32)

    # Multi-label targets: [pion, muon, mip]
    labels = [int(energy.mean() > 0.5), int(num_hits % 2 == 0)]
    if len(labels) < 3:
        labels.append(0)

    return {
        "coord": coord,
        "z": z,
        "energy": energy,
        "view": view,
        "labels": labels,
        "event_id": event_id,
        "group_id": event_id,
    }

records = [make_record(num_hits=8 + i, event_id=i) for i in range(20)]
dataset = GraphGroupDataset(records, num_classes=3)
dataset[0]  # trigger a build and inspect

## 3) Wrap data with a Lightning DataModule
Splits the dataset and prepares PyG loaders.

In [None]:
datamodule = GraphDataModule(dataset=dataset, batch_size=4, val_split=0.2, test_split=0.0)
datamodule.setup()
datamodule.train_dataloader(), datamodule.val_dataloader()

## 4) Build the model and Lightning module

In [None]:
model = GroupClassifier(num_classes=3, hidden=64, num_blocks=2)
lightning_module = GraphLightningModule(model, task="classification", lr=1e-3)
lightning_module


## 5) Compose a pipeline with the Lightning training stage
The `LightningTrainStage` fits the module using any provided DataModule and records the trainer + trained module back into the shared `Context`.

In [None]:
train_stage = LightningTrainStage(
    config=StageConfig(
        name="train",
        params={
            "module": lightning_module,
            "datamodule": datamodule,
            "trainer_params": {
                "max_epochs": 2,
                "limit_train_batches": 2,
                "limit_val_batches": 1,
                "logger": False,
                "enable_checkpointing": False,
            },
        },
    )
)

pipeline = Pipeline([train_stage], name="tutorial_pipeline")
ctx = pipeline.run(Context())
ctx.summary()
ctx.get("metrics", {})


## 6) Next steps
- Swap in your own datasets or DataModules.
- Add stages for preprocessing, evaluation, and checkpointing.
- Integrate experiment tracking (e.g., Weights & Biases) by configuring the Lightning Trainer.