In [1]:
# !pip install efficientnet-pytorch sklearn

In [2]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130


In [4]:
import matplotlib.pyplot as plt

import os
import random
import time

import numpy as np
import torch
import torchvision
from PIL import Image
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_auc_score, roc_curve, auc
from torch import nn
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import transforms

import models
from dataset_generator import DatasetGenerator

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
image_size = 224
image_resize = 256
class_names = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']

Matplotlib is building the font cache using fc-list. This may take a moment.


In [5]:
print(torch.__version__)
print(torchvision.__version__)

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

1.4.0
0.5.0


In [6]:
image_dir = 'database'
train_csv_file = 'dataset/train.txt'
test_csv_file = 'dataset/test.txt'
valid_csv_file = 'dataset/valid.txt'

In [7]:
# TODO: Define your transforms for the training, validation, and testing sets
train_transforms = transforms.Compose([
    transforms.RandomOrder([
        transforms.ColorJitter(hue=.05, saturation=.05),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15, resample=Image.BILINEAR),
        transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)),
    ]),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

valid_transforms = transforms.Compose([
    transforms.Resize(image_resize),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_transforms = valid_transforms 

In [8]:
# TODO: Load the datasets with DatasetGenerator
train_dataset = DatasetGenerator(train_csv_file, image_dir, transform=train_transforms)
image, label = next(iter(train_dataset))
print(label)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [9]:
# TODO: Load the datasets with DatasetGenerator
test_dataset = DatasetGenerator(test_csv_file, image_dir, transform=test_transforms)
image, label = next(iter(test_dataset))
print(label)

tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [10]:
#TODO: Load the datasets with DatasetGenerator
valid_dataset = DatasetGenerator(valid_csv_file, image_dir, transform=valid_transforms)
image,label = next(iter(valid_dataset))
print(label)

tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])


In [11]:
# TODO: Build and train your network
def load_pretrained_model(arch):
    model_func = getattr(models, arch)
    model = model_func()
    model.arch = arch
    
    return model

In [12]:
def validate(model, valid_dataloader, loss_func, device):
    #track accuracy and loss 
    test_loss = 0
    
    with torch.no_grad(): #deactivates requires_grad flag, disables tracking of gradients 
        for images, labels in valid_dataloader: #iterate over images and labels in valid dataset
            images, labels = images.to(device), labels.to(device) #move a tensor to a device
            ps = model.forward(images) #probabilities for each label
            test_loss += loss_func(ps, labels) #take all values 
    
    return test_loss

In [13]:
# TODO: Do validation on the test set
def test_model(model, test_loader, device):
    #track accuracy, move to device, switch on eval mode
    model.to(device)
    model.eval()
    pred_out = torch.FloatTensor().to(device)
    labels_out = torch.FloatTensor().to(device)
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device) #move a tensor to a device
            labels_out = torch.cat((labels_out,labels),0)
            ps = model.forward(images)
            pred_out = torch.cat((pred_out,ps),0)
            
    metrics = calc_roc_metrics(pred_out, labels_out)
    return metrics

In [14]:
def calc_precision_recall(pred_out, labels_out):
    precision = dict()
    recall = dict()
    average_precision = []

    # https://github.com/rachellea/glassboxmedicine/blob/master/2020-07-14-AUROC-AP/main.py
    # http://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
    
    labels_out = labels_out.cpu().numpy()
    pred_out = pred_out.cpu().numpy()

    for idx, clz in enumerate(class_names):
        clz_pred_out = pred_out[:,idx]
        labels_list = labels_out[:,idx]
        # TBD calculate precision and recall curve and store in precision, recall
        precision[clz], recall[clz], _ = precision_recall_curve(y_true = labels_list,
                                                                probas_pred = clz_pred_out)

        # TBD calculate average precision score (the area under the curve)
        average_precision.append(average_precision_score(y_true=labels_list,
                                                         y_score = clz_pred_out))

    return precision, recall, average_precision

