**NOTE: This notebook is written for the Google Colab platform, which provides free hardware acceleration. However it can also be run (possibly with minor modifications) as a standard Jupyter notebook, using a local GPU.** 



In [None]:
#@title -- Installation of Packages -- { display-mode: "form" }
import sys
!{sys.executable} -m pip install captum
!{sys.executable} -m pip install git+https://github.com/michalgregor/class_utils.git

In [None]:
#@title -- Import of Necessary Packages -- { display-mode: "form" }
import matplotlib.pyplot as plt
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import numpy as np
import torch

import torchvision
from torchvision import models
from torchvision import transforms

from captum.attr import IntegratedGradients, Saliency, GuidedBackprop
from captum.attr import GradientShap
from captum.attr import Occlusion
from captum.attr import visualization as viz

In [None]:
#@title -- Downloading Data -- { display-mode: "form" }
from class_utils.download import download_file_maybe_extract, download_files_maybe_extract
DATA_HOME = "https://github.com/michalgregor/ml_notebooks/blob/main/data/{}?raw=1"

download_files_maybe_extract([
    DATA_HOME.format("imagenet_classes"),
    DATA_HOME.format("images/toucan.png"),
    DATA_HOME.format("images/raccoon_example.jpg"),
    DATA_HOME.format("images/nails.jpg"),
    DATA_HOME.format("images/screws.jpg"),
    DATA_HOME.format("images/ambulance.jpg")
], directory="data")

# also create a directory for storing any outputs
import os
os.makedirs("output", exist_ok=True)

In [None]:
#@title -- Auxiliary Functions -- { display-mode: "form" }
with open("data/imagenet_classes", "r") as file:
    class_names = [c[:-1] for c in file.readlines()]

def resize_image(img, size=224):
    transform = transforms.Compose([
        transforms.Resize(size)
    ])
    return np.asarray(transform(img))

def decode_proba(proba, top=5, verbose=True):
    with torch.no_grad():
        probs, classes = torch.topk(proba, 5)
        probs = probs[0]
        classes = classes[0]

    if verbose:
        for p, c in zip(probs.cpu().numpy(),
                        classes.cpu().numpy()):
            print("{}:\t{} ({})".format(
                np.array2string(p, precision=5),
                class_names[c], c))
            
    return probs, classes

vis_kwargs_default = dict(
    methods=['original_image', 'alpha_scaling', 'heat_map'],
    cmap="viridis",
    show_colorbar=True,
    signs=['all', 'absolute_value', 'absolute_value'],
    fig_size=(12.5, 4)
)

class ProcImage:
    def __init__(self, model, weights, device, verbose=True):
        self.model = model
        self.device = device
        self.verbose = verbose

        weight_transforms = weights.transforms()
        self.normalize = transforms.Normalize(
            mean=weight_transforms.mean,
            std=weight_transforms.std
        )

    def preproc_image(self, img, size=224):
        if not isinstance(img, Image.Image):
            img = torchvision.transforms.functional.to_pil_image(img)

        img = img.convert('RGB')
        transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            self.normalize
        ])

        tensor = transform(img).unsqueeze(0)
        return tensor

    def __call__(self, img_path):
        img = Image.open(img_path)
        img_preproc = self.preproc_image(img)
        img_resized = resize_image(img)
        
        if not self.device is None:
            img_preproc = img_preproc.to(self.device)
            self.model.to(self.device)

        self.model.eval()
        img_preproc.requires_grad = True
        
        output = torch.nn.functional.softmax(self.model(img_preproc), dim=1)
        pred_class = decode_proba(output, verbose=True)[1][0]

        return img_preproc, img_resized, pred_class

def visualize_attrib(attrib, **kwargs):
    kwargs_merged = dict(**vis_kwargs_default)
    kwargs_merged.update(kwargs)
    attrib = np.transpose(attrib.squeeze().cpu().detach().numpy(), (1,2,0))

    # workaround for a matplotlib-related bug in captum
    visualize_image_attr_multiple(attrib, img_resized, **kwargs_merged)
    #viz.visualize_image_attr_multiple(attrib, img_resized, **kwargs_merged)

# workaround for a matplotlib-related bug in captum

from matplotlib.figure import Figure
from captum.attr._utils.visualization import (
    _prepare_image, ImageVisualizationMethod, _normalize_attr,
    VisualizeSign, make_axes_locatable, LinearSegmentedColormap
)

