In [None]:
import torch
import os
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, roc_curve, auc
import torch.nn as nn
from dataloader import evaluate_singlemodel, ChestImages, device
from config import P, paths
from torch.utils.data import DataLoader
from models import ResNet18

import matplotlib.pyplot as plt
import math
import sys

In [None]:
def plot_auc(axes, probs, labels, class_names):
    for i in range(len(class_names) + 1):
        ax = axes.flatten()[i]
        if i == len(class_names):
            fpr, tpr, threshold = roc_curve(labels.flatten(), probs.flatten())
            title = "ROC for {}".format("Overall")
        else:
            fpr, tpr, threshold = roc_curve(labels[:,i], probs[:,i])
            title = "ROC for {}".format(class_names[i])

        roc_auc = auc(fpr, tpr)

        ax.set_title(title)
        ax.plot(fpr, tpr, label = "{}: AUC = {:.3f}".format(model._get_name(), roc_auc))
        ax.legend(loc = 'lower right')
        ax.plot([0, 1], [0, 1],'r--')
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 1])
        if (i == 0) or (i == 3):
            ax.set_ylabel('True Positive Rate')
        if i >= 3:
            ax.set_xlabel('False Positive Rate')

In [None]:
def get_prob_labels(path):
    criterion = nn.BCEWithLogitsLoss()
    
    all_probs = []
    for file in os.listdir(path):
        print(file)
        checkpoint = torch.load(os.path.join(path, file), map_location=device)
        class_names = checkpoint["classes"]
        model = ResNet18(out_size=len(class_names))
        model.load_state_dict(checkpoint["model_state_dict"])
        _ = model.to(device)
        probs, labels, loss = evaluate_singlemodel(model, criterion, dataloaders["valid"])
        probs = torch.cat(probs).numpy()
        

        all_probs += [probs]
    probs_mean = np.mean(np.stack(all_probs), axis=0)
    
    
    labels = torch.cat(labels).numpy()
    
    return labels, probs_mean, class_names

In [7]:
if os.name == "nt":
    dir_path = "C://Users/Ashok/Documents/MS/models/"
else:
    dir_path = "/Users/ashok/Downloads/Chexpert/models/"

folders = [folder for folder in os.listdir(dir_path) if folder[0] != "." and "ignore" in folder and "all" in folder]

for folder in folders:
    print(folder)

ResNet18_ignore_all_2019.12.07.16.49.53


In [8]:
"all" in folder

True

In [11]:
for folder in folders:
    print(folder)
    
    if "all" in folder:
        classes_type = "all"
    else:
        classes_type = "subset"
        
    if "all" in folder:
        P.training_classes = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 
               'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 
               'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 
               'Pleural Other', 'Fracture', 'Support Devices']
    else:
        P.training_classes = ['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis', 'Pleural Effusion']

    train_data = ChestImages(paths[os.name][P.dataset]["train_location"], 
            paths[os.name][P.dataset]["dirpath"], 
            P, frac=1.0, classes_type=classes_type)

    valid_data = ChestImages(paths[os.name][P.dataset]["valid_location"], 
        paths[os.name][P.dataset]["dirpath"], 
        P,
        frac=1.0, classes_type=classes_type)

    dataloaders = {
        "train": DataLoader(train_data, 
            batch_size=P.batch_size, 
            shuffle=False, 
            num_workers=P.num_workers),

        "valid": DataLoader(valid_data, 
            batch_size=P.batch_size, 
            shuffle=False, 
            num_workers=P.num_workers)
    }
    labels, probs, class_names = get_prob_labels(os.path.join(dir_path, folder))
    training_classes = ['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis', 'Pleural Effusion']

    keys = (2, 5, 6, 8, 10)
#     for key in keys:
#         assert class_names[key] in training_classes

    for i in range(len(class_names)):
        class_labels, class_probs = labels[:,i], probs[:,i]
        if len(np.unique(class_labels)) == 1: # bug in roc_auc_score
            auc = accuracy_score(class_labels, np.rint(class_probs))
        else:
            auc = roc_auc_score(class_labels, class_probs)
        print("AUC for {:30} = {:.3f}".format(class_names[i], auc))

    overall_auc = roc_auc_score(labels.flatten(), probs.flatten())
    print("AUC for {:30} = {:.3f}".format("Overall", overall_auc))
    if P.classes_type == "all":
        keyclasses_auc = roc_auc_score(labels[:,keys].flatten(), probs[:, keys].flatten())
        print("AUC for {:30} = {:.3f}".format("Key pathologies", keyclasses_auc))
    print("-"*40)

ResNet18_ignore_all_2019.12.07.16.49.53
epoch1_itr1700.pt
epoch1_itr2700.pt
epoch1_itr2800.pt
epoch1_itr2900.pt
epoch1_itr3300.pt
epoch2_itr1400.pt
epoch2_itr1500.pt
epoch2_itr1600.pt
epoch2_itr2000.pt
epoch2_itr3400.pt
AUC for No Finding                     = 0.897
AUC for Enlarged Cardiomediastinum     = 0.598
AUC for Cardiomegaly                   = 0.842
AUC for Lung Opacity                   = 0.920
AUC for Lung Lesion                    = 0.189
AUC for Edema                          = 0.910
AUC for Consolidation                  = 0.936
AUC for Pneumonia                      = 0.727
AUC for Atelectasis                    = 0.798
AUC for Pneumothorax                   = 0.850
AUC for Pleural Effusion               = 0.935
AUC for Pleural Other                  = 0.888
AUC for Fracture                       = 1.000
AUC for Support Devices                = 0.936
AUC for Overall                        = 0.870
AUC for Key pathologies                = 0.854
----------------------------

In [None]:
f, axes = plt.subplots(2, 3, figsize=(18,10))
plot_auc(axes, probs, labels, class_names)

In [None]:
P.dataset

In [None]:
for data in dataloaders["valid"]:
    inputs, labels = data["image"], data["targets"]
    break

In [None]:
inputs.shape

In [None]:
labels.shape