Reference: https://github.com/Hryniewska/EnsembleXAI/blob/main/notebooks/Imagenet_tests_and_results.ipynb

In [1]:
import torch
import os
from PIL import Image
from torchvision.models import resnet18, ResNet18_Weights
import urllib.request
import json
from captum.attr import IntegratedGradients, Saliency, GradientShap, GuidedBackprop, Deconvolution, InputXGradient, Lime, Occlusion, ShapleyValueSampling, FeatureAblation, KernelShap, NoiseTunnel
from captum.attr import visualization as viz

### Image directory

In [None]:
input_dir = "C:\\Users\\jebcu\\Desktop\\CPSC471_Project\\input"
images_dir = input_dir + "\\images\\"

### Get classes

In [None]:
with urllib.request.urlopen("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json") as url:
    imagenet_classes_dict = json.load(url)
with urllib.request.urlopen("https://raw.githubusercontent.com/LUSSeg/ImageNet-S/main/data/categories/ImageNetS_categories_im50.txt") as url:
    imagenetS50_ids_dict = {str(x).replace("b'", "").replace("\\n'", "").replace("'",""):i+1 for i, x in enumerate(url)}

### Process images

In [None]:
def images_list(image_paths):
    images = []
    for image_path in image_paths:
        for image_name in os.listdir(image_path):
            image = Image.open(image_path + image_name)
            if image.mode == 'L':
                image = image.convert(mode='RGB')
            images.append(image)
    return images

In [None]:
all_images_original = images_list([images_dir + classid + "\\" for classid in imagenetS50_ids_dict])

### Do predictions with ResNet18

In [2]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
resnet_transform = ResNet18_Weights.DEFAULT.transforms()
pipeline = lambda images: torch.stack([resnet_transform(image) for image in images])

In [None]:
proper_data = pipeline(all_images_original)
outputs = model(proper_data)
_, preds = torch.max(outputs, 1)
probs = torch.nn.functional.softmax(outputs, dim=1)

In [None]:
torch.save(proper_data, "ImageNet/proper_data.pt")
torch.save(preds, "ImageNet/preds.pt")

### Single Explanations

In [None]:
single_pred = preds[2].unsqueeze(dim=0)
single_data = proper_data[2].unsqueeze(dim=0)
integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(single_data, target=single_pred, n_steps=200)

In [None]:
transformed_img = resnet_transform(all_images_original[2])

_ = viz.visualize_image_attr(
    attributions_ig.permute(0, 2, 3, 1).tolist()[0],
    transformed_img.permute(1, 2, 0).tolist(),
    method="blended_heat_map",
    sign="all",
    show_colorbar=True
)

### Multiple explanations

In [None]:
multiple_pred = preds[0:2]
multiple_data = proper_data[0:2]
integrated_gradients = IntegratedGradients(model)
attributions_ig = integrated_gradients.attribute(multiple_data, target=multiple_pred, n_steps=200)

In [None]:
attributions_ig.shape

In [None]:
transformed_img = resnet_transform(all_images_original[0])

_ = viz.visualize_image_attr(
    attributions_ig.permute(0, 2, 3, 1).tolist()[0],
    transformed_img.permute(1, 2, 0).tolist(),
    method="blended_heat_map",
    sign="all",
    show_colorbar=True
)

In [None]:
transformed_img = resnet_transform(all_images_original[1])

_ = viz.visualize_image_attr(
    attributions_ig.permute(0, 2, 3, 1).tolist()[1],
    transformed_img.permute(1, 2, 0).tolist(),
    method="blended_heat_map",
    sign="all",
    show_colorbar=True
)

### Create superpixels
Reference: https://www.kaggle.com/code/sukanyabag/lime-model-explainability-testing-pytorch and https://github.com/marcotcr/lime/blob/master/lime/lime_image.py

In [None]:
from skimage.segmentation import quickshift
import matplotlib.pyplot as plt
import numpy as np

In [None]:
proper_masks = [quickshift(image, kernel_size=4, max_dist=200, ratio=0.2, channel_axis=0) for image in proper_data]

In [None]:
proper_masks = [torch.tensor(m) for m in proper_masks]

In [None]:
plt.imshow(proper_data[0][0])

In [None]:
plt.imshow(proper_masks[0])

In [None]:
proper_masks = torch.stack(proper_masks)

In [None]:
plt.imshow(proper_data[0][0])

In [None]:
plt.imshow(proper_masks[0])

In [None]:
torch.save(proper_masks, "ImageNet/proper_masks.pt")

Above code took 2 minutes 39.4 seconds. ~3 second per image

### Attributions (default parameters)

In [8]:
# Load tensors
proper_data = torch.load("ImageNet/proper_data.pt").cuda()
preds = torch.load("ImageNet/preds.pt").cuda()
proper_masks = torch.load("ImageNet/proper_masks.pt").cuda()

In [None]:
def get_attributions(explainer, num_batches = 50):
    attributions = None
    for i in range(num_batches):
        batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
        if attributions is None:
            attributions = explainer.attribute(proper_data[batch_slice], target=preds[batch_slice])
        else:
            temp = explainer.attribute(proper_data[batch_slice], target=preds[batch_slice])
            attributions = torch.cat((attributions, temp), dim = 0)
    return attributions

In [None]:
model_cuda = model.cuda()

In [None]:
integrated_gradients = IntegratedGradients(model_cuda)
attributions_ig = get_attributions(integrated_gradients)
torch.save(attributions_ig, "ImageNet/attributions_ig.pt")

