This notebook is used for evaluating the trained models on different splis and visualizing the batch effect

In [88]:
import sys
import os

src_path = os.path.split(os.getcwd())[0]
sys.path.insert(0, src_path)

import numpy as np
import torch
import json
from training.model import ResNet
from training.data import JUMPCPDataset
from tqdm import tqdm
from torch.utils.data import DataLoader
import pandas as pd


In [89]:
checkpoint_path = "/system/user/studentwork/seibezed/bachelor/logs/seperated_1/checkpoints/epoch_48.pt"
model = "ResNet50"
test_file = "/system/user/studentwork/seibezed/bachelor/data/seperated_seed1234_test.csv"
mapping = "/system/user/studentwork/seibezed/bachelor/data/class_mapping.json"
image_path = "/system/user/publicdata/jumpcp/"


In [90]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()

def load_model(checkpoint_path, device, model):
    checkpoint = torch.load(checkpoint_path)
    state_dict = checkpoint["state_dict"]

    model_config_file = os.path.join(src_path, f"training/model_parameter/{model.replace('/', '-')}.json")
    assert os.path.exists(model_config_file)
    with open(model_config_file, 'r') as f:
        model_info = json.load(f)
    model = ResNet(**model_info)

    if str(device) == "cpu":
        model.float()
    print(device)

    new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}

    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval()

    return model


In [91]:
def main(test_file, model_path, model, img_path, mapping_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(torch.cuda.device_count())

    model = load_model(model_path, device, model)

    test_data = JUMPCPDataset(test_file, img_path, mapping_path)

    all_predictions = []
    all_labels = []
    all_metadata = []
    with torch.no_grad():
        for batch in tqdm(DataLoader(test_data, num_workers=20, batch_size=64)):
            imgs, labels, metadata = batch
            imgs = imgs.to(device)
            labels = labels.to(device)

            predictions = model(imgs)
            #TODO: add some id such that imgs can be associated with batch/lab

            all_predictions.append(predictions)
            all_labels.append(labels)
            all_metadata.append(metadata)

    #flat_metadata = [item for batch in all_metadata for item in batch]
    flat_metadata = []
    for batch_meta in all_metadata:  # each is a dict of lists
        (batch_key, batch_items), (source_key, source_items) = batch_meta.items()
        for i in range(len(batch_items)):
            flat_metadata.append({batch_key: batch_items[i], source_key: source_items[i]})

    return torch.cat(all_predictions), torch.cat(all_labels), pd.DataFrame(flat_metadata)



In [92]:
pred, labels, metadata = main(test_file, checkpoint_path, model, image_path, mapping)

4
cuda


100%|██████████| 180/180 [01:30<00:00,  1.99it/s]


In [93]:
pred_labels = torch.argmax(pred, dim=1)
correct = (pred_labels == labels).sum().item()
acc = correct/len(labels)
print(acc)

0.9168840201704052


In [94]:
id_mapping = "/system/user/studentwork/seibezed/bachelor/data/id_mapping.json"

#compute accuracy per label
label_acc = {}
#load mapping
with open(id_mapping, "r") as f:
    id_to_class = json.load(f)

for i in labels.unique():
    filter_pred = pred[labels==i]
    filter_pred = torch.argmax(filter_pred, dim=1)
    correct = (filter_pred == i).sum().item()
    acc = correct/len(filter_pred)

    label_acc.update({id_to_class[str(i.item())]: acc})

print(label_acc)


{'JCP2022_035095': 0.9986111111111111, 'JCP2022_046054': 0.9652777777777778, 'JCP2022_064022': 0.9895833333333334, 'JCP2022_085227': 0.6194444444444445, 'JCP2022_012818': 0.9944095038434662, 'JCP2022_025848': 0.8791666666666667, 'JCP2022_037716': 0.9694444444444444, 'JCP2022_050797': 0.9196366177498253}


In [95]:
id_mapping = "/system/user/studentwork/seibezed/bachelor/data/id_mapping.json"

#compute accuracy per batch
batch_acc = {}

with open(id_mapping, "r") as f:
    id_to_class = json.load(f)


for i in metadata["batch"].unique():
    filter_pred = pred[metadata["batch"]==i]
    filter_labels = labels[metadata["batch"]==i]
    filter_pred = torch.argmax(filter_pred, dim=1)
    correct = (filter_pred == filter_labels).sum().item()
    acc = correct/len(filter_pred)

    batch_acc.update({i: acc})

print(batch_acc)



{'CP59': 0.9513888888888888, 'CP_31_all_Phenix1': 0.8822709857192615}
