In [None]:
import sys
sys.path.append("../../")

import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.nn as nn

from torchvision import models
from utils.data import get_example_image

In [None]:
tensor, img, label = get_example_image(650)

plt.imshow(img)
plt.axis("off")
plt.title(f"Label: {label}")
plt.show()

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 10 classes

model.load_state_dict(torch.load("../../assets/models/finetuned_resnet18.pth"))
model.eval()

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None

        # register hooks
        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_full_backward_hook(self._save_gradient)

    def _save_activation(self, module, input, output):
        self.activations = output.detach()

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, x, class_idx):
        output = self.model(x)
        self.model.zero_grad()
        output[0, class_idx].backward()

        # global average pooling of gradients
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)

        # weighted sum of activations
        cam = (weights * self.activations).sum(dim=1)

        cam = F.relu(cam)

        # normalize
        cam = cam.squeeze()
        cam -= cam.min()
        cam /= cam.max() + 1e-8

        return cam

In [None]:
# resnet18 final conv layer
target_layer = model.layer4[-1].conv2
input_tensor = tensor.unsqueeze(0)
target_class = 0  # dog class

grad_cam = GradCAM(model, target_layer)
cam = grad_cam(input_tensor, target_class)

cam = cam.unsqueeze(0).unsqueeze(0)  # (1, 1, 7, 7)
cam = F.interpolate(
    cam,
    size=(img.shape[0], img.shape[1]),
    mode="bilinear",
    align_corners=False,
)
cam = cam.squeeze().cpu().numpy()

In [None]:
plt.figure(figsize=(4, 4))
plt.imshow(img)
plt.imshow(cam, cmap="jet", alpha=0.35)
plt.axis("off")
plt.title("Grad-CAM Overlay")
plt.show()