In [None]:
saliency = Saliency(model_cuda)
attributions_s = get_attributions(saliency)
torch.save(attributions_s, "ImageNet/attributions_s.pt")

In [None]:
num_batches = 50

In [None]:
gradient_shap = GradientShap(model_cuda)
attributions_gs = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_gs is None:
        attributions_gs = gradient_shap.attribute(proper_data[batch_slice].cuda(), torch.zeros_like(proper_data[0:1]), target=preds[batch_slice].cuda())
    else:
        temp = gradient_shap.attribute(proper_data[batch_slice].cuda(), torch.zeros_like(proper_data[0:1]), target=preds[batch_slice].cuda())
        attributions_gs = torch.cat((attributions_gs, temp), dim = 0)

torch.save(attributions_gs, "ImageNet/attributions_gs.pt")

In [None]:
guided_backprop = GuidedBackprop(model_cuda)
attributions_gb = get_attributions(guided_backprop)
torch.save(attributions_gb, "ImageNet/attributions_gb.pt")

In [None]:
deconvolution = Deconvolution(model_cuda)
attributions_d = get_attributions(deconvolution)
torch.save(attributions_d, "ImageNet/attributions_d.pt")

In [None]:
input_x_gradient = InputXGradient(model_cuda)
attributions_ixg = get_attributions(input_x_gradient)
torch.save(attributions_ixg, "ImageNet/attributions_ixg.pt")

In [None]:
# Need feature mask
lime = Lime(model_cuda)
attributions_l = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_l is None:
        attributions_l = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
    else:
        temp = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
        attributions_l = torch.cat((attributions_l, temp), dim = 0)

torch.save(attributions_l, "ImageNet/attributions_l.pt")

Above took 36.4s

In [None]:
occulsion = Occlusion(model_cuda)
attributions_o = None

# Using sliding window size (3, 15, 15) and strides = (3, 8, 8) as used in EnsembleXAI
for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_o is None:
        attributions_o = occulsion.attribute(proper_data[batch_slice].cuda(), (3, 15, 15), target=preds[batch_slice].cuda(), strides = (3, 8, 8))
    else:
        temp = occulsion.attribute(proper_data[batch_slice].cuda(), (3, 15, 15), target=preds[batch_slice].cuda(), strides = (3, 8, 8))
        attributions_o = torch.cat((attributions_o, temp), dim = 0)

torch.save(attributions_o, "ImageNet/attributions_o.pt")

Above took 3 minutes and 31.2 seconds

In [None]:
# Need feature mask
shapley_value_sampling = ShapleyValueSampling(model_cuda)
attributions_svs = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_svs is None:
        attributions_svs = shapley_value_sampling.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
    else:
        temp = shapley_value_sampling.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
        attributions_svs = torch.cat((attributions_svs, temp), dim = 0)

torch.save(attributions_svs, "ImageNet/attributions_svs.pt")

Above took 6 minute 41.8 seconds

In [None]:
# Need feature mask
feature_ablation = FeatureAblation(model_cuda)
attributions_fa = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_fa is None:
        attributions_fa = feature_ablation.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
    else:
        temp = feature_ablation.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
        attributions_fa = torch.cat((attributions_fa, temp), dim = 0)

torch.save(attributions_fa, "ImageNet/attributions_fa.pt")

Above took 14.8 seconds

In [None]:
# Need feature mask
kernel_shap = KernelShap(model_cuda)
attributions_ks = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_ks is None:
        attributions_ks = kernel_shap.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
    else:
        temp = kernel_shap.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
        attributions_ks = torch.cat((attributions_ks, temp), dim = 0)

torch.save(attributions_ks, "ImageNet/attributions_ks.pt")

Above took 31.0 seconds

In [None]:
%%time
noise_tunnel = NoiseTunnel(integrated_gradients) # base on EnsembleXAI
attributions_nt = get_attributions(noise_tunnel, num_batches = 500)
torch.save(attributions_nt, "ImageNet/attributions_nt.pt")

### Normalization

In [30]:
from EnsembleXAI.Normalization import mean_var_normalize

In [None]:
attributions = {
    'attributions_ig': torch.load('ImageNet/attributions_ig.pt'),
    'attributions_s': torch.load('ImageNet/attributions_s.pt'),
    'attributions_gs': torch.load('ImageNet/attributions_gs.pt'),
    'attributions_gb': torch.load('ImageNet/attributions_gb.pt'),
    'attributions_d': torch.load('ImageNet/attributions_d.pt'),
    'attributions_ixg': torch.load('ImageNet/attributions_ixg.pt'),
    'attributions_l': torch.load('ImageNet/attributions_l.pt'),
    'attributions_o': torch.load('ImageNet/attributions_o.pt'),
    'attributions_svs': torch.load('ImageNet/attributions_svs.pt'),
    'attributions_fa': torch.load('ImageNet/attributions_fa.pt'),
    'attributions_ks': torch.load('ImageNet/attributions_ks.pt'),
    'attributions_nt': torch.load('ImageNet/attributions_nt.pt'),
}

In [None]:
normalized_attributions = {attr: mean_var_normalize(attributions[attr]) for attr in attributions}

### EnsembleXAI

In [29]:
from EnsembleXAI.Ensemble import normEnsembleXAI

In [None]:
explanations = torch.stack([normalized_attributions[attr] for attr in normalized_attributions], dim=1)

In [None]:
agg = normEnsembleXAI(explanations.detach(), aggregating_func='avg')

In [None]:
torch.save(agg, "ImageNet/agg.pt")