# Importing

In [1]:
import torch
from captum.attr import (Saliency, GradientShap, DeepLift, IntegratedGradients, NoiseTunnel, GuidedGradCam, GuidedBackprop, Occlusion, LRP, visualization as viz, KernelShap, FeaturePermutation, ShapleyValueSampling)
import torchvision
from torchvision import transforms
import json
import requests
from io import BytesIO
import warnings
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as c_map
from torchvision import models
import seaborn as sns
import scipy
import skimage
warnings.filterwarnings('ignore')
import gc
from pneumothorax_image_dataset import PneumothoraxImageDataset
%matplotlib inline

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
cpu = torch.device("cpu")
print(f'Using {device} for inference')

Using cpu for inference


In [2]:
torch.manual_seed(123)
np.random.seed(123)

# Setup datasets and configurations

In [3]:
USE_KAGGLE = True

# Dataset
KAGGLE_DATASET_PATH = '/kaggle/input/pneumothorax-chest-xray-images-and-masks'
COLAB_DATASET_PATH = '/content/datasets'
DATASET_PATH = 'data'

# Checkpoint
KAGGLE_CHECKPOINT_PATH = '/kaggle/input/xai-pretrained-blackbox'
COLAB_CHECKPOINT_PATH = '/content/datasets/'
CHECKPOINT_PATH = 'xai-pretrained-blackbox'

In [4]:
INPUT_SIZE = (128, 128)
RUN_ON_TEST_SET_SAMPLE = True

In [5]:
resize_image = transforms.Compose([
  transforms.Resize(INPUT_SIZE),
])

transform = transforms.Compose([            
  transforms.ToTensor(),
  resize_image,
  transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
  )
])

mask_transform = transforms.Compose([            
  transforms.ToTensor(),
  resize_image,
])

In [6]:
import os
import pandas as pd
# train_data = pd.read_csv(os.path.join(main_path, 'stage_1_train_images.csv'))
DATASET_NAME = 'siim-acr-pneumothorax'
test_data = pd.read_csv(os.path.join(DATASET_PATH, DATASET_NAME, 'stage_1_test_images.csv'))
# Just load only abnormal cases
# test_abnormal = test_data.loc[test_data.has_pneumo == 1]

# Load all cases
test_abnormal = test_data

# train_data['images'] = train_data['new_filename'].apply(lambda x: os.path.join(main_path, 'png_images', x))
# train_data['masks'] = train_data['new_filename'].apply(lambda x: os.path.join(main_path, 'png_masks', x))

test_abnormal['images'] = test_abnormal['new_filename'].apply(lambda x: os.path.join(DATASET_PATH, DATASET_NAME, 'png_images', x))
test_abnormal['masks'] = test_abnormal['new_filename'].apply(lambda x: os.path.join(DATASET_PATH, DATASET_NAME, 'png_masks', x))

images = test_abnormal['images'].tolist()
masks = test_abnormal['masks'].tolist()
targets = test_abnormal['has_pneumo'].tolist()

In [7]:
dataset = PneumothoraxImageDataset(images, targets, masks, transform=transform, mask_transform=mask_transform)
ds_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False)

## Setup Blackbox & XAI methods

In [8]:
# ResNet
resnet_cp = torch.load(f'{CHECKPOINT_PATH}/resnet_cp.pth', map_location=device)
blackbox_resnet = models.resnet101()
blackbox_resnet.load_state_dict(resnet_cp)
# for param in blackbox_resnet.parameters():
#     param.require_grad = False
blackbox_resnet = torch.nn.DataParallel(blackbox_resnet)
blackbox_resnet = blackbox_resnet.eval().to(device)

# InceptionV3
checkpoint = torch.load(f'{CHECKPOINT_PATH}/inception_cp.pth', map_location=device)
blackbox_inception = models.inception_v3()
blackbox_inception.load_state_dict(checkpoint)
# for parameter in blackbox_inception.parameters():
#     parameter.requires_grad = False
blackbox_inception = torch.nn.DataParallel(blackbox_inception)
blackbox_inception = blackbox_inception.eval().to(device)


In [9]:
blackbox_configs = {
    'ResNet': {
        'module': blackbox_resnet,
        'layer_name_for_guided_gc': 'layer4'
    },
    'InceptionV3': {
        'module': blackbox_inception,
        'layer_name_for_guided_gc': 'Mixed_7c'
    }
}

