# Group Splitter Inference

In [None]:
from pathlib import Path

import pyarrow.parquet as pq

from pioneerml.common.zenml import load_step_output
from pioneerml.common.zenml import utils as zenml_utils
from pioneerml.pipelines.inference.group_splitting import group_splitting_inference_pipeline

PROJECT_ROOT = zenml_utils.find_project_root()
zenml_utils.setup_zenml_for_notebook(root_path=PROJECT_ROOT, use_in_memory=True)

In [None]:
# Inputs

def _pick_pred(group_pred_dir: Path, main_path: Path) -> Path | None:
    latest = group_pred_dir / f"{main_path.stem}_preds_latest.parquet"
    if latest.exists():
        return latest
    alt = group_pred_dir / f"{main_path.stem}_preds.parquet"
    if alt.exists():
        return alt
    return None

# Main event files
main_dir = Path(PROJECT_ROOT) / "data"
main_paths = sorted(main_dir.glob("ml_output_*.parquet"))

# Example: uncomment to use fewer files
# main_paths = main_paths[:1]

# Group-classifier priors
group_pred_dir = Path(PROJECT_ROOT) / "data" / "group_classifier"
paired = []
for mp in main_paths:
    gp = _pick_pred(group_pred_dir, mp)
    if gp is not None:
        paired.append((str(mp.resolve()), str(gp.resolve())))

if not paired:
    raise RuntimeError(
        "No aligned main/group-classifier prediction pairs found. "
        "Run group-classifier inference first."
    )

parquet_paths = [p[0] for p in paired]
group_probs_parquet_paths = [p[1] for p in paired]
model_path = None  # default: latest trained model
output_dir = str((Path(PROJECT_ROOT) / "data" / "group_splitter").resolve())

print(f"Input files: {len(parquet_paths)}")
print(f"Group prior files: {len(group_probs_parquet_paths)}")

In [None]:
# Run inference pipeline
run = group_splitting_inference_pipeline.with_options(enable_cache=False)(
    parquet_paths=parquet_paths,
    group_probs_parquet_paths=group_probs_parquet_paths,
    model_path=model_path,
    output_dir=output_dir,
    pipeline_config={
        "loader": {
            "config_json": {
                "mode": "inference",
                "batch_size": 64,
                "chunk_row_groups": 4,
                "chunk_workers": 0,
                "use_group_probs": True,
            }
        },
        "inference": {"threshold": 0.5},
        "export": {"check_accuracy": False, "write_timestamped": False},
    },
)

export_info = load_step_output(run, "export_group_splitter_predictions")
print(export_info)

In [None]:
# Inspect export outputs
pred_path = Path(export_info["predictions_path"])
metrics_path = Path(export_info["metrics_path"])
print("predictions:", pred_path)
print("metrics:", metrics_path)
print(metrics_path.read_text())

In [None]:
# Optional: verify parquet columns
tbl = pq.read_table(pred_path)
print(tbl.schema)
print(tbl.slice(0, 3))