# Event Splitter Event

Train the event-level edge-affinity splitter with classifier/splitter/endpoint priors.


In [None]:
from pathlib import Path
from pioneerml.common.evaluation.plots.loss import LossCurvesPlot
from pioneerml.common.zenml import load_step_output
from pioneerml.common.zenml import utils as zenml_utils
from pioneerml.pipelines.training import event_splitter_event_pipeline

PROJECT_ROOT = zenml_utils.find_project_root()
zenml_utils.setup_zenml_for_notebook(root_path=PROJECT_ROOT, use_in_memory=True)


In [None]:
# Build aligned inputs: main parquet + group classifier + group splitter + endpoint predictions
data_dir = Path(PROJECT_ROOT) / "data"
main_paths = sorted(data_dir.glob("ml_output_*.parquet"))
group_probs_dir = data_dir / "inference_outputs" / "group_classifier_event"
splitter_probs_dir = data_dir / "inference_outputs" / "group_splitter_event"
endpoint_dir = data_dir / "inference_outputs" / "endpoint_regressor_event"

def _pick_pred(path_dir: Path, main_path: Path) -> Path | None:
    latest = path_dir / main_path.name.replace(".parquet", "_preds_latest.parquet")
    if latest.exists():
        return latest
    return None

paired = []
missing = []
for main_path in main_paths:
    group_path = _pick_pred(group_probs_dir, main_path)
    splitter_path = _pick_pred(splitter_probs_dir, main_path)
    endpoint_path = _pick_pred(endpoint_dir, main_path)
    if group_path is not None and splitter_path is not None and endpoint_path is not None:
        paired.append((str(main_path), str(group_path), str(splitter_path), str(endpoint_path)))
    else:
        missing.append(str(main_path))

if not paired:
    raise RuntimeError(
        "No aligned main/group/splitter/endpoint parquet quartets found. "
        "Run prior-stage inference first."
    )
if missing:
    print(f"Warning: missing predictions for {len(missing)} shard(s); skipping those files.")

parquet_paths = [p[0] for p in paired]
group_probs_parquet_paths = [p[1] for p in paired]
group_splitter_parquet_paths = [p[2] for p in paired]
endpoint_parquet_paths = [p[3] for p in paired]
print(f"Using {len(parquet_paths)} shard(s) for event-splitter-event training.")


In [None]:
pipeline_cfg = {
    "loader": {
        "config_json": {
            "time_window_ns": 1.0,
            "use_group_probs": True,
            "use_splitter_probs": True,
            "use_endpoint_preds": True,
        },
        "normalize": True,
    },
    "hpo": {
        "enabled": True,
        "n_trials": 1,
        "max_epochs": 3,
        "storage": f"sqlite:///{PROJECT_ROOT}/.optuna/event_splitter_event_hpo.db",
    },
    "train": {"max_epochs": 3},
    "evaluate": {"threshold": 0.5, "batch_size": 1},
    "export": {
        "prefer_cuda": True,
        "export_dir": str(Path(PROJECT_ROOT) / "trained_models" / "event_splitter_event"),
    },
}

run = event_splitter_event_pipeline.with_options(enable_cache=False)(
    parquet_paths=parquet_paths,
    group_probs_parquet_paths=group_probs_parquet_paths,
    group_splitter_parquet_paths=group_splitter_parquet_paths,
    endpoint_parquet_paths=endpoint_parquet_paths,
    pipeline_config=pipeline_cfg,
)


In [None]:
# Run the pipeline
def _concise(values, limit: int = 10):
    values = list(values)
    return values[-limit:] if len(values) > limit else values

trained_module = load_step_output(run, "train_event_splitter_event")
hpo_params = load_step_output(run, "tune_event_splitter_event")
metrics = load_step_output(run, "evaluate_event_splitter_event")
export = load_step_output(run, "export_event_splitter_event")
print("hpo_params:", hpo_params)
if trained_module is not None:
    print("train_epoch_loss_history:", _concise(trained_module.train_epoch_loss_history))
    print("val_epoch_loss_history:", _concise(trained_module.val_epoch_loss_history))
print("metrics:", metrics)
print("export:", export)


In [None]:
# Plot loss curves
if trained_module is None:
    raise RuntimeError("No trained module loaded from pipeline run.")
LossCurvesPlot().render(trained_module, show=True)
