# Group Splitter (ZenML)

Train the `GroupSplitter` on real preprocessed time-group data:
- Load preprocessed splitter groups from .npy files in the data folder
- Train the `GroupSplitter` via a ZenML pipeline with Optuna hyperparameter tuning
- Reload artifacts and render evaluation plots (confusion matrices, ROC curves, etc.)


In [None]:
from pathlib import Path
from pioneerml.zenml import load_step_output
from pioneerml.zenml import utils as zenml_utils
from pioneerml.zenml.pipelines.training import group_splitter_optuna_pipeline
from pioneerml.data import NODE_LABEL_TO_NAME, NUM_NODE_CLASSES

PROJECT_ROOT = zenml_utils.find_project_root()
zenml_client = zenml_utils.setup_zenml_for_notebook(root_path=PROJECT_ROOT, use_in_memory=True)
print(f"ZenML ready with stack: {zenml_client.active_stack_model.name}")


[37mInitializing the ZenML global configuration version to 0.92.0[0m
[37mCreating database tables[0m
[37mCreating default project 'default' ...[0m
[37mCreating default stack...[0m
[37mSetting the global active project to 'default'.[0m
[33mSetting the global active stack to default.[0m
Using ZenML repository root: /home/jack/python_projects/pioneerML
Ensure this is the top-level of your repo (.zen must live here).
ZenML ready with stack: default


In [None]:
# Configure the pipeline
# Use absolute path based on project root
file_pattern = str(Path(PROJECT_ROOT) / 'data' / 'mainTimeGroups_*.npy')
run = group_splitter_optuna_pipeline.with_options(enable_cache=False)(
    build_datamodule_params={
        # Data loading parameters
        'file_pattern': file_pattern,
        'max_files': 20,
        'limit_groups': 1000000,
        'min_hits': 1,
        # Datamodule parameters
        'use_group_probs': False,  # Set to True if you have classifier probabilities
        'batch_size': 8,
        'num_workers': 0,
        'val_split': 0.15,
        'seed': 42,
    },
    run_hparam_search_params={
        'n_trials': 25,
        'max_epochs': 20,
        'limit_train_batches': 0.8,
        'limit_val_batches': 1.0,
    },
    train_best_model_params={
        'max_epochs': 50,
        'early_stopping': True,
        'early_stopping_patience': 6,
        'early_stopping_monitor': 'val_loss',
        'early_stopping_mode': 'min',
    },
)
print(f"Run name: {run.name}")
print(f"Run status: {run.status}")


[37mInitiating a new run for the pipeline: [0m[38;5;105mgroup_splitter_optuna_pipeline[37m.[0m
[37mRegistered new pipeline: [0m[38;5;105mgroup_splitter_optuna_pipeline[37m.[0m
[37mCaching is disabled by default for [0m[38;5;105mgroup_splitter_optuna_pipeline[37m.[0m
[37mUsing user: [0m[38;5;105mdefault[37m[0m
[37mUsing stack: [0m[38;5;105mdefault[37m[0m
[37m  artifact_store: [0m[38;5;105mdefault[37m[0m
[37m  deployer: [0m[38;5;105mdefault[37m[0m
[37m  orchestrator: [0m[38;5;105mdefault[37m[0m
[37mYou can visualize your pipeline runs in the [0m[38;5;105mZenML Dashboard[37m. In order to try it locally, please run [0m[38;5;105mzenml login --local[37m.[0m
[37mStep [0m[38;5;105mload_splitter_data[37m has started.[0m
[31mFailed to run step [0m[38;5;105mload_splitter_data[31m: No files matched pattern 'data/mainTimeGroups_*.npy'[0m
[31mStep load_splitter_data failed.[0m
Traceback (most recent call last):
Traceback (most recent call l

In [None]:
trained_module = load_step_output(run, "train_best_splitter")
datamodule = load_step_output(run, "build_splitter_datamodule")
predictions = load_step_output(run, "collect_splitter_predictions", index=0)
targets = load_step_output(run, "collect_splitter_predictions", index=1)
best_params = load_step_output(run, "run_splitter_hparam_search")

if trained_module is None or datamodule is None:
    raise RuntimeError("Could not load artifacts from the optuna pipeline run.")

datamodule.setup(stage="fit")
trained_module.eval()
device = next(trained_module.parameters()).device
val_size = len(datamodule.val_dataset) if datamodule.val_dataset is not None else len(datamodule.train_dataset)
print(f"Loaded module on {device}; validation samples: {val_size}")
print("Best params from Optuna:", best_params)
print("Epochs actually run:", getattr(trained_module, "final_epochs_run", None))


In [None]:
# Class names for splitter (per-hit classification)
class_names = [NODE_LABEL_TO_NAME[i] for i in range(NUM_NODE_CLASSES)]
print("Class names (index-aligned):", class_names)


In [None]:
from pioneerml.evaluation.plots import (
    plot_loss_curves,
    plot_multilabel_confusion_matrix,
    plot_precision_recall_curves,
    plot_roc_curves,
)

plot_loss_curves(trained_module, title="Group splitter: loss", show=True)
plot_multilabel_confusion_matrix(
    predictions=predictions,
    targets=targets,
    class_names=class_names,
    threshold=0.5,
    normalize=True,
    show=True,
)
plot_roc_curves(predictions, targets, class_names=class_names, show=True)
plot_precision_recall_curves(predictions, targets, class_names=class_names, show=True)


## Save the Trained Model

Save the trained model and metadata for later use.


In [None]:
import torch
import json
from datetime import datetime

# Create checkpoints directory
checkpoints_dir = Path(PROJECT_ROOT) / "artifacts" / "checkpoints" / "group_splitter"
checkpoints_dir.mkdir(parents=True, exist_ok=True)

# Generate a timestamped filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"group_splitter_{timestamp}.pt"
metadata_filename = f"group_splitter_{timestamp}_metadata.json"

# Extract the underlying model from the Lightning module
model = trained_module.model
model.eval()

# Save model state_dict
model_path = checkpoints_dir / model_filename
torch.save(model.state_dict(), model_path)
print(f"Saved model state_dict to: {model_path}")

# Save metadata (hyperparameters, training info, etc.)
metadata = {
    "model_type": "GroupSplitter",
    "timestamp": timestamp,
    "run_name": run.name,
    "best_hyperparameters": best_params,
    "training_config": getattr(trained_module, "training_config", {}),
    "epochs_run": getattr(trained_module, "final_epochs_run", None),
    "dataset_info": {
        "train_size": len(datamodule.train_dataset) if datamodule.train_dataset else 0,
        "val_size": len(datamodule.val_dataset) if datamodule.val_dataset else 0,
        "num_classes": NUM_NODE_CLASSES,
        "class_names": class_names,
        "use_group_probs": getattr(datamodule, "use_group_probs", False),
    },
    "model_architecture": {
        "num_classes": NUM_NODE_CLASSES,
        "hidden": best_params.get("hidden"),
        "heads": best_params.get("heads"),
        "layers": best_params.get("layers"),
        "dropout": best_params.get("dropout"),
    },
}

metadata_path = checkpoints_dir / metadata_filename
with open(metadata_path, "w") as f:
    json.dump(metadata, f, indent=2)
print(f"Saved metadata to: {metadata_path}")

print(f"\nModel saved successfully!")
print(f"  Model: {model_path}")
print(f"  Metadata: {metadata_path}")
