# Importing

In [1]:
from captum.attr import (Saliency, GradientShap, DeepLift, IntegratedGradients, NoiseTunnel, GuidedGradCam, GuidedBackprop, Occlusion, LRP, visualization as viz, KernelShap, FeaturePermutation, ShapleyValueSampling)
import json
import numpy as np
import os
import pandas as pd
import torch
import torchvision
from torchvision import models, transforms
import warnings

from pneumothorax_image_dataset import PneumothoraxImageDataset
import utils

# Matplotlib
%matplotlib inline

# Suppress warnings
warnings.filterwarnings('ignore')

# Devices
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
cpu = torch.device("cpu")
print(f'Using {device} for inference')

Using cpu for inference


In [2]:
torch.manual_seed(123)
np.random.seed(123)

# Setup datasets and configurations

## Select environment

In [3]:
# Select environment
DATASET_PATH, CHECKPOINT_PATH = utils.get_path('local')

print(f'Dataset path: {DATASET_PATH}')
print(f'Checkpoint path: {CHECKPOINT_PATH}')

Dataset path: pneumothorax-chest-xray-dataset/siim-acr-pneumothorax
Checkpoint path: pretrained_weights/


## Load test set

In [4]:
RUN_ON_TEST_SET_SAMPLE = True
BATCH_SIZE = 1

In [5]:
test_data = pd.read_csv(os.path.join(DATASET_PATH, 'stage_1_test_images.csv'))
test_data['images'] = test_data['new_filename'].apply(lambda x: os.path.join(DATASET_PATH, 'png_images', x))
test_data['masks'] = test_data['new_filename'].apply(lambda x: os.path.join(DATASET_PATH, 'png_masks', x))

filenames = test_data['new_filename'].tolist()
images = test_data['images'].tolist()
masks = test_data['masks'].tolist()
targets = test_data['has_pneumo'].tolist()

In [6]:
dataset = PneumothoraxImageDataset(filenames, images, targets, masks, utils.transform, utils.mask_transform)
ds_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False)

## Setup Blackbox

In [7]:
# ResNet
resnet_cp = torch.load(f'{CHECKPOINT_PATH}/resnet_cp.pth', map_location=device)
blackbox_resnet = models.resnet101()
blackbox_resnet.load_state_dict(resnet_cp)
blackbox_resnet = torch.nn.DataParallel(blackbox_resnet)
blackbox_resnet = blackbox_resnet.eval().to(device)

# InceptionV3
checkpoint = torch.load(f'{CHECKPOINT_PATH}/inception_cp.pth', map_location=device)
blackbox_inception = models.inception_v3()
blackbox_inception.load_state_dict(checkpoint)
blackbox_inception = torch.nn.DataParallel(blackbox_inception)
blackbox_inception = blackbox_inception.eval().to(device)

In [8]:
blackbox_configs = {
    'ResNet': {
        'module': blackbox_resnet,
        'layer_name_for_guided_gc': 'layer4'
    },
    'InceptionV3': {
        'module': blackbox_inception,
        'layer_name_for_guided_gc': 'Mixed_7c'
    }
}

# Experiments

## Utility functions

In [9]:
def get_explanation(xai_config, batch, targets):
    """Return the explanation of an XAI model for an input batch
    """
    xai_model = xai_config['method']
    options = xai_config['options']
    attribution = xai_model.attribute(batch, target=targets, **options)
    return attribution.mean(dim=1)

In [10]:
def normalize(x):
    t = (x - np.min(x)) / (np.max(x) - np.min(x))
    return t

In [11]:
def normalize_sum_to_one(x):
    norm = normalize(x)
    return norm / norm.sum()

In [12]:
def compute_mean_disagreement(exp1, exp2, metric, **options):
    metrics = []
    for e1, e2 in zip(exp1, exp2):
        metrics.append(metric(e1, e2, **options))
    return np.array(metrics).mean()

