# Importing

In [1]:
!pip install captum

[0m

In [2]:
from captum.attr import (Saliency, GradientShap, DeepLift, IntegratedGradients, NoiseTunnel, GuidedGradCam, GuidedBackprop, Occlusion, LRP, visualization as viz, KernelShap, FeaturePermutation, ShapleyValueSampling)
import json
import numpy as np
import os
import torch
from torchvision import models
import warnings

import utils

# Matplotlib
%matplotlib inline

# Suppress warnings
warnings.filterwarnings('ignore')

# Devices
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
cpu = torch.device("cpu")
print(f'Using {device} for inference')

Using cuda for inference


In [3]:
torch.manual_seed(123)
np.random.seed(123)

# Configurations

## Select environment

In [4]:
# Select environment
SELECTED_ENV = 'kaggle'
DATASET_PATH, CHECKPOINT_PATH = utils.get_path(SELECTED_ENV)

print(f'Dataset path: {DATASET_PATH}')
print(f'Checkpoint path: {CHECKPOINT_PATH}')

Dataset path: /kaggle/input/
Checkpoint path: /kaggle/input/xai-pretrained-blackbox


In [5]:
RUN_ON_TEST_SET_SAMPLE = True
BATCH_SIZE = 12

## Setup Blackbox

In [6]:
def load_blackbox(dataset):
    resnet_path = f'{CHECKPOINT_PATH}/pneumothorax-resnet101.pth'\
                  if dataset == utils.DS_PNEUMOTHORAX else\
                  f'{CHECKPOINT_PATH}/pneumonia-resnet101.pth'
    
    inception_path = f'{CHECKPOINT_PATH}/pneumothorax-inceptionv3.pth'\
                     if dataset == utils.DS_PNEUMOTHORAX else\
                     f'{CHECKPOINT_PATH}/pneumonia-inceptionv3.pth'
    
    # ResNet
    resnet_cp = torch.load(resnet_path, map_location=device)
    blackbox_resnet = models.resnet101()
    blackbox_resnet.load_state_dict(resnet_cp, strict=False)
    blackbox_resnet = torch.nn.DataParallel(blackbox_resnet)
    blackbox_resnet = blackbox_resnet.eval().to(device)

    # InceptionV3
    checkpoint = torch.load(inception_path, map_location=device)
    blackbox_inception = models.inception_v3()
    blackbox_inception.load_state_dict(checkpoint, strict=False)
    blackbox_inception = torch.nn.DataParallel(blackbox_inception)
    blackbox_inception = blackbox_inception.eval().to(device)
    return blackbox_resnet, blackbox_inception

# Experiments

## Utility functions

In [7]:
def get_explanation(xai_config, batch, targets):
    """Return the explanation of an XAI model for an input batch
    """
    xai_model = xai_config['method']
    options = xai_config['options']
    attribution = xai_model.attribute(batch, target=targets, **options)
    return attribution.mean(dim=1)

In [8]:
def normalize(x):
    t = (x - np.min(x)) / (np.max(x) - np.min(x))
    return t

In [9]:
def normalize_sum_to_one(x):
    norm = normalize(x)
    return norm / norm.sum()

In [10]:
def compute_mean_disagreement(exp1, exp2, metric, **options):
    metrics = []
    for e1, e2 in zip(exp1, exp2):
        metrics.append(metric(e1, e2, **options))
    return np.array(metrics).mean()

In [11]:
class ScalarOutputWrapper(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.m = module
    
    def forward(self, x):
        output = self.m(x)
#         _, res = torch.topk(output, 1, dim=1)
        return output

## Functions for generating explanations

In [12]:
def get_explanations(xai_methods_dict, methods, ds_loader, indices, gpu_device):
    explanations = {}
    indices = indices.to(gpu_device)
    
    for method in methods:
        explanations[method] = []

    for method in methods:
        print(f'Getting explanations of {method}...')
        for i, data in enumerate(ds_loader):
            batch = data[1] if len(data) == 4 else data[0]
            batch_size = batch.shape[0]
            if method in xai_methods_dict:
                explanation = get_explanation(xai_methods_dict[method], batch.to(gpu_device), indices[i * batch_size:(i + 1) * batch_size])
                explanation = explanation.cpu().detach()
                explanations[method].append(explanation)
            else:
                explanations[method].append(torch.empty((1, *utils.INPUT_SIZE)).fill_(np.nan))
    
    return explanations

Define starting baseline for GradientShap and Integrated Gradients.

In [13]:
rand_img_dist = torch.rand((1, 3, *utils.INPUT_SIZE)).to(device)
rand_img_dist.shape

torch.Size([1, 3, 128, 128])

In [14]:
def get_xai_configs(blackbox_configs, blackbox_name):
    blackbox_selected = blackbox_name
    blackbox = blackbox_configs[blackbox_selected]['module']
    layer = getattr(blackbox.module, blackbox_configs[blackbox_selected]['layer_name_for_guided_gc'])
    
    # Initialize attribution methods
    sl = Saliency(blackbox)
    guided_gc = GuidedGradCam(blackbox, layer)
    gbp = GuidedBackprop(blackbox)
    lrp = LRP(blackbox)
    dl = DeepLift(ScalarOutputWrapper(blackbox))
    gs = GradientShap(blackbox)
    ig = IntegratedGradients(blackbox)
    occlusion = Occlusion(blackbox)
    
    # Define configurations
    xai_methods = {
      "Saliency": { "method": sl , "options": {}},
      "GuidedGradCam": { "method": guided_gc, "options": {'interpolate_mode': 'area'}},
      "GuidedBackprop": { "method": gbp, "options": {}},
      "LRP": { "method": lrp, "options": {}},
      "GradientShap": { "method": gs, "options": { 'n_samples': 16, 'stdevs': 0.0001, 'baselines': rand_img_dist }},
      "IntegratedGradients": { "method": ig, "options": { 'n_steps' : 100, 'internal_batch_size': 1 }, 'baselines': rand_img_dist },
      "Occlusion": { "method": occlusion, "options": { 'sliding_window_shapes': (3,8, 8), 'strides': (3, 4, 4)}},
    }
    
    if blackbox_name != 'ResNet':
        xai_methods["DeepLift"] = { "method": dl, "options": {}}
    
    return xai_methods, blackbox

In [15]:
def run(blackbox_configs, blackbox_name, loader):
    # Get configurations
    xai_methods, blackbox = get_xai_configs(blackbox_configs, blackbox_name)
    
    # Evaluation
    files = []
    indices = []
    for data in loader:
        if len(data) == 4:
            f, x, m, y = data
        elif len(data) == 2:
            x, y = data
            f, m = [None], None
        output = torch.nn.functional.softmax(blackbox(x.to(device)), dim=1)
        _, index = torch.topk(output, k=1, dim=1)
        indices.append(index.flatten())
        files.extend(list(f))
    indices = torch.concat(indices)
    
    
    # Get gradient-based explanations
    gradient_methods = ['Saliency', 'GuidedGradCam', 'GuidedBackprop', 'LRP', 'IntegratedGradients', 'GradientShap', 'DeepLift']
    gradient_explanations = get_explanations(xai_methods, gradient_methods, loader, indices, device)
    
    # Get perturbation-based explanations
    perturbation_methods = ['Occlusion']
    perturbation_explanations = get_explanations(xai_methods, perturbation_methods, loader, indices, device)
    
    # Combine
    explanations = {**gradient_explanations, **perturbation_explanations}
    methods = gradient_methods + perturbation_methods
    for method in methods:
        exp = explanations[method]
        for i in range(len(exp)):
            exp[i] = exp[i].cpu().detach()
        explanations[method] = torch.cat(exp)
    
    return files, explanations

## Generate explanations

Define data loader and methods

In [16]:
methods = ['Saliency', 'GuidedGradCam', 'GuidedBackprop', 'LRP', 'IntegratedGradients', 'GradientShap', 'DeepLift', 'Occlusion']

In [17]:
def get_loader(dataset):
    loader = None
    ds_loader = utils.get_dataset_loader(SELECTED_ENV, dataset, BATCH_SIZE)
    if RUN_ON_TEST_SET_SAMPLE:
        loader = [next(iter(ds_loader))]
    else:
        loader = ds_loader
    return loader

In [18]:
def get_blackbox_config(dataset):
    blackbox_resnet, blackbox_inception = load_blackbox(utils.DS_PNEUMONIA)
    blackbox_configs = {
        'ResNet': {
            'module': blackbox_resnet,
            'layer_name_for_guided_gc': 'layer4'
        },
        'InceptionV3': {
            'module': blackbox_inception,
            'layer_name_for_guided_gc': 'Mixed_7c'
        }
    }
    return blackbox_configs

In [19]:
pneumothorax_config, pneumothorax_loader = get_blackbox_config(utils.DS_PNEUMOTHORAX), get_loader(utils.DS_PNEUMOTHORAX)
pneumonia_config, pneumonia_loader = get_blackbox_config(utils.DS_PNEUMONIA), get_loader(utils.DS_PNEUMONIA)

In [20]:
%%time
print('----- InceptionV3 explanations -----')
test_files, explanations_pneumothorax_inceptionv3 = run(pneumothorax_config, 'InceptionV3', pneumothorax_loader)
print('----- ResNet101 explanations -----')
_, explanations_pneumothorax_resnet = run(pneumothorax_config, 'ResNet', pneumothorax_loader)

----- InceptionV3 explanations -----
Getting explanations of Saliency...
Getting explanations of GuidedGradCam...
Getting explanations of GuidedBackprop...
Getting explanations of LRP...
Getting explanations of IntegratedGradients...
Getting explanations of GradientShap...
Getting explanations of DeepLift...
Getting explanations of Occlusion...
----- ResNet101 explanations -----
Getting explanations of Saliency...
Getting explanations of GuidedGradCam...
Getting explanations of GuidedBackprop...
Getting explanations of LRP...
Getting explanations of IntegratedGradients...
Getting explanations of GradientShap...
Getting explanations of DeepLift...
Getting explanations of Occlusion...
CPU times: user 3min 5s, sys: 9.43 s, total: 3min 15s
Wall time: 2min 31s


In [21]:
%%time
print('----- InceptionV3 explanations -----')
_, explanations_pneumonia_inceptionv3 = run(pneumonia_config, 'InceptionV3', pneumonia_loader)
print('----- ResNet101 explanations -----')
_, explanations_pneumonia_resnet = run(pneumonia_config, 'ResNet', pneumonia_loader)

----- InceptionV3 explanations -----
Getting explanations of Saliency...
Getting explanations of GuidedGradCam...
Getting explanations of GuidedBackprop...
Getting explanations of LRP...
Getting explanations of IntegratedGradients...
Getting explanations of GradientShap...
Getting explanations of DeepLift...
Getting explanations of Occlusion...
----- ResNet101 explanations -----
Getting explanations of Saliency...
Getting explanations of GuidedGradCam...
Getting explanations of GuidedBackprop...
Getting explanations of LRP...
Getting explanations of IntegratedGradients...
Getting explanations of GradientShap...
Getting explanations of DeepLift...
Getting explanations of Occlusion...
CPU times: user 3min 3s, sys: 7.91 s, total: 3min 11s
Wall time: 2min 28s


## Save explanations

In [23]:
def save_explanations(path, methods, explanations):
    for method in methods:
        torch.save(explanations[method], os.path.join(path, f'{method}.pt'))

In [24]:
SAVED_EXPLANATIONS_PARENT_DIR = 'explanations'
PNEUMOTHORAX_INCEPTIONV3_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'Pneumothorax/InceptionV3')
PNEUMOTHORAX_RESNET_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'Pneumothorax/ResNet')
PNEUMONIA_INCEPTIONV3_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'Pneumonia/InceptionV3')
PNEUMONIA_RESNET_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'Pneumonia/ResNet')

