In [1]:
from IPython.display import clear_output
import torch 
import torchvision 
import quantus
clear_output()

In [2]:
import pathlib
import random
import copy
import gc
import numpy as np
# import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
# import seaborn as sns
import torch
import torchvision
import quantus
import warnings
from torchvision import datasets, transforms
import os
import torch.nn as nn
import json
# sns.set() 

# Enable GPU. 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
clear_output()

In [3]:
# load x_batch and y_batch
transform = transforms.Compose(
    [transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

data_dir = '../data'
image_test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'dump_test'), transform)
dataloader = torch.utils.data.DataLoader(image_test_dataset, batch_size=1)
test_dataset_size = len(image_test_dataset)

x_batch = torch.empty(0, 3, 224, 224, dtype=torch.float32)
y_batch = torch.empty(0, dtype=torch.uint8)

for data in dataloader:
    image, label = data
    x_batch = torch.cat([x_batch, image], dim=0)
    y_batch = torch.cat([y_batch, label], dim=0)

x_batch, y_batch = x_batch.to(device), y_batch.to(device)
x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

In [4]:
model = torchvision.models.resnet50()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load('../data/dtd_state_dict'))
model = model.to(device)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
# Load explanations.
layer = 4
explanations = {
    "Grad_CAM": [],
    "Grad_CAM++": [],
    "Score_CAM": [],
    "Relevance_CAM": [],
}
for method, batch in explanations.items():
    dir = f'data/R_CAM_results_layer{layer}/np/heatmaps/{method}/'
    files = os.listdir(dir)
    for f in files:
        file_np = np.load(dir+f)
        batch.append(file_np)

explanations = {k: np.array(v) for k, v in explanations.items()}


In [6]:
from Multi_CAM import get_CAM
from tqdm import tqdm

def explainer_wrapper(**kwargs):
    """Wrapper for explainer functions."""
    inputs = kwargs["inputs"]
    targets = kwargs["targets"]
    method = kwargs["method"]
    layer = kwargs["layer"]
    type = kwargs["type"]
    size = inputs.shape[0]
    # print(f"inside explainer_wrapper;\nsize={size};\ninputs={inputs.shape}\n{inputs}\ntargets={targets.shape}\n{targets}")
    
    if(method=="Grad_CAM++"): method = "Grad_CAMpp"

    if(size>1):
        all_CAMS = []
        for s in tqdm(range(size)):
            input = inputs[s]
            target = targets[s]
            s_CAM = get_CAM(method=method, layer=layer, input=input, target=target)
            all_CAMS.append(s_CAM)

        return torch.tensor(all_CAMS)

    elif(size==1):
        input = inputs[0]
        target = targets
        # print(f"input={input.shape}\ntarget={target}")
        s_CAM = get_CAM(method=method, layer=layer, input=input, target=target)
        s_CAM = np.reshape(s_CAM, (1, 224, 224))
        # print(f"input={input.shape}, {type(input)}\ns_CAM={s_CAM.shape}, {type(s_CAM)}")

        if(type=="tensor"):
            return torch.tensor(s_CAM)    
        elif(type=="numpy"):
            return s_CAM    
        
        return "wrong type"

    else:
        error_msg = "wrong inputs shape"
        return error_msg


In [7]:
explanations = {
    "Grad_CAM": [],
    "Grad_CAM++": [],
    "Score_CAM": [],
    "Relevance_CAM": [],
}
for method, batch in explanations.items():
    dir = f'data/R_CAM_results_layer{layer}/np/heatmaps/{method}/'
    files = os.listdir(dir)
    for f in files:
        file_np = np.load(dir+f)
        batch.append(file_np)

explanations = {k: np.array(v) for k, v in explanations.items()}

In [8]:
if(os.path.exists("data/results.npy")):
    with open('data/results.json', 'r') as f:
        results = json.load(f)
else:
    xai_methods = list(explanations.keys())
    results = {method : {} for method in xai_methods}

results

{'Grad_CAM': {}, 'Grad_CAM++': {}, 'Score_CAM': {}, 'Relevance_CAM': {}}

NOTE:
Due to the computational demands of this task, I calculated the results for each method and metric separately by uncommenting the appropriate section and changing the 'method' variable. After the calculations were complete, I saved the results in /data/results.json and then restarted the kernel to move on to the next metric/method.

In [9]:
# available methods: 'Grad_CAM' 'Grad_CAM++' 'Score_CAM' 'Relevance_CAM'
method = "Score_CAM"

