In [2]:
import torch
import torch.nn.functional as F
from torch.autograd import Function
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [3]:
# Load a pre-trained model (for example, ResNet-50)
model = models.resnet50(pretrained=True)
model.eval()  # Set the model to evaluation mode



In [None]:
# Hook for the gradients of the target layer
class FeatureExtractor:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None

    def save_gradient(self, grad):
        self.gradients = grad

    def __call__(self, x):
        outputs = []
        self.gradients = None
        # Forward pass through the layers
        for name, module in self.model._modules.items():
            x = module(x)
            if name == self.target_layer:
                x.register_hook(self.save_gradient)
                outputs.append(x)
        return outputs, x

In [None]:
# Class to compute Grad-CAM
class GradCam:
    def __init__(self, model, target_layer):
        self.model = model
        self.feature_extractor = FeatureExtractor(model, target_layer)

    def __call__(self, input_tensor, target_class=None):
        features, output = self.feature_extractor(input_tensor)

        if target_class is None:
            target_class = torch.argmax(output, dim=1).item()

        # Zero gradients
        self.model.zero_grad()
        # Backpropagate to get the gradient of the target class
        target = output[0, target_class]
        target.backward()

        # Get the gradients from the feature extractor
        gradients = self.feature_extractor.gradients
        # Get the feature maps from the layer
        features = features[0]

        # Global average pooling of the gradients
        pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

        # Multiply each channel in the feature map by the corresponding gradient
        for i in range(features.shape[1]):
            features[0, i, :, :] *= pooled_gradients[i]

        # Average the weighted feature maps along the channels to get the heatmap
        heatmap = torch.mean(features, dim=1).squeeze()

        # Apply ReLU to the heatmap
        heatmap = F.relu(heatmap)

        # Normalize the heatmap
        heatmap = heatmap - heatmap.min()
        heatmap = heatmap / heatmap.max()

        return heatmap.cpu().detach().numpy()

In [None]:
# Utility function to preprocess the image
def preprocess_image(img_path):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = Image.open(img_path)
    img_tensor = preprocess(img).unsqueeze(0)
    return img_tensor

# Utility function to overlay the heatmap on the original image
def overlay_heatmap(heatmap, img_path):
    img = cv2.imread(img_path)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
    return superimposed_img

In [None]:
# Load an image and preprocess it
img_path = '../data/alzheimer_mri/test/ModerateDemented/ModerateDemented_6.png'
input_tensor = preprocess_image(img_path)

# Create Grad-CAM object for the last convolutional layer (layer4 in ResNet-50)
grad_cam = GradCam(model, target_layer='layer4')

# Generate Grad-CAM heatmap for the target class (optional, if not provided, it uses the predicted class)
heatmap = grad_cam(input_tensor)

# Overlay the heatmap on the original image
superimposed_img = overlay_heatmap(heatmap, img_path)

# Display the image with the heatmap
plt.imshow(superimposed_img[:, :, ::-1])  # Convert BGR to RGB for display
plt.axis('off')
plt.show()
