In [1]:
import random
import sys
import os
import types
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import transforms

from PIL import Image

import cv2

from lime import lime_image
from skimage.segmentation import mark_boundaries

In [None]:
# path to the folder with chexpert dataset which has structure outline in the project description
data_path = '/home/tu-serbin/data'

train_csv_path = data_path + '/chexpert/v1.0/train.csv'
valid_csv_path = data_path + '/chexpert/v1.0/valid.csv'
dir_path = data_path + '/chexpert/v1.0/'

#path to the directory with saved state dictionaries
model_save_dir = '/home/tu-serbin/group/igloo/alex/saves/'

# Data preparation

Function that drops lateral image records and irrelevant columns, replaces -1 with 0 and edits the "Path" column

In [4]:
def dropper(df):
    d = df.copy()
    index = d[d["Frontal/Lateral"] == "Lateral"].index
    d.drop(index=index, axis=0, inplace=True)
    d = d.drop(columns=['Sex','Age','Frontal/Lateral','AP/PA'])
    d = d.replace(-1.0,0)
    d.Path = d.Path.str.replace('CheXpert-','chexpert/')
    d = d.reset_index(drop=True)
    return d

In [5]:
valid_csv = pd.read_csv(valid_csv_path, sep=',').fillna(0)
dval = dropper(valid_csv)

train_csv = pd.read_csv(train_csv_path, sep=',').fillna(0)
dtrain = dropper(train_csv)

index_ = dval.sum().drop('Path') < 5
bad_cols = dval.drop(columns='Path').columns[index_.to_list()]
good_cols = dval.columns.drop(bad_cols.to_list()+['Path'])

# image paths as Series
vpath = dval.Path
tpath = dtrain.Path

# Transform

In [None]:
dnet_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

# Load DenseNet-169

In [9]:
# Load the DenseNet121 
model = torchvision.models.densenet121()
# Get the input dimension of last layer
kernel_count = model.classifier.in_features
# Replace last layer with new layer that have num_classes nodes, after that apply Sigmoid to the output
model.classifier = nn.Sequential(nn.Linear(kernel_count, 14), nn.Sigmoid())

name = 'epoch_1_score_0.81652.pth'
model.load_state_dict(torch.load(os.path.join(model_save_dir, name))['state_dict'])
_=model.eval()


# Class for generating LIME images

In [10]:
class lime_explainer():
    def __init__(self, mod):
        self.model = mod
        
    def explain_please(self, im:Image.Image, transform=None, cuda=False):
        im_r = im
        
        self.c = cuda
            
        if transform:
            im_r = transform(im_r)
            
        #print('image shape', im_r.shape) #3,224,224
        self.model.eval()     
        #print(self.model(im_r.unsqueeze(0)))
              
        explainer = lime_image.LimeImageExplainer()
        explanation = explainer.explain_instance(np.transpose(im_r.numpy(),(1,2,0)).astype(np.double), #resized image as numpy array 
                                         self.predictor, # classification function
                                         top_labels=3, 
                                         hide_color=0, 
                                         num_samples=1000) # number of images that will be sent to classification function
        return explanation
    
    def predictor(self, im:np.array):
        #from IPython.core.debugger import Tracer; Tracer()() 
        l = im.shape[0]
        im1 = np.transpose(im, (0,3,1,2))
        im1 = torch.Tensor(im1)
        
        logits = self.model(im1)       
        probs = F.softmax(logits, dim=1)
        return probs.cpu().detach().numpy()
    

Function which takes image path and produces image segmentation using LIME

In [11]:
def dnet_explain(impath):
    
    image = Image.open(impath).convert('RGB')

    lime_ = lime_explainer(model)
    
    exp = lime_.explain_please(image, dnet_transform, cuda=False)

    temp, mask = exp.get_image_and_mask(exp.top_labels[0], positive_only=False, num_features=10,\
                                        hide_rest=False)
    img_boundry1 = mark_boundaries(temp, mask)
    
    plt.imshow(img_boundry1)
    
    plt.axis('off')

Pick random image and run LIME 

In [None]:
image_id = np.random.randint(dtrain.shape[0])
path1 = data_path+dtrain.Path[image_id]

dnet_explain(path1)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))