In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [3]:
import numpy as np
import pandas as pd
import tensorflow as tf

from src.particle_net.dataset import PointDataset, create_data_loader
from src.particle_net.model import get_particle_net
from src.utils import get_logger


def load_model(path_to_checkpoint, num_classes=2, max_constits=80, num_features=7):
    """Load ParticleNet model and restore weights from TF checkpoint directory."""
    input_shapes = {"features": (max_constits, num_features), "points": (max_constits, 2)}
    model = get_particle_net(num_classes, input_shapes)

    checkpoint = tf.train.Checkpoint(model=model)
    latest = tf.train.latest_checkpoint(path_to_checkpoint)
    if latest is None:
        raise FileNotFoundError(f"No checkpoint found in {path_to_checkpoint}")
    checkpoint.restore(latest).expect_partial()
    return model


def infer_with_particle_net(path_to_checkpoint, path_to_dataset, batch_size=128, max_constits=80, num_classes=2):
    """Run inference with ParticleNet and return a pandas DataFrame with predictions and targets."""
    logger = get_logger("inference")

    # Build dataset and loader
    dataset = PointDataset(path_to_dataset, max_constits=max_constits, max_jets=1000)
    loader = create_data_loader(dataset, batch_size=batch_size)

    # Infer number of features from dataset
    num_features = dataset.num_features

    # Load model
    model = load_model(path_to_checkpoint, num_classes=num_classes, max_constits=max_constits, num_features=num_features)

    preds, targets = [], []
    for idx, (inputs, labels, weights) in enumerate(loader):
        logger.info(f"Processing batch {idx}")
        outputs = model(inputs, training=False)
        preds.append(outputs.numpy())
        targets.append(labels.numpy())

    preds_all = np.concatenate(preds, axis=0)
    targets_all = np.concatenate(targets, axis=0)

    # If classification, return probs and predicted label
    if preds_all.ndim == 2 and preds_all.shape[1] > 1:
        pred_labels = preds_all.argmax(axis=1)
        df = pd.DataFrame(preds_all, columns=[f"prob_{i}" for i in range(preds_all.shape[1])])
        df["pred"] = pred_labels
        df["target"] = targets_all.argmax(axis=1)  # one-hot to index
    else:
        # Regression or single output
        df = pd.DataFrame({
            "pred": preds_all.flatten(),
            "target": targets_all.flatten(),
        })

    return df


In [4]:
df = infer_with_particle_net("../checkpoints/particle_net", "../data/test-preprocessed.h5")

2025-09-16 18:10:28.286290: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M4
2025-09-16 18:10:28.286313: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-09-16 18:10:28.286317: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
I0000 00:00:1758057028.286331  412004 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1758057028.286350  412004 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2025-09-16 18:10:28 - inference - INFO - Processing batch 0
2025-09-16 18:10:29 - inference - INFO - Processing batch 1
2025-09-16 18:10:29 - inference - INFO - Processing batch 2
2025-09-16 18:10:29 - inference - INFO - Processing batch 3
2025-09-16 18:10:29 - infe

In [5]:
df

Unnamed: 0,prob_0,prob_1,pred,target
0,0.755441,0.244559,0,1
1,0.050219,0.949781,1,1
2,0.855395,0.144605,0,1
3,0.051961,0.948039,1,1
4,0.987510,0.012490,0,0
...,...,...,...,...
995,0.890901,0.109099,0,0
996,0.955186,0.044814,0,0
997,0.351783,0.648217,1,0
998,0.108467,0.891533,1,1
