# Group Splitter Inference

In [1]:
from pathlib import Path

import pyarrow.parquet as pq

from pioneerml.common.zenml import load_step_output
from pioneerml.common.zenml import utils as zenml_utils
from pioneerml.pipelines.inference.group_splitting import group_splitting_inference_pipeline

PROJECT_ROOT = zenml_utils.find_project_root()
zenml_utils.setup_zenml_for_notebook(root_path=PROJECT_ROOT, use_in_memory=True)

Using ZenML repository root: /workspace
Ensure this is the top-level of your repo (.zen must live here).


<zenml.client.Client at 0x71199bb43970>

In [2]:
# Inputs

def _pick_pred(group_pred_dir: Path, main_path: Path) -> Path | None:
    primary = group_pred_dir / f"{main_path.stem}_preds.parquet"
    if primary.exists():
        return primary
    latest = group_pred_dir / f"{main_path.stem}_preds_latest.parquet"
    if latest.exists():
        return latest
    return None

# Main event files
main_dir = Path(PROJECT_ROOT) / "data"
main_paths = sorted(main_dir.glob("ml_output_*.parquet"))

# Example: uncomment to use fewer files
main_paths = main_paths[:1]

# Group-classifier priors
group_pred_dir = Path(PROJECT_ROOT) / "data" / "group_classifier"
paired = []
for mp in main_paths:
    gp = _pick_pred(group_pred_dir, mp)
    if gp is not None:
        paired.append((str(mp.resolve()), str(gp.resolve())))

if not paired:
    raise RuntimeError(
        "No aligned main/group-classifier prediction pairs found. "
        "Run group-classifier inference first."
    )

parquet_paths = [p[0] for p in paired]
group_probs_parquet_paths = [p[1] for p in paired]
model_path = None  # default: latest trained model
output_dir = str((Path(PROJECT_ROOT) / "data" / "group_splitter").resolve())

print(f"Input files: {len(parquet_paths)}")
print(f"Group prior files: {len(group_probs_parquet_paths)}")

Input files: 1
Group prior files: 1


In [3]:
# Run inference pipeline
run = group_splitting_inference_pipeline.with_options(enable_cache=False)(
    parquet_paths=parquet_paths,
    group_probs_parquet_paths=group_probs_parquet_paths,
    model_path=model_path,
    output_dir=output_dir,
    pipeline_config={
        "loader": {
            "config_json": {
                "mode": "inference",
                "batch_size": 64,
                "chunk_row_groups": 4,
                "chunk_workers": 0,
                "use_group_probs": True,
            }
        },
        "inference": {"threshold": 0.5},
        "save_predictions": {"check_accuracy": False, "write_timestamped": False},
    },
)

export_info = load_step_output(run, "save_group_splitter_predictions")
print(export_info)

[37mInitiating a new run for the pipeline: [0m[38;5;105mgroup_splitting_inference_pipeline[37m.[0m
[37mCaching is disabled by default for [0m[38;5;105mgroup_splitting_inference_pipeline[37m.[0m
[37mUsing user: [0m[38;5;105mdefault[37m[0m
[37mUsing stack: [0m[38;5;105mdefault[37m[0m
[37m  deployer: [0m[38;5;105mdefault[37m[0m
[37m  artifact_store: [0m[38;5;105mdefault[37m[0m
[37m  orchestrator: [0m[38;5;105mdefault[37m[0m
[37mYou can visualize your pipeline runs in the [0m[38;5;105mZenML Dashboard[37m. In order to try it locally, please run [0m[38;5;105mzenml login --local[37m.[0m
[37mStep [0m[38;5;105mload_group_splitter_inference_inputs[37m has started.[0m
[37mStep [0m[38;5;105mload_group_splitter_inference_inputs[37m has finished in [0m[38;5;105m0.645s[37m.[0m
[37mStep [0m[38;5;105mload_group_splitter_model[37m has started.[0m
[37mStep [0m[38;5;105mload_group_splitter_model[37m has finished in [0m[38;5;105m0.107s[37m

In [4]:
# Inspect export outputs
predictions_paths = [Path(p) for p in (export_info.get("predictions_paths") or [])]
if not predictions_paths and export_info.get("predictions_path"):
    predictions_paths = [Path(export_info["predictions_path"])]
metrics_path = Path(export_info["metrics_path"])

print("predictions_paths:")
for p in predictions_paths:
    print(" ", p)
print("metrics:", metrics_path)
print(metrics_path.read_text())

predictions_paths:
  /workspace/data/group_splitter/ml_output_000_preds.parquet
metrics: /workspace/data/group_splitter/metrics_latest.json
{
  "accuracy": null,
  "confusion": null,
  "exact_match": null,
  "loss": null,
  "mode": "group_splitter",
  "model_path": "/workspace/trained_models/groupsplitter/groupsplitter_20260218_231744_torchscript.pt",
  "output_path": "/workspace/data/group_splitter/ml_output_000_preds.parquet",
  "output_paths": [
    "/workspace/data/group_splitter/ml_output_000_preds.parquet"
  ],
  "threshold": 0.5,
  "validated_files": [
    "/workspace/data/ml_output_000.parquet"
  ],
  "validated_group_probs_files": [
    "/workspace/data/group_classifier/ml_output_000_preds.parquet"
  ]
}


In [5]:
# Optional: verify parquet schema + small sample (avoids loading full file)
import gc
import pyarrow as pa
import pyarrow.parquet as pq

if not predictions_paths:
    raise RuntimeError("No prediction parquet files were exported.")

pf = pq.ParquetFile(predictions_paths[0])
print("file:", predictions_paths[0])
print("rows:", pf.metadata.num_rows)
print(pf.schema_arrow)

if pf.num_row_groups > 0:
    sample = pf.read_row_group(0).slice(0, 3)
    print(sample)
else:
    sample = None
    print("No row groups found.")

# Release notebook-held references after inspection
del sample, pf
gc.collect()
pa.default_memory_pool().release_unused()


file: /workspace/data/group_splitter/ml_output_000_preds.parquet
rows: 1024
event_id: int64
time_group_ids: list<element: int64>
  child 0, element: int64
pred_hit_pion: list<element: float>
  child 0, element: float
pred_hit_muon: list<element: float>
  child 0, element: float
pred_hit_mip: list<element: float>
  child 0, element: float
pyarrow.Table
event_id: int64
time_group_ids: list<element: int64>
  child 0, element: int64
pred_hit_pion: list<element: float>
  child 0, element: float
pred_hit_muon: list<element: float>
  child 0, element: float
pred_hit_mip: list<element: float>
  child 0, element: float
----
event_id: [[0,1,2]]
time_group_ids: [[[0,0,0,0,0,...,2,2,2,2,2],[0,0,0,0,0,...,1,1,1,1,1],[0,0,0,0,0,...,4,4,4,4,4]]]
pred_hit_pion: [[[0.9997389,0.99739504,0.99755585,0.99986935,0.9996674,...,0.000056134355,0.000069731126,0.0000026965904,0.0000020997734,0.00006676184],[0.9998197,0.9999306,0.9997559,0.9997441,0.99968946,...,0.0007802801,0.00077271904,0.0012621538,0.000577956