# Imports


In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np

import pickle
import urllib.request

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torchsummary import summary

from PIL import Image

%matplotlib inline

# Model

![ResNet34](https://miro.medium.com/max/1050/1*Y-u7dH4WC-dXyn9jOG4w0w.png)

In [None]:
resnet34 = models.resnet34(weights='ResNet34_Weights.IMAGENET1K_V1')
resnet34.eval()

In [None]:
summary(resnet34, (3, 224, 224), device='cpu')

# Inputs

In [None]:
def preprocess_image(dir_path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    dataset = datasets.ImageFolder(dir_path, transforms.Compose([
            transforms.Resize(256), 
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize]))

    return (dataset)

if not os.path.exists("data"):
    os.mkdir("data")
if not os.path.exists("data/GradCAM/data"):
    os.mkdir("data/GradCAM/data")
    !cd data/GradCAM && wget "https://www.lri.fr/~gcharpia/deeppractice/2023/TP2/TP2_images.zip" && unzip TP2_images.zip

dir_path = "data" 
dataset = preprocess_image(dir_path)

print(dataset)
print(dataset[0][0].shape)

# Labels

In [None]:
classes = pickle.load(urllib.request.urlopen('https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl'))
classes

# Sample

In [None]:
def predict_label(model, image_index, top_n=1):

    output = model(dataset[image_index][0].view(1, 3, 224, 224))

    _, indices = torch.topk(output, top_n)
    label_index = indices[0].numpy()[:top_n]
    label_pred = [classes[i] for i in label_index]         
    
    return (label_index, label_pred)


def labeled_plot(image_index, label):

    input_image = Image.open(dataset.imgs[image_index][0]).convert('RGB')

    plt.imshow(input_image)
    plt.title(" / ".join(label))
    plt.show()


image_index = 5
(label_index, label_name) = predict_label(resnet34, image_index, 3)
print(label_index)
print(label_name)
labeled_plot(image_index, label_name)

# Grad-CAM 

**Overview**  
Given an image, and a category (‘tiger cat’) as input, we forward-propagate the image through the model to obtain the `raw class scores` before softmax. The gradients are set to zero for all classes except the desired class (tiger cat), which is set to 1. This signal is then backpropagated to the `rectified convolutional feature map` of interest, where we can compute the coarse Grad-CAM localization (blue heatmap).

*pug, pug-dog* | *tabby, tabby cat*

![alt](https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/dog.jpg) ![alt](https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/cat.jpg)

### Layer

In [None]:
resnet34.layer4
list(resnet34.layer4.children())
list(resnet34.layer4.children())[-1]
list(list(resnet34.layer4.children())[-1].children())
list(list(resnet34.layer4.children())[-1].children())[-1]

target_layer = list(list(resnet34.layer4.children())[-1].children())[-1]
str(target_layer.__class__).split(".")[-1][:-2]

### Loss

In [None]:
"""
Both losses do the same thing.

The ```MultiplicativeFiltergingLoss```sets gradients to 0 more literally
by multiplicating output by the one-hot_encoded target.

The ```FilteringLoss```  sets gradients to 0 more efficiently by filtering.
"""

class MultiplicativeFilteringLoss(nn.Module):

    def __init__(self):
        super(FilteringLoss, self).__init__()

    def forward(self, output, target):

        value = sum(output * target).clone()

        return value
    

class FilteringLoss(nn.Module):

    def __init__(self):
        super(FilteringLoss, self).__init__()

    def forward(self, output, target):

        value = output[:, target].clone()

        return value

### GradCAM

In [None]:
def gradcam(image_index, label_index, model, target_layer):

    # HOOK
    activation = None
    gradient = None

    def get_activations_hook(module, input, output):
        nonlocal activation
        activation = output

    def get_grads_hook(module, grad_input, grad_output):
        nonlocal gradient
        gradient = grad_output[0]

    hook_activation = target_layer.register_forward_hook(get_activations_hook)
    hook_gradient = target_layer.register_backward_hook(get_grads_hook)

    # FORWARD
    model.zero_grad()
    input = dataset[image_index][0].view(1, 3, 224, 224)
    input.requires_grad = True
    output = model(input)

    # BACKWARD

    # ohe_target = torch.tensor(
    #     [1 if i == label_index else 0 for i, _ in enumerate(output[0])],
    #     dtype=torch.float
    # ).reshape(output.shape)
    # loss_fn = MultiplicativeFilteringLoss()
    # loss = loss_fn(output, ohe_target)
    # loss.backward()

    loss_fn = FilteringLoss()
    loss = loss_fn(output, torch.tensor([label_index]))
    loss.backward()

    # CLEANUP
    hook_activation.remove()
    hook_gradient.remove()

    # GRADCAM
    weights = torch.mean(gradient, dim=(2, 3), keepdim=True)
    grad_cam = torch.sum(weights * activation, dim=1, keepdim=True)
    grad_cam = nn.functional.relu(grad_cam)

    upsampled_grad_cam = nn.functional.interpolate(
        grad_cam, size=(224, 224),
        mode = 'bilinear', # 'nearest' doesn't work and the others do not work with our shape
        align_corners=False
    )

    # VISUALIZATION
    label = classes[label_index]
    reference_image = input.detach().numpy()[0].transpose(1, 2, 0)
    gradcam_heatmap = upsampled_grad_cam.detach().numpy()[0, 0, :, :]

    # plt.imshow(reference_image)
    # plt.imshow(gradcam_heatmap, cmap='jet', alpha=0.5)
    # plt.title(label)
    # plt.show()

    return reference_image, gradcam_heatmap, label


fig, axes = plt.subplots(1, len(label_index), figsize=(10,10))

for i, label_i in enumerate(label_index):

    image, heatmap, label = gradcam(image_index, label_i, resnet34, target_layer)

    axes[i].imshow(image)
    axes[i].imshow(heatmap, cmap='jet', alpha=0.5)
    axes[i].set_title(label)
    axes[i].axis('off')

plt.tight_layout() 
plt.show()


### Comparisons

In [None]:
def compare_gradcams(image_index, label_indexes, model, modules):

    if len(modules) == 1:

        fig, axes = plt.subplots(len(modules), len(label_indexes), figsize=(10,10))

        for i, label_i in enumerate(label_indexes):

            image, heatmap, label = gradcam(image_index, label_i, model, modules[0])

            axes[i].imshow(image)
            axes[i].imshow(heatmap, cmap='jet', alpha=0.5)
            axes[i].set_title(label)
            axes[i].axis('off')
    
    if len(modules) > 1 and len(label_indexes) > 1:

        fig, axes = plt.subplots(len(modules), len(label_indexes), figsize=(10,10))

        for j, module_j in enumerate(modules):
            for i, label_i in enumerate(label_indexes):

                image, heatmap, label = gradcam(image_index, label_i, model, module_j)

                axes[j][i].imshow(image)
                axes[j][i].imshow(heatmap, cmap='jet', alpha=0.5)
                axes[j][i].set_title(label)
                axes[j][i].axis('off')

In [None]:
layers_final_module = [
    list(list(resnet34.layer4.children())[-1].children())[-1],
    list(list(resnet34.layer3.children())[-1].children())[-1],
    list(list(resnet34.layer2.children())[-1].children())[-1],
    list(list(resnet34.layer1.children())[-1].children())[-1]
]

compare_gradcams(image_index, label_index, resnet34, layers_final_module)
plt.suptitle('GradCAM accross layers')
plt.tight_layout() 
plt.show()

In [None]:
layer4_blocks_final_module = [
    list(list(resnet34.layer4.children())[-1].children())[-1],
    list(list(resnet34.layer4.children())[-2].children())[-1],
    list(list(resnet34.layer4.children())[-3].children())[-1],
]

compare_gradcams(image_index, label_index, resnet34, layer4_blocks_final_module)
plt.suptitle('GradCAM accross Layer 4 blocks')
plt.tight_layout() 
plt.show()

In [None]:
layer4_block_3_modules = [
    list(list(resnet34.layer4.children())[-1].children())[-1],
    list(list(resnet34.layer4.children())[-1].children())[-2],
    list(list(resnet34.layer4.children())[-1].children())[-3],
    list(list(resnet34.layer4.children())[-1].children())[-4],
    list(list(resnet34.layer4.children())[-1].children())[-5]
]

compare_gradcams(image_index, label_index, resnet34, layer4_block_3_modules)
plt.suptitle('GradCAM accross Layer 4 block 3 modules')
plt.tight_layout() 
plt.show()

In [None]:
layer4_block_2_modules = [
    list(list(resnet34.layer4.children())[-2].children())[-1],
    list(list(resnet34.layer4.children())[-2].children())[-2],
    list(list(resnet34.layer4.children())[-2].children())[-3],
    list(list(resnet34.layer4.children())[-2].children())[-4],
    list(list(resnet34.layer4.children())[-2].children())[-5]
]

compare_gradcams(image_index, label_index, resnet34, layer4_block_2_modules)
plt.suptitle('GradCAM accross Layer 4 Block 2 modules')
plt.tight_layout() 
plt.show()

In [None]:
layer4_block_1_modules = [
    list(list(resnet34.layer4.children())[-3].children())[-1],
    list(list(resnet34.layer4.children())[-3].children())[-2],
    list(list(resnet34.layer4.children())[-3].children())[-3],
    list(list(resnet34.layer4.children())[-3].children())[-4],
    list(list(resnet34.layer4.children())[-3].children())[-5]
]

compare_gradcams(image_index, label_index, resnet34, layer4_block_1_modules)
plt.suptitle('GradCAM accross Layer 4 Block 1 modules')
plt.tight_layout() 
plt.show()

In [None]:
layer3_block_1_modules = [
    list(list(resnet34.layer3.children())[-3].children())[-1],
    list(list(resnet34.layer3.children())[-3].children())[-2],
    list(list(resnet34.layer3.children())[-3].children())[-3],
    list(list(resnet34.layer3.children())[-3].children())[-4],
    list(list(resnet34.layer3.children())[-3].children())[-5]
]

compare_gradcams(image_index, label_index, resnet34, layer3_block_1_modules)
plt.suptitle('GradCAM accross Layer 3 Block 1 modules')
plt.tight_layout() 
plt.show()

In [None]:
test_layer = [
    list(list(resnet34.layer4.children())[-1].children())[-5]
]

compare_gradcams(image_index, label_index, resnet34, test_layer)

# Conclusion

As we move up the layers and the modules inside the layers the results of GradCAM are less and less precise.
The nearest from the output we are, the more precise the heatmap is, and as we go farther and farther the heatmap worsens.
This is visibale at every level.

It is interesting to see that as precision worsens not only the relevant pixels area decreases but once it has disappeared, the non relevant pixel area increases. 


# GradCAM contributions

Grad-CAM allows to generate visual explanations for CNN-based network with the following advantages:
 - it works for **any** CNN-based network
 - it requires no architectural changes or re-training
 - explanations are class-specific

GradCAM shows a proof-of-concept of how interpretable visualizations can help in diagnosing failure.
It demonstrates that diagnosing image classification for seemingly unreasonable predictions can have
 reasonable explanations. Finally is can help in the identification of biases in datasets.

