In [1]:
import numpy as np
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from src.dataset import ChestXrayDataSet, CLASS_NAMES
from src.model import DenseNet121
from src.utils import compute_AUCs, compute_score_with_logits, tile
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Autoreload modules so that changes to src automatically reflect
%load_ext autoreload
%autoreload 2

In [2]:
labels = pd.read_csv("data/labels/cleaned.csv")
# size = 1000 # only using 10k out of 30k cleaned dataset
# X, Y = labels.iloc[:size, 0], labels.iloc[:size, 1]
X, Y = labels.iloc[:, 0], labels.iloc[:, 1]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.4, random_state=42, shuffle=True)

In [3]:
N_CLASSES = len(CLASS_NAMES)
BATCH_SIZE = 16
DATA_DIR = "data/images"

In [4]:
normalize = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])

train_dataset = ChestXrayDataSet(data_dir=DATA_DIR, X = X_train, Y = Y_train,
                          transform=transforms.Compose([
                              transforms.Resize(256),
                              transforms.TenCrop(224),
                              transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                              transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
                              ]))

# note that workers take up some amount of VRAM   
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

test_dataset = ChestXrayDataSet(data_dir=DATA_DIR, X = X_test, Y = Y_test,
                          transform=transforms.Compose([
                              transforms.Resize(256),
                              transforms.TenCrop(224),
                              transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                              transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
                              ]))
    
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

In [5]:
training = True # Flip to false to simply load pre-trained model
CKPT_PATH = "model.pth.tar" # Starter model from https://github.com/arnoweng/CheXNet
CKPT_TRAINED_PATH = "model-trained.pth" # Model trained on top of ^
cudnn.benchmark = True # Fixed input size, enables tuning for optimal use

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# initialize and load the model
model = DenseNet121(N_CLASSES).to(device)
model = torch.nn.DataParallel(model).to(device)

if training:
    # if os.path.isfile(CKPT_PATH):
    #     print("=> loading checkpoint")
    #     checkpoint = torch.load(CKPT_PATH)
    #     model.load_state_dict(checkpoint['state_dict'], strict=False)
    #     print("=> loaded checkpoint")
    # else:
    #     print("=> no checkpoint found")
else:
    if os.path.isfile(CKPT_TRAINED_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_TRAINED_PATH)
        # Load directly into the module else the model gets screwed up
        # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/15
        model.module.load_state_dict(checkpoint, strict=True)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

=> loading checkpoint
=> loaded checkpoint


## Training

In [6]:
# Source: https://github.com/thibaultwillmann/CheXNet-Pytorch/blob/master/CheXnet.ipynb
# use_amp = True
# scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

num_epochs = 10

if training:
    model.train()

    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # RMSprop, Adam

    for epoch in range(num_epochs):  # loop over the dataset multiple times

        running_loss = 0.0
        correct = 0
        total = 0 
        for i, (images, labels) in enumerate(train_loader, 0): # get the inputs; data is a list of [images, labels]

            # images.shape -> [N, 10, 3, 224, 224]
            # labels.shape -> [N, 15]

            # zero the parameter gradients
            optimizer.zero_grad(set_to_none=True)

            images = images.cuda()

            # format input
            n_batches, n_crops, channels, height, width = images.size()
            image_batch = torch.autograd.Variable(images.view(-1, channels, height, width)) # 10N images: N batches contain 10 crops each decomposed into 10N images

            labels = tile(labels, 0, 10).cuda() # duplicate for each crop the label [1,2],[3,4] => [1,2],[1,2],[3,4],[3,4] -> 10N labels

            # forward + backward + optimize
            # TODO: possibly use torch.cuda.amp?            
            outputs = model(image_batch)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            correct += compute_score_with_logits(outputs, labels).sum()
            total += labels.size(0)

        print('Epoch: %d, loss: %.3f, Accuracy: %.3f' %
            (epoch + 1, running_loss, 100 * correct / total))
        
        # initialize the ground truth and output tensor
        gt = torch.FloatTensor().to(device)
        pred = torch.FloatTensor().to(device)

        # switch to evaluate mode
        model.eval()

        for _, (inp, target) in enumerate(test_loader):
            target = target.cuda()
            gt = torch.cat((gt, target), 0)
            bs, n_crops, c, h, w = inp.size()
            with torch.no_grad():
                input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
                # output = torch.sigmoid(model(input_var))
                output = model(input_var)
                output_mean = output.view(bs, n_crops, -1).mean(1)
                pred = torch.cat((pred, output_mean.data), 0)
        
        AUROCs = compute_AUCs(gt, pred, N_CLASSES)
        AUROC_avg = np.array(AUROCs).mean()
        print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
        for j in range(N_CLASSES):
            print('The AUROC of {} is {}'.format(CLASS_NAMES[j], AUROCs[j]))
        
        model.train()

    print('Finished Training')

