# Group Classification (ZenML)

Train the `GroupClassifier` on parquet-based time-group data:
- Load groups from `ml_output_*.parquet` using the **C++ Arrow dataloader** (zero-copy)
- Train via the new ZenML `group_classification_pipeline`
- Export a TorchScript model (for C++/Python inference)

**Note:** the model predicts a label *per time-group* (not per event). We aggregate to events later via the output adapter.


In [1]:
from pioneerml.zenml import utils as zenml_utils

PROJECT_ROOT = zenml_utils.find_project_root()
zenml_client = zenml_utils.setup_zenml_for_notebook(root_path=PROJECT_ROOT, use_in_memory=True)
print(f"ZenML ready with stack: {zenml_client.active_stack_model.name}")


Using ZenML repository root: /workspace
Ensure this is the top-level of your repo (.zen must live here).
[37mInitializing the ZenML global configuration version to 0.92.0[0m
[37mCreating database tables[0m
[37mCreating default project 'default' ...[0m
[37mCreating default stack...[0m
[33mThe current repo active project is no longer available.[0m
[37mSetting the repo active project to 'default'.[0m
[33mThe current repo active stack is no longer available. Resetting the active stack to default.[0m
[37mSetting the global active project to 'default'.[0m
[33mSetting the global active stack to default.[0m
[37mReloading configuration file /workspace/.zen/config.yaml[0m
ZenML ready with stack: default


In [2]:
from pathlib import Path
from datetime import datetime
import sys

import pyarrow.parquet as pq
import torch

from pioneerml.zenml import load_step_output

# Ensure the C++ dataloader Python bindings are on the path
pml_bindings = Path(PROJECT_ROOT) / "external" / "pioneerml_dataloaders" / "build" / "bindings"
sys.path.insert(0, str(pml_bindings))

from pioneerml.zenml.pipelines.training import group_classification_pipeline


In [3]:
# Build a tiny parquet shard for quick debugging

data_dir = Path(PROJECT_ROOT) / "data"
src_parquet = data_dir / "ml_output_000.parquet"

small_parquet = Path(PROJECT_ROOT) / "artifacts" / "classify_groups_small.parquet"
small_rows = 8

table = pq.read_table(src_parquet)
table = table.slice(0, small_rows)
small_parquet.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, small_parquet)

parquet_paths = [str(small_parquet)]
print(f"Wrote small parquet: {small_parquet} (rows={small_rows})")


Wrote small parquet: /workspace/artifacts/classify_groups_small.parquet (rows=8)


In [4]:
# Run the new ZenML pipeline

run = group_classification_pipeline.with_options(enable_cache=False)(
    parquet_paths=parquet_paths,
    config_json={"time_window_ns": 1.0},
    max_epochs=1,
    lr=1e-3,
    weight_decay=1e-4,
)
print(f"Run name: {run.name}")
print(f"Run status: {run.status}")


[37mInitiating a new run for the pipeline: [0m[38;5;105mgroup_classification_pipeline[37m.[0m
[37mRegistered new pipeline: [0m[38;5;105mgroup_classification_pipeline[37m.[0m
[37mCaching is disabled by default for [0m[38;5;105mgroup_classification_pipeline[37m.[0m
[37mUsing user: [0m[38;5;105mdefault[37m[0m
[37mUsing stack: [0m[38;5;105mdefault[37m[0m
[37m  deployer: [0m[38;5;105mdefault[37m[0m
[37m  orchestrator: [0m[38;5;105mdefault[37m[0m
[37m  artifact_store: [0m[38;5;105mdefault[37m[0m
[37mYou can visualize your pipeline runs in the [0m[38;5;105mZenML Dashboard[37m. In order to try it locally, please run [0m[38;5;105mzenml login --local[37m.[0m
[37mStep [0m[38;5;105mload_group_classifier_data[37m has started.[0m
[37mStep [0m[38;5;105mload_group_classifier_data[37m has finished in [0m[38;5;105m0.164s[37m.[0m
[37mStep [0m[38;5;105mtrain_group_classifier[37m has started.[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 5070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

`Trainer.fit` stopped: `max_epochs=1` reached.


[37mStep [0m[38;5;105mtrain_group_classifier[37m has finished in [0m[38;5;105m2.098s[37m.[0m
[37mPipeline run has finished in [0m[38;5;105m3.311s[37m.[0m
Run name: group_classification_pipeline-2026_02_01-03_14_06_198794
Run status: completed


In [5]:
# Load artifacts and run a quick forward pass

trained_module = load_step_output(run, "train_group_classifier")
batch = load_step_output(run, "load_group_classifier_data")

if trained_module is None or batch is None:
    raise RuntimeError("Could not load artifacts from the group_classification_pipeline run.")

trained_module.eval()
device = next(trained_module.parameters()).device
data = batch.data.to(device)

with torch.no_grad():
    logits = trained_module.model(data)

print(f"Device: {device}")
print(f"Events (graphs): {data.num_graphs}")
print(f"Groups: {data.num_groups}")
print(f"x: {tuple(data.x.shape)}")
print(f"edge_index: {tuple(data.edge_index.shape)}")
print(f"edge_attr: {tuple(data.edge_attr.shape)}")
print(f"group_ptr: {tuple(data.group_ptr.shape)}")
print(f"time_group_ids: {tuple(data.time_group_ids.shape)}")
print(f"logits: {tuple(logits.shape)}")


Device: cpu
Events (graphs): 8
Groups: 97
x: (416, 4)
edge_index: (2, 26378)
edge_attr: (26378, 4)
group_ptr: (9,)
time_group_ids: (416,)
logits: (97, 3)


In [7]:
# Export TorchScript model

export_dir = Path(PROJECT_ROOT) / "trained_models" / "groupclassifier"
export_dir.mkdir(parents=True, exist_ok=True)

stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
torchscript_path = export_dir / f"groupclassifier_{stamp}_torchscript.pt"

trained_module.model.cpu()
example = batch.data
trained_module.model.export_torchscript(torchscript_path, example, strict=False)
print(f"Saved TorchScript model: {torchscript_path}")


Saved TorchScript model: /workspace/trained_models/groupclassifier/groupclassifier_20260201_031410_torchscript.pt


## Save the TorchScript Model

Export the trained model to TorchScript so it can be loaded in C++ or Python.


In [8]:
# (Optional) Inspect a few target labels

print("targets sample:", batch.targets[:5])


targets sample: tensor([[1., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]]) tensor([[1., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.]])
