In [None]:
"""
Idea: 
0. first used in CNN.  
1. When model is trained, its params fixed, and let the activation map (eg at the last layer) as the variables. 
   Do backward propagation to these activations and get a gradient map
2. Do mean on the gradient of these activation maps for different channels, do relu, resize it to be the same size as image. 
"""


![alt text](<grad_cam_illustration.png>)

from: https://xai-tutorials.readthedocs.io/en/latest/_model_specific_xai/Grad-CAM.html

In [None]:
"""
Eg for CNN, from: https://medium.com/@bmuskan007/grad-cam-a-beginners-guide-adf68e80f4bb
"""

# Import necessary packages and libraries
import torchvision
import torch
import numpy as np
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import cv2
import matplotlib.pyplot as plt


# Load pre-trained model
vgg_model = torchvision.models.vgg16(pretrained=True)

# transformation for passing image into the network
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

"""
MOST IMPORTANT
"""
# selecting layers from the model to generate activations
image_to_heatmaps = nn.Sequential(*list(vgg_model.features[:-4]))       # calculate until 4th from last layer -- get activation map 
def compute_heatmap(model,img):
  model.eval()

  """get label"""
  # compute logits from the model
  logits = model(img)
  # model's prediction 
  pred = logits.max(-1)[-1]

  """get activation"""
  # activations from the model
  activations = image_to_heatmaps(img)

  """get gradient"""
  # compute gradients with respect to the model's most confident prediction
  logits[0, pred].backward(retain_graph=True)   # when do this, the gradient for each param of the model will be calculated at the param
  # average gradients of the featuremap 
  pool_grads = model.features[-3].weight.grad.data.mean((0,2,3))
  # multiply each activation map with corresponding gradient average
  for i in range(activations.shape[1]):
    activations[:,i,:,:] *= pool_grads[i]
  # calculate mean of weighted activations
  heatmap = torch.mean(activations, dim=1)[0].cpu().detach()
  return heatmap, pred


def upsampleHeatmap(map, image):
  # permute image
  image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
  # maximum and minimum value from heatmap
  m, M = map.min(), map.max()
  # normalize the heatmap
  map = 255 * ((map-m)/ (m-M))
  map = np.uint8(map)
  # resize the heatmap to the same as the input
  map = cv2.resize(map, (224, 224))
  map = cv2.applyColorMap(255-map, cv2.COLORMAP_JET)
  map = np.uint8(map)
  # change this to balance between heatmap and image
  map = np.uint8(map*0.7 + image*0.3)
  return map

def display_images(upsampled_map, image):
    image = image.squeeze(0).permute(1, 2, 0)
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(upsampled_map)
    axes[0].set_title("Heatmap")
    axes[0].axis('off')
    axes[1].imshow(image)
    axes[1].set_title("Original Image")
    axes[1].axis('off')
    plt.show()
# Example usage
cat_dog_img = "cat_and_dog.jpg"
cat_dog_img = Image.open(cat_dog_img)
cat_dog_img = transform(cat_dog_img)

cat_dog_img = cat_dog_img.unsqueeze(0)
heatmap,pred = compute_heatmap(vgg_model,cat_dog_img)
upsampled_map = upsampleHeatmap(heatmap, cat_dog_img)
print(f"Prediction: {pred}")

display_images(upsampled_map, cat_dog_img)