In [1]:
import os
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from tqdm.autonotebook import tqdm

while os.path.basename(os.getcwd()) != "T2T_ViT":
    os.chdir('..')

from vit_shapley.modules.surrogate import Surrogate
from vit_shapley.modules.explainer import Explainer
from vit_shapley.modules.explainer_utils import remake_masks, quick_test_masked
from vit_shapley.CIFAR_10_Dataset import CIFAR_10_Dataset, CIFAR_10_Datamodule, PROJECT_ROOT, apply_masks_to_batch
from utils import load_transferred_model

if torch.cuda.device_count() > 1:
    torch.cuda.set_device(5)

  from tqdm.autonotebook import tqdm


In [2]:
datamodule = CIFAR_10_Datamodule(num_players=196, num_mask_samples=1, paired_mask_samples=False)
datamodule.setup()
data = next(iter(datamodule.train_dataloader()))

images = data['images']
labels = data['labels']
masks = data['masks']
print(images.shape, labels.shape, masks.shape)

ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 

torch.Size([32, 3, 224, 224]) torch.Size([32]) torch.Size([32, 1, 196])


# Loading models

In [3]:
surrogate_dir = PROJECT_ROOT / "saved_models/surrogate/cifar10"
surrogate = Surrogate.load_from_checkpoint(
    surrogate_dir / "v2/player196/t2t_vit.ckpt",
    # surrogate_dir / "_player16_lr1e-05_wd0.0_b256_epoch28.ckpt",
    map_location="cuda",
    strict=False,    # It's OK to ignore "target_model.*" being saved checkpoint but not in Surrogate for evaluation.
    backbone_name="t2t_vit"  # Needs to be specified for older checkpoints.
)
surrogate.eval()
pass

adopt performer encoder for tokens-to-token


In [4]:
explainer_dir = PROJECT_ROOT / "saved_models/explainer/cifar10/"
explainer = Explainer.load_from_checkpoint(
    explainer_dir / 'v2/player196/t2t_vit.ckpt',
    map_location="cuda",
    surrogate=deepcopy(surrogate),
    strict=False,
    backbone_name="t2t_vit",  # Needs to be specified for older checkpoints.
    use_convolution=False,
    num_players=196,
)
explainer.eval()
pass

adopt performer encoder for tokens-to-token


/home/ubuntu/.conda/envs/main/lib/python3.11/site-packages/pytorch_lightning/core/saving.py:173: Found keys that are in the model state dict but not in the checkpoint: ['head.weight', 'head.bias']


In [61]:
# explainer wrapper for quantus
def shap_explain_func(model, inputs, targets, method="Shap"):
    inputs = torch.tensor(inputs)
    targets = torch.tensor(targets)
    with torch.no_grad():
        shap_values_ = explainer(inputs.to(explainer.device)).cpu()  # (batch=32, num_players=196, num_classes=10)
    shap_values = shap_values_[torch.arange(shap_values_.shape[0]), :, targets] # (batch, num_players)

    shap_values_quantus = shap_values.view(32, 14, 14).repeat_interleave(16, dim=1).repeat_interleave(16, dim=2).unsqueeze(1)
    return shap_values_quantus.numpy()

## Quantus

In [63]:
# Faithfulness correlation iteratively replaces a random subset of given attributions 
# with a baseline value and then measuring the correlation between the sum of this attribution
# subset and the difference in function output 

def compute_faithfulness_correlation(explain_function, method):
    faithfulness_correlation = quantus.FaithfulnessCorrelation(
        nr_runs=100,  
        subset_size=224,  
        perturb_baseline="black",
        perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
        similarity_func=quantus.similarity_func.correlation_pearson,  
        abs=False,  
        return_aggregate=False,
    )(model=surrogate.cpu(), 
      x_batch=images.numpy(),
      y_batch=labels.numpy(),
      explain_func=explain_function,
      explain_func_kwargs={"method": method})

explain_functions = [quantus.explain, quantus.explain, shap_explain_func]
methods = ["Saliency", "IntegratedGradients", "Shap"]

faithfulness_correlations = {}
for explain_function, method in zip(explain_functions, methods):
    faithfulness_correlations[method] = compute_faithfulness_correlation(explain_function, method)


KeyboardInterrupt



In [59]:
# Max-Sensitivity: measures the maximum sensitivity of an explanation 
# using a Monte Carlo sampling-based approximation

def compute_max_sensitivity(explain_func, method):
    max_sensitivity = quantus.MaxSensitivity(
        nr_samples=10,
        lower_bound=0.2
    )(model=surrogate.cpu(), 
      x_batch=images.numpy(),
      y_batch=labels.numpy(),
      explain_func=explain_func,
      explain_func_kwargs={"method": method})

explain_functions = [quantus.explain, quantus.explain, shap_explain_func]
methods = ["Saliency", "IntegratedGradients", "Shap"]

max_sensitivities = {}
for explain_function, method in zip(explain_functions, methods):
    max_sensitivities[method] = compute_max_sensitivity(explain_function, method)

## Visualize

In [None]:
def normalize_image(arr) -> np.ndarray:
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    if isinstance(arr, torch.Tensor):
        arr_copy = arr.clone().cpu().numpy()
    else:
        arr_copy = arr.copy()

    arr_copy = quantus.normalise_func.denormalise(arr_copy, mean=mean, std=std)
    arr_copy  = np.moveaxis(arr_copy, 0, -1)
    arr_copy = (arr_copy * 255.).astype(np.uint8)
    return arr_copy

In [None]:
import random
index = random.randint(0, len(images)-1)

# Plot examplary explanations!
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 5))
axes[0].imshow(normalize_image(images[index].detach()), vmin=0.0, vmax=1.0)
axes[0].title.set_text(f"CIFAR10 {labels[index].item()}")
exp = axes[1].imshow(a_batch_saliency[index].reshape(224, 224), cmap="seismic")
fig.colorbar(exp, fraction=0.03, pad=0.05); 
axes[0].axis("off"); axes[1].axis("off"); plt.show();