In [1]:
import random
import sys
import os
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import transforms

from PIL import Image

import cv2

from efficientnet_pytorch import EfficientNet 

In [2]:
# path to the folder with chexpert dataset which has structure outline in the project description
data_path = '/data'

train_csv_path = data_path + '/chexpert/v1.0/train.csv'
valid_csv_path = data_path + '/chexpert/v1.0/valid.csv'
dir_path = data_path + '/chexpert/v1.0/'

#path to the directory with saved state dictionaries
model_save_dir = '/saves'

# Data preparation

Function that drops lateral image records and irrelevant columns and edits the "Path" column

In [3]:
def dropper(df):
    d = df.copy()
    index = d[d["Frontal/Lateral"] == "Lateral"].index
    d.drop(index=index, axis=0, inplace=True)
    d = d.drop(columns=['Sex','Age','Frontal/Lateral','AP/PA'])
    d.Path = d.Path.str.replace('CheXpert-','chexpert/')
    d = d.reset_index(drop=True)
    return d

In [4]:
valid_csv = pd.read_csv(valid_csv_path, sep=',').fillna(0)
dval = dropper(valid_csv)

train_csv = pd.read_csv(train_csv_path, sep=',').fillna(0)
dtrain = dropper(train_csv)

# image paths as Series
vpath = dval.Path
tpath = dtrain.Path

## Transforms

In [5]:
# efficientnet transform
efnet_transform = transforms.Compose([
    transforms.Resize((456, 456)),
    transforms.ToTensor(),
])

# densenet transform
dnet_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

Function that receives a binary vector and outputs 1d array with indexes of elements which are ones

In [6]:
def findings(z, beta=0.5):

    # assume z is a tensor 
    if type(z) != np.ndarray:
        y = z.squeeze().numpy()
        #print('shape',y.shape)
    else:
        y = z.copy().squeeze()
        #print(y)

    idx = np.argwhere(y > beta)

    if idx.ndim > 1:
        idx = idx.squeeze()

    if idx.ndim == 0:
        idx = np.expand_dims(idx,0)
    return idx

# Hooks

In order for gradCAM to work, one has to store feature maps during forward pass and their gradients during backward pass

In [7]:
# gradients for gradcam are stored here
stored_grads = torch.Tensor([])
stored_fpass = torch.Tensor([])

def bpass_hook(self, gin, gout):
    global stored_grads
    stored_grads = gout

def fpass_hook(self, ten_in, ten_out):
    global stored_fpass
    stored_fpass = ten_out

# Load model

In [None]:
def select(n=0):
    global model, transform
    
    if n==0:
        # Load the DenseNet121 
        model = torchvision.models.densenet121()
        # Get the input dimension of last layer
        kernel_count = model.classifier.in_features
        # Replace last layer with new layer that have num_classes nodes, after that apply Sigmoid to the output
        model.classifier = nn.Sequential(nn.Linear(kernel_count, 14), nn.Sigmoid())
        
        name = 'epoch_1_score_0.81652.pth'
        model.load_state_dict(torch.load(os.path.join(model_save_dir, name))['state_dict'])
        _=model.eval()
        
        # Get module object of last conv layer to attach hook to it
        last_conv_layer = model.net.features.denseblock4.denselayer16.conv2
        transform = dnet_transform

    if n==1:
        model = EfficientNet.from_pretrained('efficientnet-b5')
        model = nn.Sequential(nn.Linear(2048, 14), nn.Sigmoid())
        
        loaded_tuple = torch.load(model_save_dir+'epoch_3_score_0.88577.pth', map_location='cpu')
        state_dict = loaded_tuple['state_dict']

        # rename keys of state dict for efficient net
        keys = list(state_dict.keys())
        for k in keys:
            new_key = k[4:]
            state_dict[new_key] = state_dict.pop(k)
        
        model.load_state_dict(state_dict)
        _ = model.eval()
        
        last_conv_layer = model.net._conv_head
        transform = efnet_transform
    
    # register hooks for gradCAM
    handle_b = last_conv_layer.register_backward_hook(bpass_hook)
    handle_f = last_conv_layer.register_forward_hook(fpass_hook)

In [None]:
select(1)

Fetch random image and extract true labels and "uncertain" labels

In [None]:
# get image and true labels
imid = np.random.randint(dtrain.shape[0])
print(imid)
path1 = data_path+dtrain.Path[imid]

image_orig = Image.open(path1).convert('RGB')
image_transformed = transform(image_orig).unsqueeze(0)

true_labels_vec = dtrain.iloc[imid,1:].to_numpy().astype(int)
true_labels = np.argwhere(true_labels_vec==1).flatten()
true_labels_uncertain = np.argwhere((-1)*true_labels_vec==1).flatten()

# Run gradCAM

In [None]:
# run model
out = model(image_transformed)
out_np = out.detach().numpy().squeeze()

pred_labels_binary = findings(out_np)
pred_labels_p = np.round(out_np[pred_labels_binary], 2)

# generate arguments for backward() function corresponding to classes with p>0.5
l = len(pred_labels_binary)
args = []
for ii in pred_labels_binary:
    backward_arg = torch.zeros(1,14)
    backward_arg[0,ii] = 1
    args.append(backward_arg)
    
# generate gradCAMs
hmap_list = []
cam_list = []

img_hmap = np.transpose(image_transformed.squeeze().numpy(),(1,2,0))

# plot
arglen = len(args)
kwargs = dict(xticks=[],yticks=[])

fig, ax = plt.subplots(1, arglen, figsize=(5*arglen,5), subplot_kw=kwargs)
if arglen==1:
    ax = [ax]
plt.subplots_adjust(wspace=0.1, hspace=0.1)

for k, a in enumerate(args):
    print('{}/{}'.format(k+1,arglen))
    out.backward(a, retain_graph=True)

    gradients = stored_grads[0].clone()
    activations = stored_fpass[0].clone().unsqueeze(0)
    activations = activations.detach()

    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    for j in range(gradients.shape[1]):
        activations[:, j, :, :] *= pooled_gradients[j]

    heatmap = torch.sum(activations, dim=1).squeeze()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= torch.max(heatmap)

    hmap_list.append(heatmap)
    
    #ax.flatten()[k].imshow(heatmap)
    

    hmap = heatmap.numpy()
    heatmap1 = cv2.resize(hmap, (img_hmap.shape[1], img_hmap.shape[0]))
    heatmap1 = np.uint8(-255 * heatmap1 + 255)
    heatmap1 = cv2.applyColorMap(heatmap1, cv2.COLORMAP_JET)

    supim = heatmap1 * 0.002 + img_hmap
    supim = supim / supim.max()
    
    cam_list.append(supim)
    
    ax[k].imshow(supim)
    ax[k].set_title(dtrain.columns[1:][pred_labels_binary[k]])

print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('True labels:', true_labels, dtrain.columns[1:][true_labels].to_list())
print('Uncertainties:', true_labels_uncertain, dtrain.columns[1:][true_labels_uncertain].to_list())
print('Prediction:', pred_labels_binary, dtrain.columns[1:][pred_labels_binary].to_list())
print('Probabilities:', pred_labels_p)