In [1]:
import numpy as np
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, classification_report
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.models.densenet import DenseNet121_Weights
from PIL import Image


In [2]:
labels = pd.read_csv("data/labels/labels.csv")

In [3]:
size = 2500
X, Y = labels.iloc[:size, 0], labels.iloc[:size, 1]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.4, random_state=42, shuffle=True)

In [4]:
N_CLASSES = 14
BATCH_SIZE = 16
CKPT_PATH = "model.pth.tar"
CKPT_TRAINED_PATH = "model-trained.pth"
DATA_DIR = "data/images"
CLASS_NAMES = [ "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia",
"Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"]

In [5]:
class DenseNet121(nn.Module):
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(weights=DenseNet121_Weights.DEFAULT)
        features = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(features, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return x

In [6]:
def compute_AUCs(gt, pred):
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs

In [23]:
cudnn.benchmark = True
load_trained = True

# initialize and load the model
model = DenseNet121(N_CLASSES).cuda()
model = torch.nn.DataParallel(model).cuda()

if load_trained:
    if os.path.isfile(CKPT_TRAINED_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_TRAINED_PATH)
        model.load_state_dict(checkpoint, strict=False)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")
else:
    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

=> loading checkpoint
=> loaded checkpoint


In [24]:
def binary_encoding(indices, classes):
    x = [0] * classes
    if indices is not None:
        for index in indices:
            x[index] = 1
    return x

In [25]:
class ChestXrayDataSet(Dataset):
    def __init__(self, data_dir, X, Y, transform=None):
        image_names = []
        labels = []

        for image_name, label in zip(X, Y):
            res = os.path.join(data_dir, image_name)

            # Only using subset, skip over images not downloaded
            if not os.path.exists(res):
                continue

            image_names.append(res)

            if label == "No Finding":
                indices = None
            else:
                indices = [CLASS_NAMES.index(l) for l in label.split('|')]

            labels.append(binary_encoding(indices, 14))

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)

In [10]:
normalize = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])

train_dataset = ChestXrayDataSet(data_dir=DATA_DIR, X = X_train, Y = Y_train,
                          transform=transforms.Compose([
                              transforms.Resize(256),
                              transforms.TenCrop(224),
                              transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                              transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
                              ]))
    
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

test_dataset = ChestXrayDataSet(data_dir=DATA_DIR, X = X_test, Y = Y_test,
                          transform=transforms.Compose([
                              transforms.Resize(256),
                              transforms.TenCrop(224),
                              transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                              transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
                              ]))
    
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

In [11]:
def compute_score_with_logits(logits, labels):
    logits = torch.max(logits, 1)[1].data # argmax
    one_hots = torch.zeros(*labels.size()).cuda()
    one_hots.scatter_(1, logits.view(-1, 1), 1)
    scores = (one_hots * labels)

    return scores
  
def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(a, dim, order_index)

In [12]:
# Source: https://github.com/thibaultwillmann/CheXNet-Pytorch/blob/master/CheXnet.ipynb
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # RMSprop, Adam

for epoch in range(15):  # loop over the dataset multiple times

    running_loss = 0.0
    correct = 0
    total = 0 
    for i, (images, labels) in enumerate(train_loader, 0): # get the inputs; data is a list of [images, labels]

        #images.shape -> [64, 10, 3, 224, 224]
        #labels.shape -> [64, 15]

        # zero the parameter gradients
        optimizer.zero_grad()

        images = images.cuda()

        #format input
        n_batches, n_crops, channels, height, width = images.size()
        image_batch = torch.autograd.Variable(images.view(-1, channels, height, width)) #640 images: 64 batches contain 10 crops each decomposed into 640 images

        labels = tile(labels, 0, 10).cuda() #duplicate for each crop the label [1,2],[3,4] => [1,2],[1,2],[3,4],[3,4] -> 640 labels

        # forward + backward + optimize
        outputs = model(image_batch)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

        correct += compute_score_with_logits(outputs, labels).sum()
        total += labels.size(0)

    print('Epoch: %d, loss: %.3f, Accuracy: %.3f' %
          (epoch + 1, running_loss, 100 * correct / total))

print('Finished Training')

Epoch: 1, loss: 22.563, Accuracy: 18.807
Epoch: 2, loss: 17.523, Accuracy: 20.747
Epoch: 3, loss: 16.629, Accuracy: 21.660
Epoch: 4, loss: 16.166, Accuracy: 22.240
Epoch: 5, loss: 15.833, Accuracy: 22.767
Epoch: 6, loss: 15.560, Accuracy: 23.280
Epoch: 7, loss: 15.320, Accuracy: 23.773
Epoch: 8, loss: 15.104, Accuracy: 24.247
Epoch: 9, loss: 14.904, Accuracy: 24.660
Epoch: 10, loss: 14.715, Accuracy: 25.113
Epoch: 11, loss: 14.536, Accuracy: 25.540
Epoch: 12, loss: 14.363, Accuracy: 25.867
Epoch: 13, loss: 14.196, Accuracy: 26.213
Epoch: 14, loss: 14.032, Accuracy: 26.647
Epoch: 15, loss: 13.872, Accuracy: 27.113
Finished Training


In [26]:
# initialize the ground truth and output tensor
gt = torch.FloatTensor()
pred = torch.FloatTensor()
gt = gt.cuda()
pred = pred.cuda()

# switch to evaluate mode
model.eval()

for i, (inp, target) in enumerate(test_loader):
    target = target.cuda()
    gt = torch.cat((gt, target), 0)
    bs, n_crops, c, h, w = inp.size()
    with torch.no_grad():
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
        output = model(input_var)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        pred = torch.cat((pred, output_mean.data), 0)

In [27]:
AUROCs = compute_AUCs(gt, pred)
AUROC_avg = np.array(AUROCs).mean()
print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
for i in range(N_CLASSES):
  print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))

The average AUROC is 0.513
The AUROC of Atelectasis is 0.5436159598201774
The AUROC of Cardiomegaly is 0.46781914893617027
The AUROC of Effusion is 0.4490754013761467
The AUROC of Infiltration is 0.49265766912825737
The AUROC of Mass is 0.5061549403654667
The AUROC of Nodule is 0.6003578884580972
The AUROC of Pneumonia is 0.4671109823239641
The AUROC of Pneumothorax is 0.5702687324929973
The AUROC of Consolidation is 0.5957848692712993
The AUROC of Edema is 0.48119534309780565
The AUROC of Emphysema is 0.4777569731404958
The AUROC of Fibrosis is 0.4702864583333334
The AUROC of Pleural_Thickening is 0.42169483493918813
The AUROC of Hernia is 0.6436436436436437


In [22]:
torch.save(model.module.state_dict(), "model-trained.pth")