# 2) CNN Model Metrics
Updated on 9 April 2020

Generate the receiver operation curves (ROC) and precision recall curves (PRC) for both Tang et al's models (recreation) and the fresh model we trained on their data. This is simply for comparison to show that we can retrain the model on our system and it produces equally as good resutls.

In [None]:
import warnings
warnings.filterwarnings('ignore')
import time, os
from os.path import join
import torch
torch.manual_seed(42)
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import transforms
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
from sklearn.metrics import roc_curve, auc, precision_recall_curve

In [None]:
# Global parameters
DATA_DIR = '/mnt/Data/'
batch_size = 48
num_workers = 12
csv_path = '../CSVs/{}_oversampled.csv'
save_dir = join(DATA_DIR, 'outputs/CNN_metrics/')
img_dirs = [
    join(DATA_DIR, 'Tiles/hold-out/'),
    join(DATA_DIR, 'Tiles/train_and_val/'),
    join(DATA_DIR, 'Tiles/train_and_val/')
]
datasets = ['test', 'train', 'dev']
image_classes = ['cored','diffuse','CAA']

In [None]:
# set-up model and data loaders
os.makedirs(save_dir, exist_ok=True)

class MultilabelDataset(Dataset):
    def __init__(self, csv_path, img_path, transform=None):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        self.data_info = pd.read_csv(csv_path)
        self.img_path = img_path
        self.transform = transform
        c=torch.Tensor(self.data_info.loc[:,'cored'])
        d=torch.Tensor(self.data_info.loc[:,'diffuse'])
        a=torch.Tensor(self.data_info.loc[:,'CAA'])
        c=c.view(c.shape[0],1)
        d=d.view(d.shape[0],1)
        a=a.view(a.shape[0],1)
        self.raw_labels = torch.cat([c,d,a], dim=1)
        self.labels = (torch.cat([c,d,a], dim=1)>0.99).type(torch.FloatTensor)

    def __getitem__(self, index):
        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.labels[index]
        raw_label = self.raw_labels[index]
        # Get image name from the pandas df
        single_image_name = str(self.data_info.loc[index,'imagename'])
        # open image
        img_as_img = Image.open(self.img_path + single_image_name)
        # Transform image to tensor
        if self.transform is not None:
            img_as_img = self.transform(img_as_img)
        # Return image and the label
        return (img_as_img, single_image_label, raw_label, single_image_name)

    def __len__(self):
        return len(self.data_info.index)
    
    
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array(norm['mean'])
    std = np.array(norm['std'])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.figure()
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated
    
    
class Net(nn.Module):

    def __init__(self, fc_nodes=512, num_classes=3, dropout=0.5):
        super(Net, self).__init__()

    def forward(self, x):
 
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x