In [13]:
class ScalarOutputWrapper(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.m = module
    
    def forward(self, x):
        output = self.m(x)
#         _, res = torch.topk(output, 1, dim=1)
        return output

## Functions for generating explanations

In [14]:
def get_explanations(xai_methods_dict, methods, ds_loader, indices, gpu_device):
    explanations = {}
    indices = indices.to(gpu_device)
    
    for method in methods:
        explanations[method] = []

    for method in methods:
        print(f'Getting explanations of {method}...')
        for i, data in enumerate(ds_loader):
            batch = data[1]
            batch_size = batch.shape[0]
            if method in xai_methods_dict:
                explanation = get_explanation(xai_methods_dict[method], batch.to(gpu_device), indices[i * batch_size:(i + 1) * batch_size])
                explanation = explanation.cpu().detach()
                explanations[method].append(explanation)
            else:
                explanations[method].append(torch.empty((1, *utils.INPUT_SIZE)).fill_(np.nan))
    
    return explanations

Define starting baseline for GradientShap and Integrated Gradients.

In [15]:
rand_img_dist = torch.rand((1, 3, *utils.INPUT_SIZE)).to(device)
rand_img_dist.shape

torch.Size([1, 3, 128, 128])

In [16]:
def get_xai_configs(blackbox_configs, blackbox_name):
    blackbox_selected = blackbox_name
    blackbox = blackbox_configs[blackbox_selected]['module']
    layer = getattr(blackbox.module, blackbox_configs[blackbox_selected]['layer_name_for_guided_gc'])
    
    # Initialize attribution methods
    sl = Saliency(blackbox)
    guided_gc = GuidedGradCam(blackbox, layer)
    gbp = GuidedBackprop(blackbox)
    lrp = LRP(blackbox)
    dl = DeepLift(ScalarOutputWrapper(blackbox))
    gs = GradientShap(blackbox)
    ig = IntegratedGradients(blackbox)
    occlusion = Occlusion(blackbox)
    
    # Define configurations
    xai_methods = {
      "Saliency": { "method": sl , "options": {}},
      "GuidedGradCam": { "method": guided_gc, "options": {'interpolate_mode': 'area'}},
      "GuidedBackprop": { "method": gbp, "options": {}},
      "LRP": { "method": lrp, "options": {}},
      "GradientShap": { "method": gs, "options": { 'n_samples': 16, 'stdevs': 0.0001, 'baselines': rand_img_dist }},
      "IntegratedGradients": { "method": ig, "options": { 'n_steps' : 100, 'internal_batch_size': 1 }, 'baselines': rand_img_dist },
      "Occlusion": { "method": occlusion, "options": { 'sliding_window_shapes': (3,8, 8), 'strides': (3, 4, 4)}},
    }
    
    if blackbox_name != 'ResNet':
        xai_methods["DeepLift"] = { "method": dl, "options": {}}
    
    return xai_methods, blackbox

In [17]:
def run(blackbox_configs, blackbox_name, loader):
    # Get configurations
    xai_methods, blackbox = get_xai_configs(blackbox_configs, blackbox_name)
    
    # Evaluation
    files = []
    indices = []
    for data in loader:
        f, x, m, y = data
        output = torch.nn.functional.softmax(blackbox(x.to(device)), dim=1)
        _, index = torch.topk(output, k=1, dim=1)
        indices.append(index.flatten())
        files.extend(list(f))
    indices = torch.concat(indices)
    
    
    # Get gradient-based explanations
    gradient_methods = ['Saliency', 'GuidedGradCam', 'GuidedBackprop', 'LRP', 'IntegratedGradients', 'GradientShap', 'DeepLift']
    gradient_explanations = get_explanations(xai_methods, gradient_methods, loader, indices, device)
    
    # Get perturbation-based explanations
    perturbation_methods = ['Occlusion']
    perturbation_explanations = get_explanations(xai_methods, perturbation_methods, loader, indices, device)
    
    # Combine
    explanations = {**gradient_explanations, **perturbation_explanations}
    methods = gradient_methods + perturbation_methods
    for method in methods:
        exp = explanations[method]
        for i in range(len(exp)):
            exp[i] = exp[i].cpu().detach()
        explanations[method] = torch.cat(exp)
    
    return files, explanations

## Generate explanations

Define data loader and methods

In [18]:
loader = None
if RUN_ON_TEST_SET_SAMPLE:
    loader = [next(iter(ds_loader))]
else:
    loader = ds_loader
    
methods = ['Saliency', 'GuidedGradCam', 'GuidedBackprop', 'LRP', 'IntegratedGradients', 'GradientShap', 'DeepLift', 'Occlusion']

Generate explanations for InceptionV3

In [19]:
%%time
test_files, explanations_inceptionv3 = run(blackbox_configs, 'InceptionV3', loader)

Getting explanations of Saliency...
Getting explanations of GuidedGradCam...
Getting explanations of GuidedBackprop...
Getting explanations of LRP...
Getting explanations of IntegratedGradients...
Getting explanations of GradientShap...
Getting explanations of DeepLift...
Getting explanations of Occlusion...
CPU times: total: 3min 51s
Wall time: 1min 27s


Generate explanations for ResNet

In [20]:
%%time
_, explanations_resnet = run(blackbox_configs, 'ResNet', loader)

Getting explanations of Saliency...
Getting explanations of GuidedGradCam...
Getting explanations of GuidedBackprop...
Getting explanations of LRP...
Getting explanations of IntegratedGradients...
Getting explanations of GradientShap...
Getting explanations of DeepLift...
Getting explanations of Occlusion...
CPU times: total: 10min 42s
Wall time: 2min 12s


## Save explanations

In [21]:
def make_dir_if_not_exist(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [22]:
def save_explanations(path, methods, explanations):
    for method in methods:
        torch.save(explanations[method], os.path.join(path, f'{method}.pt'))

In [23]:
SAVED_EXPLANATIONS_PARENT_DIR = 'explanations'
INCEPTIONV3_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'InceptionV3')
RESNET_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'ResNet')

make_dir_if_not_exist(SAVED_EXPLANATIONS_PARENT_DIR)
make_dir_if_not_exist(INCEPTIONV3_DIR)
make_dir_if_not_exist(RESNET_DIR)

In [24]:
save_explanations(INCEPTIONV3_DIR, methods, explanations_inceptionv3)
save_explanations(RESNET_DIR, methods, explanations_resnet)

In [25]:
with open(os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'test_files.json'), 'w') as f:
    json.dump(test_files, f)