This notebook is used for evaluating the trained models on different splis and visualizing the batch effect

In [1]:
import sys
import os

src_path = os.path.split(os.getcwd())[0]
sys.path.insert(0, src_path)

import numpy as np
import torch
import json
from training.model import ResNet
from training.data import JUMPCPDataset
from tqdm import tqdm
from torch.utils.data import DataLoader
import pandas as pd


In [2]:
#checkpoint_path = "/system/user/studentwork/seibezed/bachelor/logs/imgres=499_lr=0.001_wd=0.1_agg=True_model=ResNet50_world_size=4batchsize=32_workers=8_date=2025-10-27-19-55-19__fold4/checkpoints/epoch_42.pt"
cross_validation = True
folds = 5
epochs = [42, 41, 42, 40, 42]
checkpoint_path = [f"/system/user/studentwork/seibezed/bachelor/logs/imgres=499_lr=0.001_wd=0.1_agg=True_model=ResNet50_world_size=4batchsize=32_workers=8_date=2025-10-27-19-55-19__fold{fold}/checkpoints/epoch_{epoch}.pt" for fold, epoch in zip(range(folds),epochs)]
model = "ResNet50"
test_file = "/system/user/studentwork/seibezed/bachelor/data/random_seed1234_test.csv"
mapping = "/system/user/studentwork/seibezed/bachelor/data/class_mapping_seed1234.json"
image_path = "/system/user/publicdata/jumpcp/"


In [3]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()

def load_model(checkpoint_path, device, model):
    checkpoint = torch.load(checkpoint_path)
    state_dict = checkpoint["state_dict"]

    model_config_file = os.path.join(src_path, f"training/model_parameter/{model.replace('/', '-')}.json")
    assert os.path.exists(model_config_file)
    with open(model_config_file, 'r') as f:
        model_info = json.load(f)
    model = ResNet(**model_info)

    if str(device) == "cpu":
        model.float()
    print(device)

    new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}

    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval()

    return model


In [4]:
def main(test_file, model_path, model, img_path, mapping_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(torch.cuda.device_count())

    model = load_model(model_path, device, model)

    test_data = JUMPCPDataset(test_file, img_path, mapping_path)

    all_predictions = []
    all_labels = []
    all_metadata = []
    with torch.no_grad():
        for batch in tqdm(DataLoader(test_data, num_workers=20, batch_size=64)):
            imgs, labels,_, metadata = batch
            imgs = imgs.to(device)
            labels = labels.to(device)

            predictions = model(imgs)

            all_predictions.append(predictions)
            all_labels.append(labels)
            all_metadata.append(metadata)

    #flat_metadata = [item for batch in all_metadata for item in batch]
    flat_metadata = []
    for batch_meta in all_metadata:  # each is a dict of lists
        (batch_key, batch_items), (source_key, source_items) = batch_meta.items()
        for i in range(len(batch_items)):
            flat_metadata.append({batch_key: batch_items[i], source_key: source_items[i]})

    return torch.cat(all_predictions), torch.cat(all_labels), pd.DataFrame(flat_metadata)



In [5]:
data = []
if cross_validation:
    for cpath in checkpoint_path:
        pred, labels, metadata = main(test_file, cpath, model, image_path, mapping)
        data.append([pred, labels, metadata])
else:
    pred, labels, metadata = main(test_file, checkpoint_path, model, image_path, mapping)
    data.append([pred, labels, metadata])
    


4
cuda


  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 155/155 [01:20<00:00,  1.92it/s]


4
cuda


100%|██████████| 155/155 [01:17<00:00,  2.00it/s]


4
cuda


100%|██████████| 155/155 [01:15<00:00,  2.04it/s]


4
cuda


100%|██████████| 155/155 [01:13<00:00,  2.12it/s]


4
cuda


100%|██████████| 155/155 [01:13<00:00,  2.10it/s]


In [6]:
accuracies = []
for pred, labels, metadata in data:
    pred_labels = torch.argmax(pred, dim=1)
    correct = (pred_labels == labels).sum().item()
    acc = correct/len(labels)
    accuracies.append(acc)
print(f"Individual acc: {accuracies}")
print(f"Test set: \n Mean: {np.mean(accuracies)}\n Std: {np.std(accuracies)}")


Individual acc: [0.969794928780685, 0.9729265582382058, 0.9698959490857663, 0.9682796242044651, 0.9713102333569047]
Test set: 
 Mean: 0.9704414587332053
 Std: 0.001569687814894813


In [16]:
id_mapping = "/system/user/studentwork/seibezed/bachelor/data/id_mapping_seed1234.json"

#load mapping
with open(id_mapping, "r") as f:
    id_to_class = json.load(f)

#compute accuracy per label    
label_acc = {}
for pred, labels, metadata in data:

    for i in labels.unique():
        filter_pred = pred[labels==i]
        filter_pred = torch.argmax(filter_pred, dim=1)
        correct = (filter_pred == i).sum().item()
        acc = correct/len(filter_pred)

        key = id_to_class[str(i.item())]
        if key not in label_acc:
            label_acc[key] = []
        label_acc[key].append(acc)

for key, value in label_acc.items():
    print(f"{key}: {np.mean(value)}")


JCP2022_035095: 0.9946117274167987
JCP2022_046054: 0.9860728744939271
JCP2022_064022: 0.9952960259529604
JCP2022_085227: 0.904527402700556
JCP2022_012818: 0.991618734593262
JCP2022_025848: 0.9466353677621283
JCP2022_037716: 0.9869674185463658
JCP2022_050797: 0.9600985221674877


In [18]:
id_mapping = "/system/user/studentwork/seibezed/bachelor/data/id_mapping_seed1234.json"

#compute accuracy per batch
batch_acc = {}
for pred, labels, metadata in data:
    for i in metadata["batch"].unique():
        filter_pred = pred[metadata["batch"]==i]
        filter_labels = labels[metadata["batch"]==i]
        filter_pred = torch.argmax(filter_pred, dim=1)
        correct = (filter_pred == filter_labels).sum().item()
        acc = correct/len(filter_pred)

        key = i
        if key not in batch_acc:
            batch_acc[key] = []
        batch_acc[key].append(acc)

for key, value in batch_acc.items():
    print(f"{key}: {np.mean(value)}")




CP_27_all_Phenix1: 0.9932971014492754
CP_32_all_Phenix1: 0.9757307351638618
CP_28_all_Phenix1: 0.9829743589743589
CP_26_all_Phenix1: 0.9859525899912203
CP_31_all_Phenix1: 0.879831223628692
CP_25_all_Phenix1: 0.9949647532729105
CP60: 0.9934875749785774
CP59: 0.9460055096418734
CP_29_all_Phenix1: 0.9898032200357783


In [None]:
#TESTING PURPOSES
#seperated
#epoch 45,40,45,44,42
acc_val = [0.9374835483021848, 0.9352461173993156, 0.9407738878652276, 0.9330086864964464, 0.9441884954587337]
acc_test = [0.9160146061554513,0.9104503564597461,0.9260998087289167,0.9110589462702139,0.9165362545644236]

print(f"Validation set: \n Mean: {np.mean(acc_val)}\n Std: {np.std(acc_val)}\n")
print(f"Test set: \n Mean: {np.mean(acc_test)}\n Std: {np.std(acc_test)}")

#random
#epoch 42, 41, 42, 40, 42

Validation set: 
 Mean: 0.9381401471043818
 Std: 0.003966296158951751

Test set: 
 Mean: 0.9160319944357503
 Std: 0.00561251023737622
