# Pion Stop Regression (ZenML)

Train the `PionStopRegressor` on real preprocessed pion time-group data:
- Load preprocessed pion groups from .npy files in the data folder
- Train the `PionStopRegressor` via a ZenML pipeline with Optuna hyperparameter tuning
- Reload artifacts and evaluate regression performance


In [None]:
from pathlib import Path
from pioneerml.zenml import load_step_output
from pioneerml.zenml import utils as zenml_utils
from pioneerml.zenml.pipelines.training import pion_stop_optuna_pipeline

PROJECT_ROOT = zenml_utils.find_project_root()
zenml_client = zenml_utils.setup_zenml_for_notebook(root_path=PROJECT_ROOT, use_in_memory=True)
print(f"ZenML ready with stack: {zenml_client.active_stack_model.name}")


In [None]:
# Configure the pipeline
# Use absolute path based on project root
file_pattern = str(Path(PROJECT_ROOT) / 'data' / 'mainTimeGroups_*.npy')
run = pion_stop_optuna_pipeline.with_options(enable_cache=False)(
    build_datamodule_params={
        # Data loading parameters
        'file_pattern': file_pattern,
        'pion_pdg': 1,
        'max_files': 10,
        'limit_groups': 100000,
        'min_hits': 3,
        'min_pion_hits': 3,
        # Datamodule parameters
        'use_true_time': True,
        'batch_size': 32,
        'num_workers': 0,
        'val_split': 0.15,
        'seed': 42,
    },
    run_hparam_search_params={
        'n_trials': 25,
        'max_epochs': 20,
        'limit_train_batches': 0.8,
        'limit_val_batches': 1.0,
    },
    train_best_model_params={
        'max_epochs': 50,
        'early_stopping': True,
        'early_stopping_patience': 6,
        'early_stopping_monitor': 'val_loss',
        'early_stopping_mode': 'min',
    },
)
print(f"Run name: {run.name}")
print(f"Run status: {run.status}")


In [None]:
trained_module = load_step_output(run, "train_best_pion_stop_regressor")
datamodule = load_step_output(run, "build_pion_stop_datamodule")
predictions = load_step_output(run, "collect_pion_stop_predictions", index=0)
targets = load_step_output(run, "collect_pion_stop_predictions", index=1)
best_params = load_step_output(run, "run_pion_stop_hparam_search")

if trained_module is None or datamodule is None:
    raise RuntimeError("Could not load artifacts from the optuna pipeline run.")

datamodule.setup(stage="fit")
trained_module.eval()
device = next(trained_module.parameters()).device
val_size = len(datamodule.val_dataset) if datamodule.val_dataset is not None else len(datamodule.train_dataset)
print(f"Loaded module on {device}; validation samples: {val_size}")
print("Best params from Optuna:", best_params)
print("Epochs actually run:", getattr(trained_module, "final_epochs_run", None))


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

# Compute Euclidean distances between predictions and targets
preds_np = predictions.detach().cpu().numpy() if hasattr(predictions, "detach") else np.asarray(predictions)
targets_np = targets.detach().cpu().numpy() if hasattr(targets, "detach") else np.asarray(targets)

distances = np.linalg.norm(preds_np - targets_np, axis=1)
mean_distance = float(distances.mean())
median_distance = float(np.median(distances))

print(f"Mean Euclidean distance: {mean_distance:.4f} mm")
print(f"Median Euclidean distance: {median_distance:.4f} mm")

# Plot distance histogram
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].hist(distances, bins=50, range=(0, 2.0), alpha=0.75, color='teal')
axes[0].set_xlabel('Euclidean error [mm]')
axes[0].set_ylabel('Counts')
axes[0].set_title('Pion stop prediction error (linear scale)')
axes[0].grid(True, linestyle='--', alpha=0.3)

axes[1].hist(distances, bins=50, range=(0, 2.0), alpha=0.75, color='teal')
axes[1].set_yscale('log')
axes[1].set_xlabel('Euclidean error [mm]')
axes[1].set_ylabel('Counts (log scale)')
axes[1].set_title('Pion stop prediction error (log scale)')
axes[1].grid(True, linestyle='--', alpha=0.3)

plt.tight_layout()
plt.show()

# Plot loss curves
from pioneerml.evaluation.plots import plot_loss_curves
plot_loss_curves(trained_module, title="Pion stop regression: loss", show=True)


## Save the Trained Model

Save the trained model and metadata for later use.


In [None]:
import torch
import json
import numpy as np
from datetime import datetime

# Create checkpoints directory
checkpoints_dir = Path(PROJECT_ROOT) / "artifacts" / "checkpoints" / "pion_stop"
checkpoints_dir.mkdir(parents=True, exist_ok=True)

# Generate a timestamped filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"pion_stop_{timestamp}.pt"
metadata_filename = f"pion_stop_{timestamp}_metadata.json"

# Extract the underlying model from the Lightning module
model = trained_module.model
model.eval()

# Save model state_dict
model_path = checkpoints_dir / model_filename
torch.save(model.state_dict(), model_path)
print(f"Saved model state_dict to: {model_path}")

# Compute evaluation metrics for metadata
preds_np = predictions.detach().cpu().numpy() if hasattr(predictions, "detach") else np.asarray(predictions)
targets_np = targets.detach().cpu().numpy() if hasattr(targets, "detach") else np.asarray(targets)
distances = np.linalg.norm(preds_np - targets_np, axis=1)
mean_distance = float(distances.mean())
median_distance = float(np.median(distances))

# Save metadata (hyperparameters, training info, etc.)
metadata = {
    "model_type": "PionStopRegressor",
    "timestamp": timestamp,
    "run_name": run.name,
    "best_hyperparameters": best_params,
    "training_config": getattr(trained_module, "training_config", {}),
    "epochs_run": getattr(trained_module, "final_epochs_run", None),
    "dataset_info": {
        "train_size": len(datamodule.train_dataset) if datamodule.train_dataset else 0,
        "val_size": len(datamodule.val_dataset) if datamodule.val_dataset else 0,
        "use_true_time": getattr(datamodule, "use_true_time", None),
    },
    "model_architecture": {
        "hidden": best_params.get("hidden"),
        "heads": best_params.get("heads"),
        "layers": best_params.get("layers"),
        "dropout": best_params.get("dropout"),
    },
    "evaluation_metrics": {
        "mean_euclidean_distance_mm": mean_distance,
        "median_euclidean_distance_mm": median_distance,
    },
}

metadata_path = checkpoints_dir / metadata_filename
with open(metadata_path, "w") as f:
    json.dump(metadata, f, indent=2)
print(f"Saved metadata to: {metadata_path}")

print(f"\nModel saved successfully!")
print(f"  Model: {model_path}")
print(f"  Metadata: {metadata_path}")
