In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [8]:
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import DataLoader

from src.bnn.dataset import TabularDataset
from src.bnn.model import BNN
from src.utils import get_device, get_logger

In [9]:
def mc_inference(model, x, n_samples=50):
    model.train()  # keep dropout active
    for m in model.modules():
        if isinstance(m, nn.BatchNorm1d):
            m.eval()
    with torch.no_grad():
        preds = torch.stack([model(x) for _ in range(n_samples)])
    return preds.mean(0), preds.var(0)


def load_model(path_to_checkpoint, input_dim, device):
    model = BNN(input_dim)
    checkpoint = torch.load(path_to_checkpoint, map_location=device)
    state = checkpoint["model_state"] if "model_state" in checkpoint else checkpoint
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model


def infer_with_bnn(path_to_checkpoint, path_to_dataset, n_samples=10, batch_size=128):
    device = "cpu"
    logger = get_logger("inference")
    dataset = TabularDataset(path_to_dataset, max_jets=10_000)
    x0, y0, *_ = dataset[0]
    input_dim = x0.numel()

    model = load_model(path_to_checkpoint, input_dim, device)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    means, vars_, targets = [], [], []
    for idx, batch in enumerate(loader):
        logger.info(f"Processing batch {idx}")
        x, y = batch[0], batch[1]
        x = x.to(device).view(x.size(0), -1)
        mean, var = mc_inference(model, x, n_samples)
        means.append(mean.detach().cpu())
        vars_.append(var.detach().cpu())
        targets.append(y.view(-1).detach().cpu())

    mean_all = torch.cat(means).detach().numpy()
    var_all = torch.cat(vars_).detach().numpy()
    target_all = torch.cat(targets).detach().numpy()

    df = pd.DataFrame({
        "mean": mean_all.flatten(),
        "var": var_all.flatten(),
        "target": target_all.flatten(),
    })
    return df


In [10]:
df = infer_with_bnn("../checkpoints/bnn/best_model.pt", "../data/test-preprocessed.h5")

2025-09-16 18:00:57 - inference - INFO - Processing batch 0
2025-09-16 18:00:57 - inference - INFO - Processing batch 0
2025-09-16 18:00:57 - inference - INFO - Processing batch 1
2025-09-16 18:00:57 - inference - INFO - Processing batch 1
2025-09-16 18:00:57 - inference - INFO - Processing batch 2
2025-09-16 18:00:57 - inference - INFO - Processing batch 2
2025-09-16 18:00:58 - inference - INFO - Processing batch 3
2025-09-16 18:00:58 - inference - INFO - Processing batch 3
2025-09-16 18:00:58 - inference - INFO - Processing batch 4
2025-09-16 18:00:58 - inference - INFO - Processing batch 4
2025-09-16 18:00:58 - inference - INFO - Processing batch 5
2025-09-16 18:00:58 - inference - INFO - Processing batch 5
2025-09-16 18:00:58 - inference - INFO - Processing batch 6
2025-09-16 18:00:58 - inference - INFO - Processing batch 6
2025-09-16 18:00:58 - inference - INFO - Processing batch 7
2025-09-16 18:00:58 - inference - INFO - Processing batch 7
2025-09-16 18:00:58 - inference - INFO -

In [11]:
df

Unnamed: 0,mean,var,target
0,1.287266,0.124149,0.0
1,-4.286171,0.278319,0.0
2,-1.772324,0.649490,0.0
3,-1.106961,0.070902,0.0
4,1.103687,0.055539,1.0
...,...,...,...
9995,-3.027663,0.229017,0.0
9996,0.891899,0.229666,1.0
9997,-1.971348,0.087494,0.0
9998,1.525557,0.140287,1.0
