<a href="https://colab.research.google.com/github/nyp-sit/iti121-2025s2/blob/main/L4/Grad_Cam_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Grad-CAM

This lab exercise show how you can use Grad-CAM to visualize what features are important in influencing the model prediction.

We will be using [pytorch-gradcam](https://github.com/jacobgil/pytorch-grad-cam) package for this exercise

In [None]:
%pip install grad-cam

In [None]:
from PIL import Image
import torch
import os
import cv2
import numpy as np
from torchvision import models
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import (
    show_cam_on_image, deprocess_image, preprocess_image
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST

We will load the image using cv2 library. We scale the image to be between 0 and 1.0 and further process the image to have the means and std deviations required for resnet as we will be using resnet for our image classification task.

Different pretrained network will have different ways to process the image (rescaling, resizing, etc).

You can find different transformation that is required by looking at the `transform()` method bundled with the weights.

```python

from torchvision.models import resnet50, ResNet50_Weights

weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

```



In [None]:
image = Image.open('cockatoo.jpeg')

rgb_img = cv2.imread('cockatoo.jpeg', 1)[:, :, ::-1]
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img,
                                    mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]).to("cpu")


We will use the pretrained network Resnet50 for predicting the image.  

You will need to choose the target layer you want to compute the visualization for.
Usually this will be the last convolutional layer in the model.
Some common choices can be:
- Resnet18 and 50: model.layer4
- VGG, densenet161: model.features[-1]
- mnasnet1_0: model.layers[-1]

You can print the model to help chose the layer
You can pass a list with several target layers,
in that case the CAMs will be computed per layer and then aggregated.

In [None]:
# specify the output directory to write the visualization to
output_dir = "output"
output_file = "GradCam_cam.jpg"

# use CPU for computation
device = torch.device("cpu")

In [None]:
model = models.resnet50(pretrained=True).to(device).eval()
target_layers = [model.layer4]


In [None]:
# targets = None    # If targets is None, the highest scoring category (for every member in the batch) will be used.
targets = [ClassifierOutputTarget(89)]   # Take the gradient of the score for class 281 w.r.t. the convolutional activations.”
# targets = [ClassifierOutputReST(89)] # Highlight regions that specifically distinguish one class from the rest. This often produces sharper, more discriminative heatmaps,


cam_algorithm = GradCAM
with cam_algorithm(model=model,
                    target_layers=target_layers) as cam:

    # cam.batch_size = 32
    grayscale_cam = cam(input_tensor=input_tensor,
                        targets=targets,
                        aug_smooth=True,  # Apply test time augmentation to smooth the CAM
                        eigen_smooth=True) # Reduce noise by taking the first principle component

    # heat map
    grayscale_cam = grayscale_cam[0, :]

    # overlays the Grad-CAM heatmap (grayscale intensity) on top of the original RGB image — so you can visually see where the model is focusing.
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

    # OpenCV internally represents images in BGR order by default.
    cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

os.makedirs(output_dir, exist_ok=True)
cam_output_path = os.path.join(output_dir, output_file)

cv2.imwrite(cam_output_path, cam_image)


Let's display the resultant viualizatio.

In [None]:
import matplotlib.pyplot as plt

img = plt.imread(cam_output_path)
plt.imshow(img)