In [1]:
import numpy as np
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, classification_report
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.models.densenet import DenseNet121_Weights
from PIL import Image


In [2]:
labels = pd.read_csv("data/labels/labels.csv")

In [3]:
size = 2500
X, Y = labels.iloc[:size, 0], labels.iloc[:size, 1]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.4, random_state=42, shuffle=True)

In [4]:
N_CLASSES = 14
BATCH_SIZE = 64
CKPT_PATH = "model.pth.tar"
DATA_DIR = "data/images"
CLASS_NAMES = [ "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia",
"Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"]

In [5]:
class DenseNet121(nn.Module):
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(weights=DenseNet121_Weights.DEFAULT)
        features = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(features, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return x

In [39]:
def compute_AUCs(gt, pred):
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs

In [7]:
cudnn.benchmark = True

# initialize and load the model
model = DenseNet121(N_CLASSES).cuda()
model = torch.nn.DataParallel(model).cuda()

if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    checkpoint = torch.load(CKPT_PATH)
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")

=> loading checkpoint
=> loaded checkpoint


In [25]:
def binary_encoding(indices, classes):
    x = [0] * classes
    if indices is not None:
        for index in indices:
            x[index] = 1
    return x

In [26]:
class ChestXrayDataSet(Dataset):
    def __init__(self, data_dir, X, Y, transform=None):
        image_names = []
        labels = []

        for image_name, label in zip(X, Y):
            res = os.path.join(data_dir, image_name)

            # Only using subset, skip over images not downloaded
            if not os.path.exists(res):
                continue

            image_names.append(res)

            if label == "No Finding":
                indices = None
            else:
                indices = [CLASS_NAMES.index(l) for l in label.split('|')]

            labels.append(binary_encoding(indices, 14))

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)

In [35]:
normalize = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])

test_dataset = ChestXrayDataSet(data_dir=DATA_DIR, X = X_test, Y = Y_test,
                          transform=transforms.Compose([
                              transforms.Resize(256),
                              transforms.TenCrop(224),
                              transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                              transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
                              ]))
    
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

In [37]:
# initialize the ground truth and output tensor
gt = torch.FloatTensor()
pred = torch.FloatTensor()
gt = gt.cuda()
pred = pred.cuda()

# switch to evaluate mode
model.eval()

for i, (inp, target) in enumerate(test_loader):
    target = target.cuda()
    gt = torch.cat((gt, target), 0)
    bs, n_crops, c, h, w = inp.size()
    with torch.no_grad():
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
        output = model(input_var)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        pred = torch.cat((pred, output_mean.data), 0)

In [40]:
AUROCs = compute_AUCs(gt, pred)
AUROC_avg = np.array(AUROCs).mean()
print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
for i in range(N_CLASSES):
  print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))

The average AUROC is 0.585
The AUROC of Atelectasis is 0.4289397863204744
The AUROC of Cardiomegaly is 0.599095744680851
The AUROC of Effusion is 0.609070384174312
The AUROC of Infiltration is 0.6193755739210285
The AUROC of Mass is 0.5652970784549731
The AUROC of Nodule is 0.6144745998608212
The AUROC of Pneumonia is 0.4826861779194436
The AUROC of Pneumothorax is 0.6274947478991597
The AUROC of Consolidation is 0.5675265931007059
The AUROC of Edema is 0.6215998712377273
The AUROC of Emphysema is 0.5451317148760331
The AUROC of Fibrosis is 0.5712239583333334
The AUROC of Pleural_Thickening is 0.4686068551571632
The AUROC of Hernia is 0.8758758758758759
