# Endpoint Regressor (ZenML)

Train the `EndpointRegressor` (stereo, quantile endpoints) on raw hit graphs. Mirrors the group classifier notebook: configure ZenML, set up Optuna, run the pipeline, inspect best params, save model/metadata, and run quick diagnostics.

In [1]:
from pioneerml.zenml import utils as zenml_utils

PROJECT_ROOT = zenml_utils.find_project_root()
zenml_client = zenml_utils.setup_zenml_for_notebook(root_path=PROJECT_ROOT, use_in_memory=True)
print(f"ZenML ready with stack: {zenml_client.active_stack_model.name}")


Using ZenML repository root: /home/jack/python_projects/pioneerML
Ensure this is the top-level of your repo (.zen must live here).
ZenML ready with stack: default


In [2]:
import optuna
from pathlib import Path
from pioneerml.zenml import load_step_output
from pioneerml.zenml.pipelines.training import endpoint_optuna_pipeline
from pioneerml.optuna import OptunaStudyManager

# Set up Optuna storage
optuna_manager = OptunaStudyManager(
    project_root=PROJECT_ROOT,
    study_name="endpoint_regressor",
)
optuna_storage = optuna_manager.resolve_storage()
print(f"Using Optuna storage: {optuna_storage}")


Using Optuna storage: sqlite:////home/jack/python_projects/pioneerML/.optuna/endpoint_regressor.db


In [3]:
# Run (or reuse) the Optuna + training pipeline

hits_pattern = str(Path(PROJECT_ROOT) / "data" / "raw_hits_info" / "hits_batch_*.npy")
info_pattern = str(Path(PROJECT_ROOT) / "data" / "raw_hits_info" / "group_info_batch_*.npy")

# Optional: attach group probabilities from upstream classifier (batched files)
probs_pattern = str(Path(PROJECT_ROOT) / "data" / "upstream_preds" / "group_probs_batch_*.npz")
print(f"Using group_probs pattern: {probs_pattern}")

run = endpoint_optuna_pipeline.with_options(enable_cache=False)(
    build_datamodule_params={
        "hits_pattern": hits_pattern,
        "info_pattern": info_pattern,
        "group_probs_pattern": probs_pattern,  # batched NPZs, aligned to hits/info batches
        "max_files": 20,
        "limit_groups": 1_000_000,
        "min_hits": 2,
        "batch_size": 64,
        "val_split": 0.15,
        "seed": 42,
    },
    run_hparam_search_params={
        "n_trials": 1,
        "max_epochs": 1,
        "limit_train_batches": 0.8,
        "limit_val_batches": 1.0,
        "storage": optuna_storage,
        "study_name": "endpoint_regressor",
    },
    train_best_model_params={
        "max_epochs": 1,
        "early_stopping": True,
        "early_stopping_patience": 4,
    },
)
print(f"Run name: {run.name}")
print(f"Run status: {run.status}")

