In [8]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image

In [30]:
class GradCamModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.gradients = None
        self.tensorhook = []
        self.layerhook = []
        self.selected_out = None
        
        # PRETRAINED MODEL
        self.pretrained = models.densenet169(pretrained=True)
        self.layerhook.append(self.pretrained.features.register_forward_hook(self.forward_hook()))
        
        for p in self.pretrained.parameters():
            p.requires_grad = True
    
    def activations_hook(self, grad):
        self.gradients = grad

    def get_act_grads(self):
        return self.gradients

    def forward_hook(self):
        def hook(module, inp, out):
            self.selected_out = out
            self.tensorhook.append(out.register_hook(self.activations_hook))
        return hook

    def forward(self, x):
        out = self.pretrained.features(x)
        out = torch.relu(out, inplace=True)
        out = torch.nn.functional.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.pretrained.classifier(out)
        return out, self.selected_out

In [31]:
# Load and preprocess the image
image_path = r'DaneTest/elephant/xxx.png'
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(image_path).convert('RGB')
preprocessed_image = preprocess(image).unsqueeze(0)

# Create an instance of GradCamModel
model = GradCamModel()

# Forward pass through the model
with torch.no_grad():
    output, features = model(preprocessed_image)
    predicted_class = torch.argmax(output, dim=1).item()

# Backward pass to compute gradients
output[:, predicted_class].backward()
grads = model.get_act_grads()

# Compute the Grad-CAM heatmap
weights = torch.mean(grads, dim=(2, 3), keepdim=True)
grad_cam = torch.sum(weights * features, dim=1, keepdim=True)
grad_cam = torch.relu_(grad_cam)  # Use relu_() for in-place ReLU activation

# Normalize the heatmap
grad_cam = torch.nn.functional.interpolate(grad_cam, size=(224, 224), mode='bilinear', align_corners=False)
grad_cam = (grad_cam - grad_cam.min()) / (grad_cam.max() - grad_cam.min())

RuntimeError: cannot register a hook on a tensor that doesn't require gradient