def visualize_image_attr(
    attr, original_image=None, method="heat_map", sign="absolute_value",
    plt_fig_axis=None, outlier_perc=2, cmap=None, alpha_overlay=0.5,
    show_colorbar=False, title=None, fig_size=(6, 6), use_pyplot=True
):
    # Create plot if figure, axis not provided
    if plt_fig_axis is not None:
        plt_fig, plt_axis = plt_fig_axis
    else:
        if use_pyplot:
            plt_fig, plt_axis = plt.subplots(figsize=fig_size)
        else:
            plt_fig = Figure(figsize=fig_size)
            plt_axis = plt_fig.subplots()

    if original_image is not None:
        if np.max(original_image) <= 1.0:
            original_image = _prepare_image(original_image * 255)
    elif ImageVisualizationMethod[method] != ImageVisualizationMethod.heat_map:
        raise ValueError(
            "Original Image must be provided for"
            "any visualization other than heatmap."
        )

    # Remove ticks and tick labels from plot.
    plt_axis.xaxis.set_ticks_position("none")
    plt_axis.yaxis.set_ticks_position("none")
    plt_axis.set_yticklabels([])
    plt_axis.set_xticklabels([])
    plt_axis.grid(visible=False)

    heat_map = None
    # Show original image
    if ImageVisualizationMethod[method] == ImageVisualizationMethod.original_image:
        assert (
            original_image is not None
        ), "Original image expected for original_image method."
        if len(original_image.shape) > 2 and original_image.shape[2] == 1:
            original_image = np.squeeze(original_image, axis=2)
        plt_axis.imshow(original_image)
    else:
        # Choose appropriate signed attributions and normalize.
        norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=2)

        # Set default colormap and bounds based on sign.
        if VisualizeSign[sign] == VisualizeSign.all:
            default_cmap = LinearSegmentedColormap.from_list(
                "RdWhGn", ["red", "white", "green"]
            )
            vmin, vmax = -1, 1
        elif VisualizeSign[sign] == VisualizeSign.positive:
            default_cmap = "Greens"
            vmin, vmax = 0, 1
        elif VisualizeSign[sign] == VisualizeSign.negative:
            default_cmap = "Reds"
            vmin, vmax = 0, 1
        elif VisualizeSign[sign] == VisualizeSign.absolute_value:
            default_cmap = "Blues"
            vmin, vmax = 0, 1
        else:
            raise AssertionError("Visualize Sign type is not valid.")
        cmap = cmap if cmap is not None else default_cmap

        # Show appropriate image visualization.
        if ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map:
            heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax)
        elif (
            ImageVisualizationMethod[method]
            == ImageVisualizationMethod.blended_heat_map
        ):
            assert (
                original_image is not None
            ), "Original Image expected for blended_heat_map method."
            plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray")
            heat_map = plt_axis.imshow(
                norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay
            )
        elif ImageVisualizationMethod[method] == ImageVisualizationMethod.masked_image:
            assert VisualizeSign[sign] != VisualizeSign.all, (
                "Cannot display masked image with both positive and negative "
                "attributions, choose a different sign option."
            )
            plt_axis.imshow(
                _prepare_image(original_image * np.expand_dims(norm_attr, 2))
            )
        elif ImageVisualizationMethod[method] == ImageVisualizationMethod.alpha_scaling:
            assert VisualizeSign[sign] != VisualizeSign.all, (
                "Cannot display alpha scaling with both positive and negative "
                "attributions, choose a different sign option."
            )
            
            original_image = original_image[..., :3]
            plt_axis.imshow(
                np.concatenate(
                    [
                        original_image,
                        _prepare_image(np.expand_dims(norm_attr, 2) * 255),
                    ],
                    axis=2,
                )
            )
        else:
            raise AssertionError("Visualize Method type is not valid.")

    # Add colorbar. If given method is not a heatmap and no colormap is relevant,
    # then a colormap axis is created and hidden. This is necessary for appropriate
    # alignment when visualizing multiple plots, some with heatmaps and some
    # without.
    if show_colorbar:
        axis_separator = make_axes_locatable(plt_axis)
        colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1)
        if heat_map:
            plt_fig.colorbar(heat_map, orientation="horizontal", cax=colorbar_axis)
        else:
            colorbar_axis.axis("off")
    if title:
        plt_axis.set_title(title)

    if use_pyplot:
        plt.show()

    return plt_fig, plt_axis


def visualize_image_attr_multiple(
    attr, original_image, methods, signs, titles=None,
    fig_size=(8, 6), use_pyplot=True, **kwargs
):
    assert len(methods) == len(signs), "Methods and signs array lengths must match."
    if titles is not None:
        assert len(methods) == len(titles), (
            "If titles list is given, length must " "match that of methods list."
        )
    if use_pyplot:
        plt_fig = plt.figure(figsize=fig_size)
    else:
        plt_fig = Figure(figsize=fig_size)
    plt_axis = plt_fig.subplots(1, len(methods))

    # When visualizing one
    if len(methods) == 1:
        plt_axis = [plt_axis]

    for i in range(len(methods)):
        visualize_image_attr(
            attr,
            original_image=original_image,
            method=methods[i],
            sign=signs[i],
            plt_fig_axis=(plt_fig, plt_axis[i]),
            use_pyplot=False,
            title=titles[i] if titles else None,
            **kwargs,
        )
    plt_fig.tight_layout()
    if use_pyplot:
        plt.show()
    return plt_fig, plt_axis

