# Generate Endpoint Predictions (Data Export)

Run the trained EndpointRegressor over the full dataset and save per-group endpoint quantiles to `data/upstream_preds`.

Outputs include:
- `pred_endpoints`: array of shape `[num_groups, 2, 3, Q]` (points start/end, coords xyz, quantiles q16/q50/q84)
- `event_id`, `group_id`
- `quantiles` used


In [1]:
import torch
import numpy as np
from pathlib import Path
from datetime import datetime

from pioneerml.zenml import utils as zenml_utils
from pioneerml.metadata import MetadataManager
from pioneerml.data import load_hits_and_info
from pioneerml.training.datamodules import EndpointDataModule

PROJECT_ROOT = zenml_utils.find_project_root()
metadata_manager = MetadataManager(root=PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")


Project root: /home/jack/python_projects/pioneerML


## List available checkpoints

In [2]:
checkpoints = metadata_manager.print_checkpoints("EndpointRegressor")
if not checkpoints:
    raise ValueError("No checkpoints found for EndpointRegressor")

print(f"Using checkpoint: {checkpoints[0]['checkpoint_path'].name}")


Found 2 checkpoint(s):
  1. endpointregressor_20260104_072603_endpoint_optuna_pipeline-2026_01_04-07_19_39_754371_state_dict.pt
     Timestamp:     20260104_072603
     Run:           endpoint_optuna_pipeline-2026_01_04-07_19_39_754371
     Architecture:  hidden=128, heads=8, layers=4, dropout=0.18402435315850102
  2. endpointregressor_20260103_124027_endpoint_optuna_pipeline-2026_01_03-12_33_04_260918_state_dict.pt
     Timestamp:     20260103_124027
     Run:           endpoint_optuna_pipeline-2026_01_03-12_33_04_260918
     Architecture:  hidden=128, heads=8, layers=4, dropout=0.18402435315850102
Using checkpoint: endpointregressor_20260104_072603_endpoint_optuna_pipeline-2026_01_04-07_19_39_754371_state_dict.pt


## Load model

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, metadata = metadata_manager.load_model(
    "EndpointRegressor",
    index=0,
    device=device,
)

print(f"Model loaded on {device}")
print(f"  Checkpoint: {checkpoints[0]['checkpoint_path'].name}")
print(f"  Timestamp: {metadata.timestamp}")
print(f"  Run: {metadata.run_name or 'unknown'}")


## Load data (full set)

In [None]:
hits_pattern = str(Path(PROJECT_ROOT) / "data" / "raw_hits_info" / "hits_batch_*.npy")
info_pattern = str(Path(PROJECT_ROOT) / "data" / "raw_hits_info" / "group_info_batch_*.npy")

# Optional: attach group probabilities
probs_dir = Path(PROJECT_ROOT) / "data" / "upstream_preds"
prob_files = sorted(probs_dir.glob("group_probs_*.npz"))
group_probs_path = str(prob_files[-1]) if prob_files else None
print(f"Using group_probs file: {group_probs_path}")

prob_dimension = 0
if group_probs_path:
    with np.load(group_probs_path) as npz:
        gp = npz["group_probs"]
        prob_dimension = int(gp.shape[1]) if gp.ndim == 2 else 0

validation_groups = load_hits_and_info(
    hits_pattern=hits_pattern,
    info_pattern=info_pattern,
    max_files=None,
    limit_groups=None,
    min_hits=2,
    include_hit_labels=False,
    verbose=True,
)

# Attach group_probs if available
if group_probs_path:
    with np.load(group_probs_path) as npz:
        gp = npz["group_probs"]
        ev = npz["event_id"]
        gi = npz["group_id"]
        lookup = {(int(e), int(g)): gp[idx] for idx, (e, g) in enumerate(zip(ev, gi))}
    attached = 0
    for rec in validation_groups:
        if rec.event_id is None or rec.group_id is None:
            continue
        key = (int(rec.event_id), int(rec.group_id))
        if key in lookup:
            rec.group_probs = lookup[key]
            attached += 1
    print(f"Attached group_probs to {attached}/{len(validation_groups)} groups")

print(f"Loaded {len(validation_groups)} groups for inference")

datamodule = EndpointDataModule(
    records=validation_groups,
    batch_size=128,
    num_workers=0,
    val_split=0.0,
    test_split=0.0,
    seed=42,
    num_quantiles=3,
    prob_dimension=prob_dimension,
)
datamodule.setup(stage="fit")
full_dataset = datamodule.train_dataset
print(f"Dataset size: {len(full_dataset)}")


## Run inference and save outputs

In [None]:
from torch_geometric.loader import DataLoader

loader = DataLoader(full_dataset, batch_size=128, shuffle=False, num_workers=0)
model.eval()
all_preds = []
event_ids = []
group_ids = []

with torch.no_grad():
    for batch in loader:
        batch = batch.to(device)
        preds = model(batch)  # [B, 2, 3, Q]
        all_preds.append(preds.detach().cpu())

        if hasattr(batch, "event_id"):
            event_ids.extend(batch.event_id.cpu().tolist())
        else:
            start = len(event_ids)
            event_ids.extend(list(range(start, start + preds.size(0))))
        if hasattr(batch, "group_id"):
            group_ids.extend(batch.group_id.cpu().tolist())
        else:
            start = len(group_ids)
            group_ids.extend(list(range(start, start + preds.size(0))))

pred_endpoints = torch.cat(all_preds, dim=0).numpy()  # [N, 2, 3, Q]
event_ids = np.array(event_ids, dtype=np.int64)
group_ids = np.array(group_ids, dtype=np.int64)
quantiles = np.array([0.16, 0.50, 0.84], dtype=np.float32)

print(f"Saved tensors shapes: pred_endpoints={pred_endpoints.shape}, event_ids={event_ids.shape}, group_ids={group_ids.shape}")

save_dir = Path(PROJECT_ROOT) / "data" / "upstream_preds"
save_dir.mkdir(parents=True, exist_ok=True)

ts = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = save_dir / f"endpoint_preds_{ts}.npz"
np.savez_compressed(
    save_path,
    pred_endpoints=pred_endpoints,
    event_id=event_ids,
    group_id=group_ids,
    quantiles=quantiles,
)
print(f"Wrote: {save_path}")
