### Import necessary modules

In [3]:
# This is necessary for imports to work correctly
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
from torchvision.transforms import transforms
from materials.models import *
from materials.constants import *
from sklearn.metrics import roc_auc_score
from materials.segmentation_inference import *
from materials.metrics import *
from PIL import Image
from materials.datasets import *
from torch.utils.data import DataLoader
from materials.custom_transformations import HistogramEqualization
from materials.util import *


### Create Dataset and Data Loader

In [3]:
transformation_list = [
    transforms.Resize((256, 256)),
    HistogramEqualization(),
    transforms.ToTensor(),
]
image_transformation = transforms.Compose(transformation_list)

test_dataset = CheXpertDataset(data_path="data/CheXpert-v1.0-small/valid.csv",
                               uncertainty_policy="zeros", transform=image_transformation)

test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)


### Load Models

In [4]:
%%capture
heart_model_path = "models/heart_segmentation/unet_vgg16_simple_aug_lr=0.0001_batch=8_28.3_21:6.pth"
lung_model_path = "models/lung_segmentation/unet_vgg16_simple_aug_lr=0.0001_batch=8_28.3_21:0.pth"
heart_model = load_segmentation_model(heart_model_path, device="cpu")
lung_model = load_segmentation_model(lung_model_path, device="cpu")

### Calculate CTR for CheXpert images

In [5]:
ground_truth = torch.FloatTensor()
ctr = []
prediction = torch.FloatTensor()
for step, (image, label) in enumerate(test_dataloader):
    ground_truth = torch.cat((ground_truth, label), 0)
    ctr_for_image = torch.tensor([ctr_from_tensor(image, heart_model, lung_model)])
    ctr.append(ctr_for_image)
    print(f"Label: {label}, CTR: {ctr_for_image}")
    prediction_for_image = torch.ones(1) if ctr_for_image > 0.5 else torch.zeros(1)
    prediction = torch.cat((prediction, prediction_for_image), 0)
print(ctr)

Label: tensor([[0.]]), CTR: tensor([1.4144])
Label: tensor([[0.]]), CTR: tensor([1.5059])
Label: tensor([[1.]]), CTR: tensor([1.7067])
Label: tensor([[1.]]), CTR: tensor([1.4713])
Label: tensor([[0.]]), CTR: tensor([1.1852])
Label: tensor([[0.]]), CTR: tensor([1.4222])
Label: tensor([[0.]]), CTR: tensor([1.6000])
Label: tensor([[0.]]), CTR: tensor([1.7297])
Label: tensor([[0.]]), CTR: tensor([2.0317])
Label: tensor([[0.]]), CTR: tensor([1.4798])
Label: tensor([[0.]]), CTR: tensor([1.2549])
Label: tensor([[1.]]), CTR: tensor([1.3838])
Label: tensor([[1.]]), CTR: tensor([1.1278])
Label: tensor([[1.]]), CTR: tensor([1.8551])
Label: tensor([[0.]]), CTR: tensor([1.2133])
Label: tensor([[1.]]), CTR: tensor([1.1689])
Label: tensor([[0.]]), CTR: tensor([1.2075])
Label: tensor([[0.]]), CTR: tensor([1.7655])
Label: tensor([[0.]]), CTR: tensor([2.0157])
Label: tensor([[0.]]), CTR: tensor([1.4713])
Label: tensor([[1.]]), CTR: tensor([1.7297])
Label: tensor([[0.]]), CTR: tensor([1.4066])
Label: ten

### Calculate AUROC score

In [6]:
score = roc_auc_score(ground_truth.squeeze(), torch.tensor(ctr))
print(score)

0.549520944741533


In [7]:
ground_truth = ground_truth.squeeze()
diff = ground_truth != prediction
wrong_indices = torch.nonzero(diff)
ctr = ctr.squeeze()
print(ctr)
for index in wrong_indices:
    index = index.item()
    print(f"Index: {index}, CTR: {ctr[index]}")

AttributeError: 'list' object has no attribute 'squeeze'