### Import necessary modules

In [1]:
import torch
from torchvision.transforms import transforms
from models import *
from constants import *
from sklearn.metrics import roc_auc_score
from segmentation_inference import *
from metrics import *
from PIL import Image
from datasets import *
from torch.utils.data import DataLoader
from custom_transformations import HistogramEqualization
from util import *


### Create Dataset and Data Loader

In [2]:
transformation_list = [
    transforms.Resize((256, 256)),
    HistogramEqualization(),
    transforms.ToTensor(),
]
image_transformation = transforms.Compose(transformation_list)

test_dataset = CheXpertDataset(data_path="data/CheXpert-v1.0-small/valid.csv",
                               uncertainty_policy="zeros", transform=image_transformation)

test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)


### Load Models

In [3]:
%%capture
heart_model_path = "models/heart_segmentation/unet_vgg16_simple_aug_03_noflip_lr=0.0001_batch=8_28.3_21:29.pth"
lung_model_path = "models/lung_segmentation/unet_vgg16_simple_aug_lr=0.0001_batch=8_28.3_21:0.pth"
heart_model = load_segmentation_model(heart_model_path, device="cpu")
lung_model = load_segmentation_model(lung_model_path, device="cpu")

### Calculate CTR for CheXpert images

In [49]:
ground_truth = torch.FloatTensor()
ctr = []
prediction = torch.FloatTensor()
for step, (image, label) in enumerate(test_dataloader):
    ground_truth = torch.cat((ground_truth, label), 0)
    ctr_for_image = torch.tensor([ctr_from_tensor(image, heart_model, lung_model)])
    ctr.append(ctr_for_image)
    print(f"Label: {label}, CTR: {ctr_for_image}")
    prediction_for_image = torch.ones(1) if ctr_for_image > 0.5 else torch.zeros(1)
    prediction = torch.cat((prediction, prediction_for_image), 0)
print(ctr)

Label: tensor([[0.]]), CTR: tensor([0.5114])
Label: tensor([[1.]]), CTR: tensor([0.5349])
Label: tensor([[0.]]), CTR: tensor([0.4437])
Label: tensor([[0.]]), CTR: tensor([0.4027])
Label: tensor([[0.]]), CTR: tensor([0.4010])
Label: tensor([[0.]]), CTR: tensor([0.4265])
Label: tensor([[0.]]), CTR: tensor([0.3736])
Label: tensor([[0.]]), CTR: tensor([0.4158])
Label: tensor([[0.]]), CTR: tensor([0.4421])
Label: tensor([[0.]]), CTR: tensor([0.4810])
Label: tensor([[1.]]), CTR: tensor([0.5605])
Label: tensor([[0.]]), CTR: tensor([0.3469])
Label: tensor([[0.]]), CTR: tensor([0.3238])
Label: tensor([[1.]]), CTR: tensor([0.5198])
Label: tensor([[1.]]), CTR: tensor([0.4255])
Label: tensor([[0.]]), CTR: tensor([0.4444])
Label: tensor([[0.]]), CTR: tensor([0.3460])
Label: tensor([[1.]]), CTR: tensor([0.5032])
Label: tensor([[0.]]), CTR: tensor([0.5287])
Label: tensor([[1.]]), CTR: tensor([0.6092])
Label: tensor([[0.]]), CTR: tensor([0.4346])
Label: tensor([[0.]]), CTR: tensor([0.4242])
Label: ten

### Calculate AUROC score

In [50]:
score = roc_auc_score(ground_truth.squeeze(), prediction)
print(score)

0.6133021390374331


In [51]:
ground_truth = ground_truth.squeeze()
diff = ground_truth != prediction
wrong_indices = torch.nonzero(diff)
ctr = ctr.squeeze()
print(ctr)
for index in wrong_indices:
    index = index.item()
    print(f"Index: {index}, CTR: {ctr[index]}")

AttributeError: 'list' object has no attribute 'squeeze'