Using group_probs pattern: /home/jack/python_projects/pioneerML/data/upstream_preds/group_probs_batch_*.npz
[37mInitiating a new run for the pipeline: [0m[38;5;105mendpoint_optuna_pipeline[37m.[0m
[37mCaching is disabled by default for [0m[38;5;105mendpoint_optuna_pipeline[37m.[0m
[37mUsing user: [0m[38;5;105mdefault[37m[0m
[37mUsing stack: [0m[38;5;105mdefault[37m[0m
[37m  artifact_store: [0m[38;5;105mdefault[37m[0m
[37m  deployer: [0m[38;5;105mdefault[37m[0m
[37m  orchestrator: [0m[38;5;105mdefault[37m[0m
[37mYou can visualize your pipeline runs in the [0m[38;5;105mZenML Dashboard[37m. In order to try it locally, please run [0m[38;5;105mzenml login --local[37m.[0m
[37mStep [0m[38;5;105mbuild_endpoint_datamodule[37m has started.[0m


[build_endpoint_datamodule] Auto-detected num_workers: 11 (from 12 CPU cores, using cores-1)
[build_endpoint_datamodule] Starting to load data from: hits=/home/jack/python_projects/pioneerML/data/raw_hits_info/hits_batch_*.npy, info=/home/jack/python_projects/pioneerML/data/raw_hits_info/group_info_batch_*.npy
[build_endpoint_datamodule] Limiting to 11 files (from 11 total files found, max_files=20)
[build_endpoint_datamodule] Limiting to 11 files (from 11 total files found, max_files=20)
[build_endpoint_datamodule] Loaded 109817 groups from 11 file pairs
[build_endpoint_datamodule] Loaded 109817 groups. Building datamodule...
[build_endpoint_datamodule] Setup complete. Train: 93345, Val: 16472


[37mStep [0m[38;5;105mbuild_endpoint_datamodule[37m has finished in [0m[38;5;105m25.431s[37m.[0m
[37mStep [0m[38;5;105mrun_endpoint_hparam_search[37m has started.[0m


[I 2026-01-04 04:29:20,866] Using an existing study with name 'endpoint_regressor' instead of creating a new one.


  0%|          | 0/1 [00:00<?, ?it/s]

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
You are using a CUDA device ('NVIDIA GeForce RTX 5070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[W 2026-01-04 04:29:24,640] Trial 6 failed with parameters: {'batch_size': 64, 'hidden': 160, 'heads': 2, 'layers': 3, 'dropout': 0.06001435441815274, 'lr': 0.004995134115523524, 'weight_decay': 9.74486608633666e-06} because of the following error: KeyError(Caught KeyError in DataLoader worker process 9.
Original Traceback (most recent call last):
  File "/home/jack/miniconda3/envs/pioneerml/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/miniconda3/envs/pioneerml/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/miniconda3/envs/pioneerml/lib/python3.11/site-packages/torch_geometric/loader/dataloader.py", line 27, in __call__
    return Batch.from_data_list(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jack/mi

In [None]:
# Load artifacts and best params
trained_module = load_step_output(run, "train_best_endpoint")
datamodule = load_step_output(run, "build_endpoint_datamodule")
predictions = load_step_output(run, "collect_endpoint_predictions", index=0)
targets = load_step_output(run, "collect_endpoint_predictions", index=1)
best_params = load_step_output(run, "run_endpoint_hparam_search")

if trained_module is None or datamodule is None:
    raise RuntimeError("Could not load artifacts from the optuna pipeline run.")

datamodule.setup(stage="fit")
trained_module.eval()
device = next(trained_module.parameters()).device
val_size = len(datamodule.val_dataset) if datamodule.val_dataset is not None else len(datamodule.train_dataset)
print(f"Loaded module on {device}; validation samples: {val_size}")

best_params_display = {k: v for k, v in best_params.items() if k != "trial_history"} if isinstance(best_params, dict) else best_params
print("Best params from Optuna:", best_params_display)
print("Epochs actually run:", getattr(trained_module, "final_epochs_run", None))


In [None]:
# Quick validation plots
from pioneerml.evaluation.plots import plot_loss_curves, plot_regression_diagnostics

plot_loss_curves(trained_module, title="Endpoint regressor: loss", show=True)
plot_regression_diagnostics(predictions=predictions, targets=targets, show=True)


In [None]:
# Save trained model + metadata
from pioneerml.metadata import TrainingMetadata, save_model_and_metadata, timestamp_now

save_ts = timestamp_now()

meta = TrainingMetadata(
    model_type="EndpointRegressor",
    timestamp=save_ts,
    run_name=run.name if 'run' in locals() else None,
    best_hyperparameters=best_params,
    best_score=best_params.get('best_score') if isinstance(best_params, dict) else None,
    n_trials=best_params.get('n_trials') if isinstance(best_params, dict) else None,
    training_config=getattr(trained_module, 'training_config', {}),
    epochs_run=getattr(trained_module, 'final_epochs_run', None),
    dataset_info={
        'train_size': len(datamodule.train_dataset) if datamodule.train_dataset else 0,
        'val_size': len(datamodule.val_dataset) if datamodule.val_dataset else 0,
    },
    model_architecture={
        'hidden': best_params.get('hidden') if isinstance(best_params, dict) else None,
        'heads': best_params.get('heads') if isinstance(best_params, dict) else None,
        'layers': best_params.get('layers') if isinstance(best_params, dict) else None,
        'dropout': best_params.get('dropout') if isinstance(best_params, dict) else None,
    },
    optuna_storage=optuna_storage,
    optuna_study_name=optuna_manager.study_name,
)

paths = save_model_and_metadata(
    model=trained_module.model,
    metadata=meta,
    state_dict_only=True,
)

print("Saved artifacts:")
for k, v in paths.items():
    print(f"  {k}: {v}")