In [15]:
def calc_roc_metrics(pred_out, labels_out):
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    # https://github.com/rachellea/glassboxmedicine/blob/master/2020-07-14-AUROC-AP/main.py
    # http://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
    
    y_true = labels_out.cpu().numpy()
    y_score = pred_out.cpu().numpy()

    for idx, clz in enumerate(class_names):
        fpr[clz], tpr[clz], _ = roc_curve(y_true[:, idx], y_score[:, idx])
        roc_auc[clz] = auc(fpr[clz], tpr[clz])
    
    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    
    
    # Compute macro-average ROC curve and ROC area
    all_fpr = np.unique(np.concatenate([fpr[clz] for clz in class_names]))

    # Then interpolate all ROC curves at these points
    mean_tpr = np.zeros_like(all_fpr)
    for clz in class_names:
        mean_tpr += np.interp(all_fpr, fpr[clz], tpr[clz])

    # Finally average it and compute AUC
    mean_tpr /= len(class_names)

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    average_precision = average_precision_score(y_true, y_score)
    roc_auc_mean = roc_auc_score(y_true, y_score)
    
    return {
        'average_precision_score': average_precision_score(y_true, y_score),
        'roc_auc_score': roc_auc_score(y_true, y_score),
        'fpr': fpr,
        'tpr': tpr,
        'roc_auc': roc_auc
    }

In [16]:
# TODO: Save the checkpoint
def save_checkpoint(checkpoint_path, model):
    checkpoint = {
        'arch':model.arch,
        'metrics':model.metrics,
        'state_dict':model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, checkpoint_path)
    print('Saved the trained model: %s' % checkpoint_path)

In [17]:
# TODO: Using the image datasets and the trainforms, define the dataloaders
batch_size = 64
dataloaders = {"train": DataLoader(train_dataset, batch_size=batch_size, num_workers=24, shuffle=True),
               "test": DataLoader(test_dataset, batch_size=batch_size, num_workers=24, shuffle=False),
              "valid":DataLoader(valid_dataset, batch_size=batch_size, num_workers=24, shuffle=False)}

In [18]:
device = 'cpu'
cuda = torch.cuda.is_available()
if cuda:
    device = 'cuda'
device

'cuda'

In [19]:
available_models = [m for m in dir(models) if 'Net' in m and 'Model' not in m]
available_models

['DenseNet121',
 'DenseNet169',
 'DenseNet201',
 'EfficientNet4',
 'EfficientNet5',
 'EfficientNet6',
 'ResNet101',
 'ResNet34',
 'ResNet50']

