# Generate Group Probabilities (Data Export)

Run the trained GroupClassifier over the full dataset and save per-group probabilities to `data/upstream_preds`.

Steps:
1. Load latest checkpoint.
2. Load raw hits/info data.
3. Run inference over all groups (no val split).
4. Save `group_probs`, `event_id`, `group_id`, and `class_names` for downstream use (e.g., endpoint regressor).


In [1]:
import torch
import numpy as np
from pathlib import Path
from datetime import datetime

from pioneerml.zenml import utils as zenml_utils
from pioneerml.metadata import MetadataManager
from pioneerml.data import load_hits_and_info, CLASS_NAMES, NUM_GROUP_CLASSES
from pioneerml.training.datamodules import GroupClassificationDataModule

PROJECT_ROOT = zenml_utils.find_project_root()
metadata_manager = MetadataManager(root=PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")


Project root: /home/jack/python_projects/pioneerML


## List available checkpoints

In [2]:
checkpoints = metadata_manager.print_checkpoints("GroupClassifier")
if not checkpoints:
    raise ValueError("No checkpoints found for GroupClassifier")

print(f"Using checkpoint: {checkpoints[0]['checkpoint_path'].name}")


Found 8 checkpoint(s):
  1. groupclassifier_20260102_120808_group_classification_optuna_pipeline-2026_01_02-12_02_53_414220_state_dict.pt
     Timestamp:     20260102_120808
     Run:           group_classification_optuna_pipeline-2026_01_02-12_02_53_414220
     Architecture:  hidden=128, dropout=0.1956153589905551, num_blocks=3
  2. groupclassifier_20251206_111624_group_classification_optuna_pipeline-2025_12_06-01_40_30_136745_state_dict.pt
     Timestamp:     20251206_111624
     Run:           group_classification_optuna_pipeline-2025_12_06-01_40_30_136745
     Architecture:  hidden=128, dropout=0.09417833792467833, num_blocks=2
  3. groupclassifier_20251206_011522_group_classification_optuna_pipeline-2025_12_06-01_14_57_970179_state_dict.pt
     Timestamp:     20251206_011522
     Run:           group_classification_optuna_pipeline-2025_12_06-01_14_57_970179
     Architecture:  hidden=256, dropout=0.17938345001656214, num_blocks=4
  4. groupclassifier_20251206_010818_group_classifi

## Load model

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, metadata = metadata_manager.load_model(
    "GroupClassifier",
    index=0,
    device=device,
)

print(f"Model loaded on {device}")
print(f"  Checkpoint: {checkpoints[0]['checkpoint_path'].name}")
print(f"  Timestamp: {metadata.timestamp}")
print(f"  Run: {metadata.run_name or 'unknown'}")


Model loaded on cuda
  Checkpoint: groupclassifier_20260102_120808_group_classification_optuna_pipeline-2026_01_02-12_02_53_414220_state_dict.pt
  Timestamp: 20260102_120808
  Run: group_classification_optuna_pipeline-2026_01_02-12_02_53_414220


## Load data (full set)

In [4]:
hits_pattern = str(Path(PROJECT_ROOT) / "data" / "raw_hits_info" / "hits_batch_*.npy")
info_pattern = str(Path(PROJECT_ROOT) / "data" / "raw_hits_info" / "group_info_batch_*.npy")

groups = load_hits_and_info(
    hits_pattern=hits_pattern,
    info_pattern=info_pattern,
    max_files=None,
    limit_groups=None,
    min_hits=2,
    include_hit_labels=False,
    verbose=True,
)
print(f"Loaded {len(groups)} groups")

num_classes = metadata.dataset_info.get('num_classes', NUM_GROUP_CLASSES) if metadata else NUM_GROUP_CLASSES
datamodule = GroupClassificationDataModule(
    records=groups,
    num_classes=num_classes,
    batch_size=128,
    num_workers=0,
    val_split=0.0,
    test_split=0.0,
    seed=42,
)
datamodule.setup(stage="fit")
full_dataset = datamodule.train_dataset
print(f"Dataset size: {len(full_dataset)}")


Loaded 109817 groups from 11 file pairs


Loaded 109817 groups
Dataset size: 109817


## Run inference and save outputs

In [5]:
from torch_geometric.loader import DataLoader

loader = DataLoader(full_dataset, batch_size=128, shuffle=False, num_workers=0)
model.eval()
all_probs = []
event_ids = []
group_ids = []

with torch.no_grad():
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)  # [B, num_classes]
        probs = torch.sigmoid(logits).detach().cpu()
        all_probs.append(probs)

        if hasattr(batch, "event_id"):
            event_ids.extend(batch.event_id.cpu().tolist())
        else:
            start = len(event_ids)
            event_ids.extend(list(range(start, start + probs.size(0))))
        if hasattr(batch, "group_id"):
            group_ids.extend(batch.group_id.cpu().tolist())
        else:
            start = len(group_ids)
            group_ids.extend(list(range(start, start + probs.size(0))))

all_probs = torch.cat(all_probs, dim=0).numpy()
event_ids = np.array(event_ids, dtype=np.int64)
group_ids = np.array(group_ids, dtype=np.int64)

print(f"Saved tensors shapes: group_probs={all_probs.shape}, event_ids={event_ids.shape}, group_ids={group_ids.shape}")

save_dir = Path(PROJECT_ROOT) / "data" / "upstream_preds"
save_dir.mkdir(parents=True, exist_ok=True)

ts = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = save_dir / f"group_probs_{ts}.npz"
np.savez_compressed(
    save_path,
    group_probs=all_probs,
    event_id=event_ids,
    group_id=group_ids,
    class_names=np.array(CLASS_NAMES),
)
print(f"Wrote: {save_path}")


Saved tensors shapes: group_probs=(109817, 3), event_ids=(109817,), group_ids=(109817,)
Wrote: /home/jack/python_projects/pioneerML/data/upstream_preds/group_probs_20260104_014641.npz