Epoch: 1, loss: 224.051, Accuracy: 41.234
The average AUROC is 0.807
The AUROC of Atelectasis is 0.813994926795661
The AUROC of Cardiomegaly is 0.9246699662407186
The AUROC of Effusion is 0.8591235130523533
The AUROC of Infiltration is 0.7611636545753314
The AUROC of Mass is 0.8022238555624658
The AUROC of Nodule is 0.7876465730169279
The AUROC of Pneumonia is 0.6074370641676088
The AUROC of Pneumothorax is 0.8438157919945061
The AUROC of Consolidation is 0.7524459074000616
The AUROC of Edema is 0.8817876838005531
The AUROC of Emphysema is 0.839734168242244
The AUROC of Fibrosis is 0.8208841949006765
The AUROC of Pleural_Thickening is 0.7258588961894885
The AUROC of Hernia is 0.8806399625458048
Epoch: 2, loss: 205.273, Accuracy: 47.016
The average AUROC is 0.817
The AUROC of Atelectasis is 0.8156797715463254
The AUROC of Cardiomegaly is 0.9299710950414224
The AUROC of Effusion is 0.862115611121864
The AUROC of Infiltration is 0.7714124854976581
The AUROC of Mass is 0.8135803921748401
T

## Evaluation of Trained Model

In [7]:
# initialize the ground truth and output tensor
gt = torch.FloatTensor().to(device)
pred = torch.FloatTensor().to(device)

# switch to evaluate mode
model.eval()

for i, (inp, target) in enumerate(test_loader):
    target = target.cuda()
    gt = torch.cat((gt, target), 0)
    bs, n_crops, c, h, w = inp.size()
    with torch.no_grad():
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
        # output = torch.sigmoid(model(input_var))
        output = model(input_var)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        pred = torch.cat((pred, output_mean.data), 0)

In [8]:
AUROCs = compute_AUCs(gt, pred, N_CLASSES)
AUROC_avg = np.array(AUROCs).mean()
print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
for i in range(N_CLASSES):
  print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))

The average AUROC is 0.763
The AUROC of Atelectasis is 0.7222861117429183
The AUROC of Cardiomegaly is 0.8746738558072127
The AUROC of Effusion is 0.8379493738274464
The AUROC of Infiltration is 0.6803081442040271
The AUROC of Mass is 0.7702297254001896
The AUROC of Nodule is 0.742914903000876
The AUROC of Pneumonia is 0.5428527380941258
The AUROC of Pneumothorax is 0.8059443123835461
The AUROC of Consolidation is 0.6608245528395017
The AUROC of Edema is 0.8507670438921471
The AUROC of Emphysema is 0.8299607618237759
The AUROC of Fibrosis is 0.7891992435536415
The AUROC of Pleural_Thickening is 0.6801939320398793
The AUROC of Hernia is 0.893858772474768


In [9]:
if training:
    torch.save(model.module.state_dict(), "model-trained.pth")