## Visual Interpretation of Neural Networks

Neural nets with deep learning represent a powerful machine learning paradigm. However, they are not known for their interpretability. Nevertheless, there are a few techniques that can provide some insight into whether a neural net is doing what it should. For tabular data, predictions can, of course, be explained with methods such as LIME. But there is also a couple of good techniques that work for images and we are going to showcase some of them in this notebook.

### Loading the Model

We will start by loading a pretrained model that we are going to be testing.



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = models.ResNet50_Weights.DEFAULT
model = models.resnet50(weights=weights).to(device)
proc_image = ProcImage(model, weights, device)

### Saliency

We will start with a concept known as visual saliency. The idea is simple: we will just use backpropagation (autodiff) to compute the sensitivity of the prediction to the various input pixels and visualize the result: the saliency map. The pixels in the saliency map will, roughly speaking, light up in proportion to their relevance to the prediction. This allows us to inspect whether the network is paying attention to the correct portions of the image: i.e. that it actually predicts label "plane" because it recognizes the plane and not because most of the pixels in the background are blue.



In [None]:
attributor = Saliency(model)

Next we are going to load an image of a toucan, run it through our model and display the saliency map. We will see that the area around the toucan will light up the most. Other areas will light up too though. As we will see in a bit, in this case, this is more a property of the visualization method than of the model: the standard gradient is not a great indicator.



In [None]:
img_preproc, img_resized, pred_class = proc_image('data/toucan.png')
attrib = attributor.attribute(img_preproc, target=pred_class)
visualize_attrib(attrib, outlier_perc=10)

### Guided Backprop

The second method we are going to examine will give us much better results. It is called guided backprop and it is actually not that different from saliency: it too relies on a version of the gradient. The difference is that the gradient propagation through the ReLU activation function is modified so that only non-negative gradients get passed through. This effectively ignores signals that contribute negatively to the activation of our class.

Since the current version of the `captum` package seems to be having problems with inplace ReLUs due to some changes in `torch` and `torchvision`, we are also going to run over our model first and change the `.inplace` flag of all ReLUs to `False`.



In [None]:
for module in model.modules():
    if isinstance(module, torch.nn.ReLU):
        module.inplace=False

In [None]:
attributor = GuidedBackprop(model)

We will again apply our attributor to the toucan image: this time the output should be clearly focused on the toucan's outline.



In [None]:
img_preproc, img_resized, pred_class = proc_image('data/toucan.png')
attrib = attributor.attribute(img_preproc, target=pred_class)
visualize_attrib(attrib, outlier_perc=60)

Fine, now let's try the same on a couple more images.



In [None]:
for img_path in ['data/ambulance.jpg', 'data/screws.jpg', 'data/nails.jpg', 'data/raccoon_example.jpg']:
    img_preproc, img_resized, pred_class = proc_image(img_path)
    attrib = attributor.attribute(img_preproc, target=pred_class)
    visualize_attrib(attrib, outlier_perc=60)

### Occlusion

There is a number of other approaches based on the gradient: to explore some of them, browse through the documentation of the `captum` package which we are using to produce the visualization in this notebook.

We will do just one more experiment – with a visualization method based on a completely different principle: occlusion. The idea is that if a pixel (or a group of pixels) is important to the prediction, then if they are occluded (e.g. set to zeros), this is going to have a very significant impact on the prediction. So if we slide over the image with an occlusion window of some size and observe the impacts, we determine which parts of the image the output is most sensitive to.



In [None]:
attributor = Occlusion(model)

Let us again try this using our toucan image. We will use a relatively small occlusion window first. As we can see, there is clear focus on the toucan, but some other areas of the image lit up as well – such as some parts of the tree.



In [None]:
img_preproc, img_resized, pred_class = proc_image('data/toucan.png')
attrib = attributor.attribute(img_preproc, target=pred_class,
                              strides = (3, 20, 20),
                              sliding_window_shapes=(3, 30, 30),
                              baselines=0)
visualize_attrib(attrib, outlier_perc=10)

To show the impact of the occlusion window's size, we will run the visualization again with a larger window. Now the area around the toucan's beak seems to be the most relevant one.



In [None]:
img_preproc, img_resized, pred_class = proc_image('data/toucan.png')
attrib = attributor.attribute(img_preproc, target=pred_class,
                              strides = (3, 50, 50),
                              sliding_window_shapes=(3, 60, 60),
                              baselines=0)
visualize_attrib(attrib, outlier_perc=10)