In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [3]:
import torch
import pandas as pd
from torch.utils.data import DataLoader

from src.resnet50.dataset import ImageDataset
from src.resnet50.model import ResNet, Bottleneck
from src.utils import get_device, get_logger

In [4]:
def load_model(path_to_checkpoint, device):
    model = ResNet(Bottleneck, [3, 4, 6, 3], dropout_p=0.5)
    checkpoint = torch.load(path_to_checkpoint, map_location=device)
    state = checkpoint["model_state"] if "model_state" in checkpoint else checkpoint
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model


def infer_with_resnet(path_to_checkpoint, path_to_dataset, batch_size=128):
    device = "cpu"
    logger = get_logger("inference")

    dataset = ImageDataset(path_to_dataset, max_jets=10000)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    preds, targets = [], []
    with torch.no_grad():
        model = load_model(path_to_checkpoint, device)
        for idx, batch in enumerate(loader):
            logger.info(f"Processing batch {idx}")
            x, y = batch[0], batch[1]
            x = x.to(device)
            out = model(x)
            preds.append(out.detach().cpu())
            targets.append(y.view(-1).detach().cpu())

    preds_all = torch.cat(preds).detach().numpy()
    targets_all = torch.cat(targets).detach().numpy()

    # Handle multi-class outputs: softmax for probabilities
    if preds_all.ndim == 2 and preds_all.shape[1] > 1:
        probs = torch.nn.functional.softmax(torch.tensor(preds_all), dim=1).numpy()
        pred_labels = probs.argmax(axis=1)
        df = pd.DataFrame(probs, columns=[f"prob_{i}" for i in range(probs.shape[1])])
        df["pred"] = pred_labels
        df["target"] = targets_all
    else:
        # Regression or binary classification
        df = pd.DataFrame({
            "pred": preds_all.flatten(),
            "target": targets_all.flatten(),
        })

    return df


In [5]:
df = infer_with_resnet("../checkpoints/resnet50/best_model.pt", "../data/test-preprocessed.h5")

2025-09-16 18:00:01 - inference - INFO - Processing batch 0
2025-09-16 18:00:01 - inference - INFO - Processing batch 1
2025-09-16 18:00:01 - inference - INFO - Processing batch 2
2025-09-16 18:00:02 - inference - INFO - Processing batch 3
2025-09-16 18:00:02 - inference - INFO - Processing batch 4
2025-09-16 18:00:02 - inference - INFO - Processing batch 5
2025-09-16 18:00:02 - inference - INFO - Processing batch 6
2025-09-16 18:00:03 - inference - INFO - Processing batch 7
2025-09-16 18:00:03 - inference - INFO - Processing batch 8
2025-09-16 18:00:03 - inference - INFO - Processing batch 9
2025-09-16 18:00:03 - inference - INFO - Processing batch 10
2025-09-16 18:00:03 - inference - INFO - Processing batch 11
2025-09-16 18:00:04 - inference - INFO - Processing batch 12
2025-09-16 18:00:04 - inference - INFO - Processing batch 13
2025-09-16 18:00:04 - inference - INFO - Processing batch 14
2025-09-16 18:00:04 - inference - INFO - Processing batch 15
2025-09-16 18:00:04 - inference - 

In [10]:
df

Unnamed: 0,pred,target
0,-2.330593,0.0
1,-3.459027,0.0
2,-2.031495,0.0
3,-2.542551,0.0
4,-1.899415,1.0
...,...,...
9995,-2.571396,0.0
9996,-1.797166,1.0
9997,-4.244915,0.0
9998,-1.976189,1.0