utils.make_dir_if_not_exist(SAVED_EXPLANATIONS_PARENT_DIR)
utils.make_dir_if_not_exist(PNEUMOTHORAX_INCEPTIONV3_DIR)
utils.make_dir_if_not_exist(PNEUMOTHORAX_RESNET_DIR)
utils.make_dir_if_not_exist(PNEUMONIA_INCEPTIONV3_DIR)
utils.make_dir_if_not_exist(PNEUMONIA_RESNET_DIR)

In [25]:
save_explanations(PNEUMOTHORAX_INCEPTIONV3_DIR, methods, explanations_pneumothorax_inceptionv3)
save_explanations(PNEUMOTHORAX_RESNET_DIR, methods, explanations_pneumothorax_resnet)
save_explanations(PNEUMONIA_INCEPTIONV3_DIR, methods, explanations_pneumonia_inceptionv3)
save_explanations(PNEUMONIA_RESNET_DIR, methods, explanations_pneumonia_resnet)

In [26]:
with open(os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'test_files.json'), 'w') as f:
    json.dump(test_files, f)

In [27]:
!zip -o explanations.zip -r explanations

updating: explanations/ (stored 0%)
updating: explanations/Pneumonia/ (stored 0%)
updating: explanations/Pneumonia/ResNet/ (stored 0%)
updating: explanations/Pneumonia/ResNet/Occlusion.pt (deflated 94%)
updating: explanations/Pneumonia/ResNet/Saliency.pt (deflated 11%)
updating: explanations/Pneumonia/ResNet/GuidedBackprop.pt (deflated 7%)
updating: explanations/Pneumonia/ResNet/IntegratedGradients.pt (deflated 4%)
updating: explanations/Pneumonia/ResNet/LRP.pt (deflated 100%)
updating: explanations/Pneumonia/ResNet/DeepLift.pt (deflated 99%)
updating: explanations/Pneumonia/ResNet/GuidedGradCam.pt (deflated 7%)
updating: explanations/Pneumonia/ResNet/GradientShap.pt (deflated 7%)
updating: explanations/Pneumonia/InceptionV3/ (stored 0%)
updating: explanations/Pneumonia/InceptionV3/Occlusion.pt (deflated 94%)
updating: explanations/Pneumonia/InceptionV3/Saliency.pt (deflated 12%)
updating: explanations/Pneumonia/InceptionV3/GuidedBackprop.pt (deflated 9%)
updating: explanations/Pneumon