# results[method]["Faithfulness"] = quantus.RegionPerturbation(
#     patch_size=14,
#     regions_evaluation=10,
#     perturb_baseline="uniform",  
#     normalise=True,
#     aggregate_func=np.mean,
#     return_aggregate=True,
#     disable_warnings=True,
# )(model=model,
# x_batch=x_batch,
# y_batch=y_batch,
# a_batch=None,
# device=device,
# explain_func=explainer_wrapper, 
# explain_func_kwargs={"method":method,"layer": layer, "type":"tensor"})

r = quantus.Continuity(
    patch_size=56,
    nr_steps=5,
    perturb_baseline="uniform",
    similarity_func=quantus.similarity_func.correlation_spearman,
    aggregate_func=np.mean,
    return_aggregate=True,
    disable_warnings=True,
)(model=model, 
   x_batch=x_batch,
   y_batch=y_batch,
   a_batch=None,    
   device=device,
   explain_func=explainer_wrapper, 
   explain_func_kwargs={"method":method,"layer": layer, "type":"tensor"})

global_avg = 0
n_patches = len(r)
r = r[0]
for patch in r:
    avg = np.mean(np.array(r[patch]))
    global_avg += avg
global_avg = global_avg/n_patches
results[method]["Robustness"] = [global_avg]

# results[method]["Axiomatic"] = quantus.NonSensitivity(
#     abs=True,
#     eps=1e-5,
#     n_samples=5, 
#     perturb_baseline="black",
#     perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
#     features_in_step=6272,
#     aggregate_func=np.mean,
#     return_aggregate=True,
#     disable_warnings=True,
# )(model=model, 
#    x_batch=x_batch,
#    y_batch=y_batch,
#    a_batch=None,    
#    device=device,
#    explain_func=explainer_wrapper, 
#    explain_func_kwargs={"method":method,"layer": layer, "type":"tensor"})

# results[method]["Complexity"] = quantus.Complexity(
#     aggregate_func=np.mean,
#     return_aggregate=True,
#     disable_warnings=True,
# )(model=model, 
#    x_batch=x_batch,
#    y_batch=y_batch,
#    a_batch=None,    
#    device=device,
#    explain_func=explainer_wrapper, 
#    explain_func_kwargs={"method":method,"layer": layer, "type":"tensor"})

# results[method]["Randomisation"] = quantus.RandomLogit(
#     num_classes=10,
#     similarity_func=quantus.similarity_func.ssim,
#     aggregate_func=np.mean,
#     return_aggregate=True,
#     disable_warnings=True,
# )(model=model, 
#    x_batch=x_batch,
#    y_batch=y_batch,
#    a_batch=None,    
#    device=device,
#    explain_func=explainer_wrapper, 
#    explain_func_kwargs={"method":method,"layer": layer, "type":"numpy"})

results

100%|██████████| 2048/2048 [10:47<00:00,  3.16it/s]
100%|██████████| 2048/2048 [11:25<00:00,  2.99it/s]
100%|██████████| 2048/2048 [10:24<00:00,  3.28it/s]
100%|██████████| 2048/2048 [11:47<00:00,  2.89it/s]
100%|██████████| 2048/2048 [11:59<00:00,  2.85it/s]
100%|██████████| 2048/2048 [11:41<00:00,  2.92it/s]
100%|██████████| 2048/2048 [10:39<00:00,  3.20it/s]
100%|██████████| 2048/2048 [12:10<00:00,  2.80it/s]
100%|██████████| 2048/2048 [11:24<00:00,  2.99it/s]
100%|██████████| 2048/2048 [11:14<00:00,  3.04it/s]
100%|██████████| 10/10 [1:54:05<00:00, 684.50s/it]
100%|██████████| 2048/2048 [10:44<00:00,  3.18it/s]
100%|██████████| 2048/2048 [10:34<00:00,  3.23it/s]
100%|██████████| 2048/2048 [11:06<00:00,  3.07it/s]
100%|██████████| 2048/2048 [11:07<00:00,  3.07it/s]
100%|██████████| 2048/2048 [10:43<00:00,  3.18it/s]
100%|██████████| 2048/2048 [10:53<00:00,  3.13it/s]
100%|██████████| 2048/2048 [11:07<00:00,  3.07it/s]
100%|██████████| 2048/2048 [11:28<00:00,  2.97it/s]
100%|████████

KeyboardInterrupt: 

In [None]:
with open('data/results.json', 'w') as f:
  json.dump(results, f)