# WB-XIC, Lab6: Wstęp do wyjaśnień konwolucyjnych sieci neuronowych

- `captum`, `shap`
- [IML](https://christophm.github.io/interpretable-ml-book/neural-networks.html)

Install packages

In [None]:
!pip install scikit-image -U
!pip install shap captum torchinfo

Load packages

In [None]:
import torch
import torchvision
import torchinfo
import shap
import captum
import numpy as np

Load a ResNet-18 model trained on ImageNet

In [None]:
model = torchvision.models.resnet34(pretrained=True)
model = model.eval()

Load [a sample of](https://shap.readthedocs.io/en/stable/generated/shap.datasets.imagenet50.html) ImageNet data

In [None]:
__X, _y = shap.datasets.imagenet50()
_X, y = torch.as_tensor(__X) / 255, torch.as_tensor(_y)

Input images are normalized (by channel)

In [None]:
preprocess = torchvision.transforms.Compose([
   torchvision.transforms.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
   )
])

Make a prediction

In [None]:
# model(preprocess(_X))

In [None]:
_X.shape

In [None]:
torch.movedim(_X, 3, 1).shape

In [None]:
X = torch.movedim(_X, 3, 1)

In [None]:
torchinfo.summary(model, input_size=X.shape)

In [None]:
# model(preprocess(X))
# torch.nn.functional.softmax(model(preprocess(X)), dim=1)
# torch.nn.functional.softmax(model(preprocess(X)), dim=1).sum(axis=1)
torch.nn.functional.softmax(model(preprocess(X)), dim=1).argmax(axis=1)

Import ImageNet labels

In [None]:
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json

In [None]:
import json
with open("imagenet_class_index.json") as json_data:
    idx_to_labels = {idx: label for idx, [_, label] in json.load(json_data).items()}

Visualize images with predictions

In [None]:
import PIL
import matplotlib.pyplot as plt

def show_images(images, k = 3): 
    fig, ax = plt.subplots(k, k, figsize=[6 * k, 6 * k])
    y_hat = torch.nn.functional.softmax(model(preprocess(images)), dim=1)
    preds = y_hat.amax(axis=1)
    preds_idx = y_hat.argmax(axis=1)
    for i, image in enumerate(images):
        pred = preds[i].item()
        pred_idx = preds_idx[i].item()
        ax[i%k, i//k].imshow(image.permute(1, 2, 0))
        ax[i%k, i//k].set_title(f"{pred_idx}: {idx_to_labels[str(pred_idx)]} ({round(pred, 3)})")
        ax[i%k, i//k].axis('off')

In [None]:
show_images(X[39:48], k=3)

## Local interpretable model-agnostic explanations (LIME)

- Theory: https://christophm.github.io/interpretable-ml-book/lime
- Practice: https://captum.ai/api/lime
- (Segmentation for the mask: https://scikit-image.org/docs/dev/api/skimage.segmentation)

In [None]:
from captum.attr import Lime
explainer = Lime(model)

In [None]:
from skimage import segmentation
## https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic
mask = segmentation.slic(
    X[39].permute(1, 2, 0).mean(axis=2), 
    n_segments=100, 
    compactness=0.1, 
    start_label=0,
    # channel_axis=2 # error :(
  )
## https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.quickshift
mask = segmentation.quickshift(
    X[39].permute(1, 2, 0), 
    kernel_size=14, 
    max_dist=7, 
    ratio=0.5
  )

In [None]:
print(mask.max())
mask

In [None]:
attr = explainer.attribute(
    preprocess(X[39].unsqueeze(0)), 
    target=299, 
    n_samples=200, 
    feature_mask=torch.as_tensor(mask),
    show_progress=True
  )

In [None]:
attr

In [None]:
def show_image_mask_explanation(image, mask, explanation):
    fig, ax = plt.subplots(1, 3, figsize=[6 * 2, 6])
    ax[0].imshow(image.permute(1, 2, 0))
    ax[0].set_title("image")
    ax[1].imshow(mask, cmap="flag")
    ax[1].set_title("segmentation mask")
    ax[2].imshow(explanation, vmin=-1, vmax=1, cmap="RdBu")
    ax[2].set_title("explanation")
    plt.show()

In [None]:
show_image_mask_explanation(X[39], mask, attr[0].mean(axis=0))

In [None]:
from captum.attr import visualization

def show_attr(attr_map):
    visualization.visualize_image_attr(
        attr_map.permute(1, 2, 0).numpy(),
        method='heat_map',
        sign='all',
        show_colorbar=True
    )

In [None]:
show_attr(attr[0])

## Integrated Gradients (IG)

* Theory: https://www.tensorflow.org/tutorials/interpretability/integrated_gradients
* Practice: https://captum.ai/api/integrated_gradients

In [None]:
from captum.attr import IntegratedGradients
exp_ig = IntegratedGradients(model)

In [None]:
attr_ig = exp_ig.attribute(preprocess(X[39].unsqueeze(0)), target=299)

In [None]:
show_attr(attr_ig[0])

## SHapley Additive exPlanations (SHAP)
- KernelSHAP theory: https://christophm.github.io/interpretable-ml-book/shap
- KernelSHAP practice: https://captum.ai/api/kernel_shap
- SHAP based on DeepLIFT: https://captum.ai/api/deep_lift_shap
- SHAP based on IG: https://captum.ai/api/gradient_shap
- https://github.com/slundberg/shap

In [None]:
from captum.attr import KernelShap
ks = KernelShap(model)

attr_ks = explainer.attribute(
    preprocess(X[39].unsqueeze(0)), 
    target=299, 
    n_samples=200, 
    feature_mask=torch.as_tensor(mask),
    show_progress=True
  )

show_attr(attr_ks[0])

In [None]:
# https://shap-lrjball.readthedocs.io/en/latest/generated/shap.DeepExplainer.html
exp_deep = shap.DeepExplainer(model, data=preprocess(X))

sv_deep, idx_deep = exp_deep.shap_values(preprocess(X[39:40]), ranked_outputs=2)

In [None]:
shap.image_plot(
    [sv.squeeze(0).transpose((1, 2, 0)) for sv in sv_deep], 
    X[39].permute(1, 2, 0).numpy(), 
    np.vectorize(lambda x: idx_to_labels[str(x)])(idx_deep)
  )

In [None]:
class NetWrapper(torch.nn.Module):
    def __init__(self, model, preprocess):
        super(NetWrapper, self).__init__()
        self.preprocess = preprocess
        self.model = model

    def forward(self, x):
        x = self.preprocess(x)
        x = self.model(x)
        x = torch.nn.functional.softmax(x, dim=1)
        return x

model_wrapper = NetWrapper(model, preprocess)

# https://shap-lrjball.readthedocs.io/en/latest/generated/shap.GradientExplainer.html
exp_gradient = shap.GradientExplainer(model_wrapper, data=X)

sv_gradient, idx_gradient = exp_gradient.shap_values(X[39:40], ranked_outputs=2)

In [None]:
shap.image_plot(
    [sv.squeeze(0).transpose((1, 2, 0)) for sv in sv_gradient], 
    X[39].permute(1, 2, 0).numpy(), 
    np.vectorize(lambda x: idx_to_labels[str(x)])(idx_gradient)
  )

## Captum Insights

https://captum.ai/tutorials/CIFAR_TorchVision_Captum_Insights

https://github.com/aaron-xichen/pytorch-playground

In [None]:
!pip install flask_compress

In [None]:
# get class labels..
dataset = torchvision.datasets.CIFAR100(
    root="data/test", train=False, download=True, transform=torchvision.transforms.ToTensor()
) 
import pickle
def unpickle(file):
    with open(file, 'rb') as fo:
        myDict = pickle.load(fo, encoding='latin1')
    return myDict
metadata = unpickle('data/test/cifar-100-python/meta')

In [None]:
len(metadata['fine_label_names'])

In [None]:
!wget http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/cifar100-3a55a987.pth
!wget https://raw.githubusercontent.com/MI2-Education/2022L-WB-XIC/master/labs/lab5/code_cifar100.py

Import a model pretrained on CIFAR100 

In [None]:
import code_cifar100

In [None]:
model_cifar100 = code_cifar100.cifar100(n_channel=128, pretrained="cifar100-3a55a987.pth")

In [None]:
import os

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature

def baseline_func(input):
    return input * 0

def formatted_data_iter():
    dataset = torchvision.datasets.CIFAR100(
        root="data/test", train=False, download=True, transform=transforms.ToTensor()
    )
    dataloader = iter(
        torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2)
    )
    while True:
        images, labels = next(dataloader)
        yield Batch(inputs=images, labels=labels)


normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

visualizer = AttributionVisualizer(
    models=[model_cifar100],
    score_func=lambda o: torch.nn.functional.softmax(o, dim=1),
    classes=metadata['fine_label_names'],
    features=[
        ImageFeature(
            "Photo",
            baseline_transforms=[baseline_func],
            input_transforms=[normalize],
        )
    ],
    dataset=formatted_data_iter(),
)

In [None]:
# visualizer.render(debug=True) # doesn't work?

In [None]:
visualizer.serve()

`python -m captum.insights.example`