In [20]:
arch = 'ResNet50'
pretrained_model = load_pretrained_model(arch)
optimizer = optim.Adam(pretrained_model.parameters(),
                       lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
criterion = nn.BCELoss(reduction='mean')
scheduler = ReduceLROnPlateau(optimizer, factor = 0.1, patience = 5, mode = 'min')

In [21]:
ckp_path = 'checkpoints/ResNet50-24082020-091321.pth.tar'
model_checkpoint = torch.load(ckp_path, map_location=torch.device(device))
if model_checkpoint['arch'] != arch:
    print('Mismatched arch: Model Checkpoint=%s vs Arch=%s'%(model_checkpoint['arch'],arch))

state_dict = model_checkpoint['state_dict']
arch = model_checkpoint['arch']
saved_optimizer = model_checkpoint['optimizer']
saved_scheduler = model_checkpoint['scheduler']

# check the state dictionary matches
old_state_dict = pretrained_model.state_dict()
for k in state_dict:
    if k not in old_state_dict:
        print('Unexpected key %s in state_dict' % k)
for k in old_state_dict:
    if k not in state_dict:
        print('Missing key %s in state_dict' % k)

# TBD: load checkpoint into pretrained_model
pretrained_model.load_state_dict(state_dict)
optimizer.load_state_dict(saved_optimizer)
scheduler.load_state_dict(saved_scheduler)

In [None]:
def train(model, loss_func, optimizer, scheduler, dataloaders, device, epochs, checkpoint_prefix, print_every=100):
    device = torch.device(device)
    model.to(device)

    train_loader = dataloaders['train']
    valid_loader = dataloaders['valid']
    test_loader = dataloaders['test']


    epoch_start = time.time()
    max_acc = 0.0

    # loop to train for number of epochs
    for e in range(epochs):
        running_loss = 0
        batch_start = time.time()
        steps = 0

        for images, labels in train_loader:
            # within each loop, iterate train_loader, and print loss
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            ps = model.forward(images)
            loss = loss_func(ps, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            steps += 1

            if steps % print_every == 0:
                model.eval()
                valid_loss = validate(model, valid_loader, loss_func, device)
                model.train()
                batch_time = time.time() - batch_start
                print(
                    "Epoch: {}/{}..".format(e + 1, epochs),
                    "Step: {}..".format(steps),
                    "Training Loss: {:.5f}..".format(running_loss / len(train_loader)),
                    "Test Loss: {:.5f}..".format(valid_loss / len(valid_loader)),
                    "Batch Time: {:.1f}, avg: {:.3f}".format(batch_time, batch_time / steps)
                )

        
        model.eval()
        valid_loss = validate(model, valid_loader, loss_func, device)
        model.train()
        
        scheduler.step(valid_loss / len(valid_loader))

        model.eval()
        metrics = test_model(model, test_loader, device)
        roc_aoc_mean = metrics['roc_auc_score']
        average_precision_mean = metrics['average_precision_score']
        model.train()
        epoch_time = time.time() - epoch_start
        
        if roc_aoc_mean > max_acc:
            max_acc = roc_aoc_mean
            model.metrics = metrics
            save_checkpoint('%s.pth.tar' % checkpoint_prefix, model)
            print(
                "Epoch: {}/{} [save] Epoch Time={:.1f}..".format(e + 1, epochs, epoch_time),
                "Average Precision: {:.3f}..".format(average_precision_mean),
                "ROC: {:.3f}..".format(roc_aoc_mean),
                "Batch Time: {:.1f}, avg: {:.3f}".format(batch_time, batch_time / steps)
            )
        else:
            print(
                "Epoch: {}/{} [----] Epoch Time={:.1f}..".format(e + 1, epochs, epoch_time),
                "Average Precision: {:.3f}..".format(average_precision_mean),
                "ROC: {:.3f}..".format(roc_aoc_mean),
                "Batch Time: {:.1f}, avg: {:.3f}".format(batch_time, batch_time / steps)
            )

    return model

In [None]:
epochs = 7
training_start = time.time()
checkpoint_prefix = os.path.join('checkpoints','%s-%d'%(arch, training_start))
    
trained_model = train(pretrained_model, criterion, optimizer, scheduler, dataloaders, device, epochs, checkpoint_prefix, print_every=300)
print('%.2f seconds taken for model training' % (time.time() - training_start))

In [None]:
metrics = trained_model.metrics

In [None]:
# see https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html

colors = ['navy', 'deeppink', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']
fpr, tpr, roc_auc = metrics['fpr'], metrics['tpr'], metrics['roc_auc']

# Plot all ROC curves
plt.figure()
idx = 0
plt.plot(fpr["micro"], tpr["micro"],
         label='Micro-average ROC (area = {0:0.2f})'.format(roc_auc["micro"]),
         color=colors[idx % len(colors)], linestyle=':', linewidth=4)
idx += 1

plt.plot(fpr["macro"], tpr["macro"],
         label='Macro-average ROC (area = {0:0.2f})'.format(roc_auc["macro"]),
         color=colors[idx % len(colors)], linestyle=':', linewidth=4)
idx += 1

lw = 2
# colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for clz in sorted(roc_auc, key=roc_auc.get):
    plt.plot(fpr[clz], tpr[clz], color=colors[idx % len(colors)], lw=lw,
             label='Class {0} (area = {1:0.2f})'.format(clz, roc_auc[clz]))
    idx += 1

plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Multi-class ROC Curve')
plt.legend(loc=(0, -1.5))
plt.show()