def dev_model(model, criterion, phase='test', gpu_id=None):
    phase = phase
    since = time.time()
    
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                             shuffle=False, num_workers=num_workers)
              for x in [phase]}

    model.train(False) 

    running_loss = 0.0
    running_corrects = torch.zeros(len(image_classes))
    running_preds = torch.Tensor(0)
    running_predictions = torch.Tensor(0)
    running_labels = torch.Tensor(0)
    running_raw_labels = torch.Tensor(0)

    # Iterate over data.
    step = 0
    for data in dataloaders[phase]:
        step += 1 
        # get the inputs
        inputs, labels, raw_labels, names = data
        running_labels = torch.cat([running_labels, labels])
        running_raw_labels = torch.cat([running_raw_labels, raw_labels])

        # wrap them in Variable
        if use_gpu:
            inputs = Variable(inputs.cuda(gpu_id))
            labels = Variable(labels.cuda(gpu_id))
        else:
            inputs, labels = Variable(inputs), Variable(labels)

        # forward
        outputs = model(inputs)
        preds = F.sigmoid(outputs) #posibility for each class
        #print(preds)
        if use_gpu:
            predictions = (preds>0.5).type(torch.cuda.FloatTensor)
        else:
            predictions = (preds>0.5).type(torch.FloatTensor)
        
        loss = criterion(outputs, labels)

        preds = preds.data.cpu()
        predictions = predictions.data.cpu()
        labels = labels.data.cpu()

        # statistics
        running_loss += loss.data[0]
        running_corrects += torch.sum(predictions==labels, 0).type(torch.FloatTensor)
        running_preds = torch.cat([running_preds, preds])
        running_predictions = torch.cat([running_predictions, predictions])


    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects / dataset_sizes[phase]

    print('{} Loss: {:.4f}\n Cored: {:.4f} Diffuse: {:.4f} CAA: {:.4f}'.format(
                phase, epoch_loss, epoch_acc[0], epoch_acc[1], epoch_acc[2]))

    print()

    time_elapsed = time.time() - since
    print('Prediction complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    return epoch_acc, running_preds, running_predictions, running_labels


def plot_roc(preds, label, image_classes, phase, size=20, path=None):
    colors = ['pink','c','deeppink', 'b', 'g', 'm', 'y', 'r', 'k']
    fig = plt.figure(figsize=(1.2*size, size))
    ax = plt.axes()
    for i in range(preds.shape[1]):
        fpr, tpr, _ = roc_curve(label[:,i].ravel(), preds[:,i].ravel())
        lw = 0.2*size
        # Plot all ROC curves
        ax.plot([0, 1], [0, 1], 'k--', lw=lw, label='random')
        ax.plot(fpr, tpr,
                 label='ROC-curve of {}'.format(image_classes[i])+ '( area = {0:0.3f})'
                ''.format(auc(fpr, tpr)),
                  color=colors[(i+preds.shape[1])%len(colors)], linewidth=lw)
       
    
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate', fontsize=1.8*size)
    ax.set_ylabel('True Positive Rate', fontsize=1.8*size)
    ax.set_title('Receiver operating characteristic Curve ({})'.format(phase), fontsize=1.8*size, y=1.01)
    ax.legend(loc=0, fontsize=1.5*size)
    ax.xaxis.set_tick_params(labelsize=1.6*size, size=size/2, width=0.2*size)
    ax.yaxis.set_tick_params(labelsize=1.6*size, size=size/2, width=0.2*size)
    
    if path != None:
        if os.path.isfile(path):
            print('{} already exists, not saving image'.format(path))
        else:
            fig.savefig(path)
            plt.close(fig)
            print('saved')

    
def plot_prc(preds, label, image_classes, phase, size=20, path=None):
    colors = ['pink','c','deeppink', 'b', 'g', 'm', 'y', 'r', 'k']
    
    fig = plt.figure(figsize=(1.2*size,size))
    ax = plt.axes()
    
    for i in range(preds.shape[1]):
        rp = (label[:,i]>0).sum()/len(label)
        precision, recall, _ = precision_recall_curve(label[:,i].ravel(), preds[:,i].ravel())
        
        lw=0.2*size
    
        ax.plot(recall, precision,
                 label='PR-curve of {}'.format(image_classes[i])+ '( area = {0:0.3f})'
                ''.format(auc(recall, precision)),
                 color=colors[(i+preds.shape[1])%len(colors)], linewidth=lw)

        ax.plot([0, 1], [rp, rp], 'k--', color=colors[(i+preds.shape[1])%len(colors)], lw=lw, label='random')
   
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('Recall', fontsize=1.8*size)
    ax.set_ylabel('Precision', fontsize=1.8*size)
    ax.set_title('Precision-Recall curve ({})'.format(phase), fontsize=1.8*size, y=1.01)
    ax.legend(loc="lower left", bbox_to_anchor=(0.01, 0.1), fontsize=1.5*size)
    ax.xaxis.set_tick_params(labelsize=1.6*size, size=size/2, width=0.2*size)
    ax.yaxis.set_tick_params(labelsize=1.6*size, size=size/2, width=0.2*size)
    
    if path != None:
        if os.path.isfile(path):
            print('{} already exists, not saving image'.format(path))
        else:
            fig.savefig(path)
            plt.close(fig)
            print('saved')
    
    
def auc_roc(preds, label):
    aucroc = []
    for i in range(preds.shape[1]):
        fpr, tpr, _ = roc_curve(label[:,i].ravel(), preds[:,i].ravel())
        aucroc.append(auc(fpr, tpr))
    return aucroc
    
    
def auc_prc(preds, label):
    aucprc = []
    for i in range(preds.shape[1]):
        precision, recall, _ = precision_recall_curve(label[:,i].ravel(), preds[:,i].ravel())
        aucprc.append(auc(recall, precision))
    return aucprc

norm = np.load("../modules/normalization.npy", allow_pickle=True).item()

os.makedirs(save_dir, exist_ok=True)


data_transforms = { 
    x: transforms.Compose([transforms.ToTensor(), transforms.Normalize(norm['mean'], norm['std'])]) \
        for x in datasets
}

image_datasets = {
    x : MultilabelDataset(csv_path.format(x), d, data_transforms[x]) \
        for x, d in zip(datasets, img_dirs)
}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers)
               for x in datasets}

dataset_sizes = {x: len(image_datasets[x]) for x in datasets}

use_gpu = torch.cuda.is_available()

for x in datasets:
    print(x)
    # Get a batch of training data
    inputs, labels, raw_labels, names = next(iter(dataloaders[x]))
    # Make a grid from batch
    out = torchvision.utils.make_grid(inputs)
    imshow(out)


In [None]:
def plot_CNN_metrics(model_path, description):
    criterion = nn.MultiLabelSoftMarginLoss(size_average=False)
    model = torch.load(model_path, map_location=lambda storage, loc: storage)
    if use_gpu:
        print("using GPU\n")
        model = model.module.cuda()
    
    # run for each
    for dataset in datasets:
        try:
            acc, pred, prediction, target = dev_model(model.module, criterion, phase=dataset, gpu_id=0)
        except:
            acc, pred, prediction, target = dev_model(model, criterion, phase=dataset, gpu_id=0)


        label = target.numpy()
        preds = pred.numpy()

        output = {}
        for i in range(3):
            fpr, tpr, _ = roc_curve(label[:,i].ravel(), preds[:,i].ravel())

            precision, recall, _ = precision_recall_curve(label[:,i].ravel(), preds[:,i].ravel())

            output['{} fpr'.format(image_classes[i])] = fpr
            output['{} tpr'.format(image_classes[i])] = tpr
            output['{} precision'.format(image_classes[i])] = precision
            output['{} recall'.format(image_classes[i])] = recall

        outcsv = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in output.items() ]))

        csv_file_path = join(save_dir, '{}_CNN_{}_output.csv'.format(description, dataset))
        outcsv.to_csv(csv_file_path, index=False)

        plot_roc(pred.numpy(), target.numpy(), image_classes, dataset, size=30,
                 path=join(save_dir, '{}_CNN_{}_roc.png'.format(description, dataset)))
        plot_prc(pred.numpy(), target.numpy(), image_classes, dataset, size=30,
                 path=join(save_dir, '{}_CNN_{}_prc.png'.format(description, dataset)))

        
# run CNN metrics for both models, saves ROC and PRC plots (+ plotting data)
model_path = '../models/CNN_model_parameters.pkl'
plot_CNN_metrics(model_path, 'tang')

model_path = '../models/CNN_fresh_model_parameters.pkl'
plot_CNN_metrics(model_path, 'emory')