# Saliency maps

An important tool in debugging CV models is a saliency map. A saliency map indicates regions of interest of an input image. There are two main techniques for convolutional models: input gradient and GradCAM.

**Goal.** The goal of this notebook is to develop the basic skills in model analysis and give a deeper understanding of CNN models.

You need the following extra libraries beyond PyTorch:
* torchvision
* Pillow (PIL)

In [None]:
import torch
import torchvision
import numpy as np
from PIL import Image

from matplotlib import pyplot as plt
from matplotlib.cm import jet

# Helper tools. You can skip this block.
def show_map(x, saliency):
    assert x.ndim == 3 and saliency.ndim == 2
    image = x.detach().clone().permute(1, 2, 0)
    image -= image.min()
    image /= image.max()

    s = saliency.detach().clone()
    if s.shape != image.shape[:2]:
        s = torch.nn.functional.interpolate(s[None, None], image.shape[:2], mode="bilinear")[0, 0]
    s -= s.min()
    s /= s.max()
    s = jet(s)[:, :, :3]  # Remove alpha channel.
    s = torch.as_tensor(s)  # (H, W, C).

    mixed = 0.5 * image + 0.5 * s
    print(x.shape, s.shape, mixed.shape)
    plt.imshow(torch.cat([image, s], 1))
    plt.show()
    plt.imshow(mixed)
    plt.show()

In [None]:
# Uncomment to download the image.
#! wget -O image.jpg https://raw.github.com/ivan-chai/isscai-cv-2024/master/02-images/image.jpg

In [None]:
# Load the pretrained ResNet34 model.
weights = torchvision.models.ResNet34_Weights.IMAGENET1K_V1
model = torchvision.models.resnet34(weights=weights).eval()

# Load the image.
image = Image.open("image.jpg")
plt.title("Original image")
plt.imshow(image)
plt.show()

# Preprocess.
x = weights.transforms()(image)
print("Preprocessed image size:", x.shape)
plt.title("Preprocessed")
preprocessed = x.clone().permute(1, 2, 0)
preprocessed -= preprocessed.min()
preprocessed /= preprocessed.max()
plt.imshow(preprocessed)

# Recognition

The model was pretrained on the ImageNet 1K dataset and capapable of recognizing both ships and cats. Let's show top predicted classes.

**Load label names**

In [None]:
#!wget https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt -O imagenet1000_clsidx_to_labels.txt

In [None]:
import yaml
with open("imagenet1000_clsidx_to_labels.txt") as fp:
    labels = yaml.safe_load(fp)
print("Num classes:", len(labels))

**Show predictions**

In [None]:
logits = model(x[None])[0]
probs = torch.nn.functional.softmax(logits, dim=0)
top = probs.argsort(descending=True)[:10].tolist()
for i in top:
    print(f"Class {i} {labels[i]}, prob: {probs[i] * 100:.2f}%")

**Discussion.** We see a mixture of ship and cat / tiger classes.

In [None]:
SHIP = 628
CAT = 282

# Simple gradient

The basic algorithm to measure the impact of an image region on the model output is to just compute the gradient value w.r.t. the input image. A large gradient magnitude indicates an important region (output is sensitive w.r.t. to this region). Otherwise, zero gradient means a small perturbation of this region will not affect the output.

Gradient can be easily computed using a standard PyTorch Autograd algorithm.

In [None]:
x.requires_grad = True
x.grad = None
logits = model(x[None])[0]
logits[SHIP].backward()
print("Gradient shape:", x.grad.shape)
print(f"Gradient range: [{x.grad.min().item():.2f}, {x.grad.max().item():.2f}]")

# Compute gradient norm for each pixel.
# We need clipping to measure the positive impact only.
value = torch.linalg.norm(x.grad.clip(min=0), dim=0)  # (C, H, W) -> (H, W).
plt.title(f"Saliency map for '{labels[SHIP]}'")
show_map(x, value)

**Discussion.**
* The saliency map is noisy and unstable.
* The logit responsible for the ship class ignores the region with the cat.
* The model looks at the bottom of the ship and at the sea above.
* The decision of the model can be affected by modifying a little amount of pixels in the image.

In [None]:
# Compute gradient norm for each pixel.
x.requires_grad = True
x.grad = None
logits = model(x[None])[0]
logits[CAT].backward()
value = torch.linalg.norm(x.grad.clip(min=0), dim=0)  # (C, H, W) -> (H, W).
plt.title(f"Saliency map for '{labels[CAT]}'")
show_map(x, value)

**Discussion.**
* The saliency map is also noisy and unstable.
* The model looks at the area nearby the cat.
* The decision of the model can be affected by modifying pixels in the background.

# GradCAM

GradCAM exploits the structure of a CNN model:
* Before the global pooling layer (highlighted) we have a feature map for each region in the image.
* We can apply a classification head to each feature vector (out of 7 x 7) and generate class probability.
* This probability indicates the correspondence of the region to the class.

We will implement a simple method similar to GradCAM.

<img src="vgg16.jpg" width=500 />

In [None]:
y = x[None]
y = model.conv1(y)
y = model.bn1(y)
y = model.relu(y)
y = model.maxpool(y)
y = model.layer1(y)
y = model.layer2(y)
y = model.layer3(y)
activations = model.layer4(y)[0]
print("Activations shape:", activations.shape)
d, h, w = activations.shape
y = activations.flatten(1).T  # (HW, C).
y = model.fc(y).reshape(h, w, -1)  # (H, W, C).
probs = torch.nn.functional.softmax(y, dim=-1)  # (H, W, C).
print("Probabilities shape:", probs.shape)

In [None]:
plt.title(f"GradCAM map for '{labels[top[0]]}'")
show_map(x, probs[..., top[0]])

In [None]:
plt.title(f"GradCAM map for '{labels[top[1]]}'")
show_map(x, probs[..., top[1]])

**Discussion**
* GradCAM is more stable.
* GradCAM has little granularity.
* GradCAM can also be used for **object detection**.

# Home work (optional)

See the original [GradCAM work](https://arxiv.org/pdf/1610.02391).
* What are the differences between our approach and GradCAM?
* Try to implement the original GradCAM approach. Are there any visible differences in the result?