In [10]:
class ScalarOutputWrapper(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.m = module
    
    def forward(self, x):
        output = self.m(x)
        _, res = torch.topk(output, 1, dim=1)
        return res

# Experiments

## Utility functions

In [11]:
def get_explanation(xai_config, batch, targets):
    """Return the explanation of an XAI model for an input batch
    """
    xai_model = xai_config['method']
    options = xai_config['options']
    attribution = xai_model.attribute(batch, target=targets, **options)
    return attribution.mean(dim=1)

In [12]:
def normalize(x):
    t = (x - np.min(x)) / (np.max(x) - np.min(x))
    return t

In [13]:
def normalize_sum_to_one(x):
    norm = normalize(x)
    return norm / norm.sum()

In [14]:
def compute_mean_disagreement(exp1, exp2, metric, **options):
    metrics = []
    for e1, e2 in zip(exp1, exp2):
        metrics.append(metric(e1, e2, **options))
    return np.array(metrics).mean()

## Evaluation

Evaluate the blackbox over the test set.

In [15]:
loader = None
if RUN_ON_TEST_SET_SAMPLE:
    loader = [next(iter(ds_loader))]
else:
    loader = ds_loader
    
methods = ['Saliency', 'GuidedGradCam', 'GuidedBackprop', 'LRP', 'IntegratedGradients', 'GradientShap', 'DeepLift', 'Occlusion']

## Generating explanations

Generate explanations for all test examples for each interested methods.

In [16]:
def get_explanations(xai_methods_dict, methods, ds_loader, indices, gpu_device):
    explanations = {}
    indices = indices.to(gpu_device)
    
    for method in methods:
        explanations[method] = []

    for method in methods:
        for i, data in enumerate(ds_loader):
            batch = data[0]
            batch_size = batch.shape[0]
            if method in xai_methods_dict:
                explanation = get_explanation(xai_methods_dict[method], batch.to(gpu_device), indices[i * batch_size:(i + 1) * batch_size])
                explanation = explanation.cpu().detach()
                explanations[method].append(explanation)
            else:
                explanations[method].append(torch.empty_like(batch).fill_(np.nan))
    
    return explanations

In [17]:
rand_img_dist = torch.rand((1, 3, *INPUT_SIZE)).to(device)
rand_img_dist.shape

torch.Size([1, 3, 128, 128])

In [18]:
def get_xai_configs(blackbox_configs, blackbox_name):
    blackbox_selected = blackbox_name
    blackbox = blackbox_configs[blackbox_selected]['module']
    layer = getattr(blackbox.module, blackbox_configs[blackbox_selected]['layer_name_for_guided_gc'])
    
    # Initialize attribution methods
    sl = Saliency(blackbox)
    guided_gc = GuidedGradCam(blackbox, layer)
    gbp = GuidedBackprop(blackbox)
    lrp = LRP(blackbox)
    dl = DeepLift(ScalarOutputWrapper(blackbox))
    gs = GradientShap(blackbox)
    ig = IntegratedGradients(blackbox)
    occlusion = Occlusion(blackbox)
    
    # Define configurations
    xai_methods = {
      "Saliency": { "method": sl , "options": {}},
      "GuidedGradCam": { "method": guided_gc, "options": {'interpolate_mode': 'area'}},
      "GuidedBackprop": { "method": gbp, "options": {}},
      "LRP": { "method": lrp, "options": {}},
      "GradientShap": { "method": gs, "options": { 'n_samples': 16, 'stdevs': 0.0001, 'baselines': rand_img_dist }},
      "IntegratedGradients": { "method": ig, "options": { 'n_steps' : 100, 'internal_batch_size': 1 }, 'baselines': rand_img_dist },
      "Occlusion": { "method": occlusion, "options": { 'sliding_window_shapes': (3,8, 8), 'strides': (3, 4, 4)}},
    }
    
#     if blackbox_name != 'ResNet':
#         xai_methods["DeepLift"] = { "method": dl, "options": {}}
    
    return xai_methods, blackbox

In [19]:
def run(blackbox_configs, blackbox_name, loader):
    # Get configurations
    xai_methods, blackbox = get_xai_configs(blackbox_configs, blackbox_name)
    
    # Evaluation
    indices = []
    for data in loader:
        x, m, y = data
        output = torch.nn.functional.softmax(blackbox(x.to(device)), dim=1)
        _, index = torch.topk(output, k=1, dim=1)
        indices.append(index.flatten())
    indices = torch.concat(indices)
    
    # Get gradient-based explanations
    gradient_methods = ['Saliency', 'GuidedGradCam', 'GuidedBackprop', 'LRP', 'IntegratedGradients', 'GradientShap', 'DeepLift']
    gradient_explanations = get_explanations(xai_methods, gradient_methods, loader, indices, device)
    
    # Get perturbation-based explanations
    perturbation_methods = ['Occlusion']
    perturbation_explanations = get_explanations(xai_methods, perturbation_methods, loader, indices, device)
    
    # Combine
    explanations = {**gradient_explanations, **perturbation_explanations}
    methods = gradient_methods + perturbation_methods
    for method in methods:
        exp = explanations[method]
        for i in range(len(exp)):
            exp[i] = exp[i].cpu().detach()
        explanations[method] = torch.cat(exp)
    
    return explanations

In [20]:
%%time
explanations_inceptionv3 = run(blackbox_configs, 'InceptionV3', loader)

CPU times: total: 3min 41s
Wall time: 39 s


In [21]:
%%time
explanations_resnet = run(blackbox_configs, 'ResNet', loader)

CPU times: total: 9min 27s
Wall time: 1min 41s


In [22]:
def make_dir_if_not_exist(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [23]:
def save_explanations(path, methods, explanations):
    for method in methods:
        torch.save(explanations[method], os.path.join(path, f'{method}.pt'))

In [24]:
SAVED_EXPLANATIONS_PARENT_DIR = 'explanations_test'
INCEPTIONV3_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'InceptionV3')
RESNET_DIR = os.path.join(SAVED_EXPLANATIONS_PARENT_DIR, 'ResNet')

make_dir_if_not_exist(SAVED_EXPLANATIONS_PARENT_DIR)
make_dir_if_not_exist(INCEPTIONV3_DIR)
make_dir_if_not_exist(RESNET_DIR)

In [25]:
save_explanations(INCEPTIONV3_DIR, methods, explanations_inceptionv3)
save_explanations(RESNET_DIR, methods, explanations_resnet)