# Compare Grad-CAM Heatmaps Across Models

This notebook is used to compare how different trained models (DenseNet-121, DenseNet-169, MobileNetV3, EfficientNet-B0) visualise regions of interest in mammogram images using Grad-CAM.

In [None]:
# imports
import torch
import torch.nn as nn
from torchvision.models import densenet169, densenet121, mobilenet_v3_large, mobilenet_v3_small, efficientnet_b0

import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
from matplotlib import colormaps
import numpy as np
from PIL import Image
import PIL

from torch.utils.data import Dataset, ConcatDataset, DataLoader
import random

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def unnormalise(img_tensor, mean, std):
    """
    Reverses ImageNet normalisation to recover original image values.
    """
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return img_tensor * std + mean

# ImageNet mean/std
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

In [None]:
# Grad-CAM hook setup
gradients = None
activations = None

# Function to register Grad-CAM hooks
def register_gradcam_hooks(last_conv_layer):
    """
    Registers forward and backward hooks on the last convolutional layer
    to capture activations and gradients needed for Grad-CAM.
    """
    global gradients, activations

    def forward_hook(module, args, output):
        global activations
        activations = output

    def backward_hook(module, grad_input, grad_output):
        global gradients
        gradients = grad_output

    last_conv_layer.register_full_backward_hook(backward_hook, prepend=False)
    last_conv_layer.register_forward_hook(forward_hook, prepend=False)
 
def generate_gradcam(model, image):
    """
    Generates Grad-CAM heatmap for a given image.
    """
    global gradients, activations
    
    model.zero_grad()  # Clear previous gradients
    output = model(image)
    prob = output.sigmoid()
    pred_label = (prob > 0.5).float()
    
    # Backward pass to get gradients
    output.backward(torch.ones_like(output))  
    
    # Pool gradients across the channels
    pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])

    # Weight the channels by corresponding gradients
    for i in range(activations.size()[1]):
        activations[:, i, :, :] *= pooled_gradients[i]

    # Compute heatmap
    heatmap = torch.mean(activations, dim=1).squeeze()
    heatmap = F.relu(heatmap)  # Apply ReLU
    heatmap /= torch.max(heatmap)  # Normalize

    return heatmap.detach().cpu()

def overlay_heatmap(img_tensor, heatmap):
    """
    Overlays the Grad-CAM heatmap on the original image.
    """

    unnorm_img = unnormalise(img_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    original_img = to_pil_image(unnorm_img.clamp(0, 1), mode='RGB')  # Convert tensor to PIL image

    # Resize the heatmap to match image size
    overlay = to_pil_image(heatmap, mode='F').resize((224, 224), resample=PIL.Image.BICUBIC)

    # Apply colormap
    cmap = colormaps['jet']
    overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)

    return original_img, overlay  # Return original image and overlay

In [None]:
def load_model_d169(model_path):
    """
    Loads DenseNet-169 with custom classifier for binary classification,
    and attaches Grad-CAM hooks to its last convolutional layer.
    """
    model = densenet169(weights=None)
    model.classifier = nn.Sequential(
        #nn.Dropout(0.5),
        nn.Linear(in_features=1664, out_features=1)
    )
    # Load model state
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device("cpu")))
    # Move model to GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Register Grad-CAM hooks
    last_conv_layer = model.features.denseblock4.denselayer32.conv2
    register_gradcam_hooks(last_conv_layer)

    return model

In [None]:
def load_model_d121(model_path):
    """
    Loads DenseNet-121 with custom classifier for binary classification,
    and attaches Grad-CAM hooks to its last convolutional layer.
    """
    model = densenet121(weights=None)
    model.classifier = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(in_features=1024, out_features=1)
    )
    # Load model state
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device("cpu")))
    # Move model to GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Register Grad-CAM hooks
    last_conv_layer = model.features.denseblock4.denselayer16.conv2
    register_gradcam_hooks(last_conv_layer)

    return model

In [None]:
# Function to load model (EfficientNet B0)
def load_model_eb0(model_path):
    """
    Loads EffcientNet B0 with custom classifier for binary classification,
    and attaches Grad-CAM hooks to its last convolutional layer.
    """
    model = efficientnet_b0(weights=None)
    # Modify the classifier for binary classification
    model.classifier = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(in_features=1280, out_features=1)
    )
    # Load model state
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device("cpu")))
    # Move model to GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Register Grad-CAM hooks
    last_conv_layer = model.features[-1][0]
    register_gradcam_hooks(last_conv_layer)

    return model

In [None]:
def load_model_mnl(model_path):
    """
    Loads MobileNetV3 Large with custom classifier for binary classification,
    and attaches Grad-CAM hooks to its last convolutional layer.
    """
    model = mobilenet_v3_large(weights=None)
    
    # Check the number of input features for the classifier
    num_features = model.classifier[0].in_features

    # Modify the classifier for binary classification
    model.classifier = nn.Sequential(
        nn.Dropout(0.4),
        nn.Linear(num_features, 1)
    )

    # Load model state
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device("cpu")))

    # Move model to device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Register Grad-CAM hooks on the last convolutional layer
    last_conv_layer = model.features[-1][0]
    register_gradcam_hooks(last_conv_layer)

    return model

In [None]:
# Function to load model (MobileNetV3 small)
def load_model_mns(model_path):
    """
    Loads MobileNetV3 Small with custom classifier for binary classification,
    and attaches Grad-CAM hooks to its last convolutional layer.
    """
    model = mobilenet_v3_small(weights=None)
    
    # Check the number of input features for the classifier
    num_features = model.classifier[0].in_features
    # Modify the classifier for binary classification
    model.classifier = nn.Sequential(
        nn.Dropout(0.5), 
        nn.Linear(num_features, 1)
    )

    # Load model state
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device("cpu")))

    # Move model to device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Register Grad-CAM hooks on the last convolutional layer
    last_conv_layer = model.features[-1][0]
    register_gradcam_hooks(last_conv_layer)

    return model

In [None]:
test_transform  = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
import os
import pandas as pd

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        image_rgb = Image.open(img_path).convert("RGB")
        label = self.img_labels.iloc[idx, 2]
        
        if self.transform:
            image_rgb = self.transform(image_rgb)
        if self.target_transform:
            label = self.target_transform(label)
        return image_rgb, label

In [None]:
# cropped calcification data
calc_test_label_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Calc-Test-png-cropped/labels/calc-test_labels.csv"
calc_test_img_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Calc-Test-png-cropped/images"

# now apply the transformations to the calcification images
calc_test_data = CustomImageDataset(calc_test_label_dir, calc_test_img_dir, test_transform)

# cropped mass data
mass_test_label_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Mass-Test-png-cropped/labels/mass-test_labels.csv"
mass_test_img_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Mass-Test-png-cropped/images"

# now apply the transformations to the mass images
mass_test_data = CustomImageDataset(mass_test_label_dir, mass_test_img_dir, test_transform)

# Merge test datasets
cropped_combined_test_data = ConcatDataset([calc_test_data, mass_test_data])
print(f"Total testing samples: {len(cropped_combined_test_data)}")

# Create DataLoaders
batch_size = 32  # Adjust based on your GPU memory
num_workers = 0

cropped_test_dataloader = DataLoader(cropped_combined_test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
labels_map = {
    0: "Benign",
    1: "Malignant",
}

#mass_sample_idxs = []
#sample_idxs = []
#sample_idxs = [438, 587, 270, 484, 614, 475, 498, 487, 112, 185, 21, 517, 113, 459, 282]
#mass_sample_idxs = [38, 292, 24, 69, 140, 97, 170, 63, 119, 4, 183, 276, 176, 46, 247]
mass_sample_idxs = [38, 292, 24, 69, 140, 97, 170, 63, 119]

figure = plt.figure(figsize=(9, 9))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    #sample_idx = torch.randint(len(cropped_combined_test_data), size=(1,)).item()
    #sample_idxs.append(sample_idx)
    #sample_idx = sample_idxs[i-1]
    #img, label = cropped_combined_test_data[sample_idx]
    #print(img.shape)

    #mass_sample_idx = torch.randint(len(mass_test_data), size=(1,)).item()
    #mass_sample_idxs.append(mass_sample_idx)
    mass_sample_idx = mass_sample_idxs[i-1]
    img, label = mass_test_data[mass_sample_idx]

    image_np = np.array(img)  # Convert to NumPy array
    print(image_np.shape)

    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    #plt.imshow(img.squeeze(), cmap="gray")
    if(torch.is_tensor(img)):
        plt.imshow(img.permute(1, 2, 0))
    else:
        plt.imshow(img)
plt.show()

In [None]:
def predict_and_visualise_single_image(model, dataset, index=None):
    """
    Displays a single image (original, heatmap, overlay) from the dataset 
    along with model predictions.

    Parameters:
        model (torch.nn.Module): Trained model for prediction.
        dataset (torch.utils.data.Dataset): Dataset containing images.
        index (int, optional): The specific index of the image to visualise.
                               If None, a random index is chosen.
    """
    model.eval()
    
    # If no index is specified, pick a random index from the dataset
    if index is None:
        index = random.randint(0, len(dataset) - 1)

    # Load and prepare the single image
    img_tensor, label = dataset[index]
    img_tensor = img_tensor.unsqueeze(0).to(device)

    # Get model prediction
    with torch.no_grad():
        output = model(img_tensor)
        prob = output.sigmoid()
        pred_label = (prob > 0.5).float().item()
    
    # Compute adjusted confidence score
    prob_value = prob.item()
    if pred_label == 1:
        confidence = (prob_value - 0.5) * 200
    else:
        confidence = (0.5 - prob_value) * 200

    # Generate Grad-CAM heatmap
    with torch.set_grad_enabled(True):
        heatmap = generate_gradcam(model, img_tensor)

    # Convert tensors to displayable images
    # overlay_heatmap returns (original_img, overlay_img)
    original_img, overlay_img = overlay_heatmap(img_tensor.squeeze(), heatmap)

    # Get labels
    true_label = "Malignant" if label == 1 else "Benign"
    pred_label_str = "Malignant" if pred_label == 1 else "Benign"

    # Create a figure with 3 columns for original, heatmap, and overlay
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # 1. Original Image
    axes[0].imshow(original_img)
    axes[0].set_title(
        f"Index: {index}\nTrue: {true_label}\nPred: {pred_label_str}\nConf: {confidence:.1f}%",
        fontsize=11
    )
    axes[0].axis("off")

    # 2. Heatmap (Grad-CAM)
    axes[1].imshow(heatmap)
    axes[1].set_title("Grad-CAM Heatmap", fontsize=11)
    axes[1].axis("off")

    # 3. Overlay (Original + Heatmap)
    axes[2].imshow(original_img)
    axes[2].imshow(overlay_img, alpha=0.4)
    axes[2].set_title("Overlay", fontsize=11)
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
def predict_and_visualise(model, dataset, sample_indices):
    """
    Predicts and visualises results for selected samples using Grad-CAM.

    Parameters:
        model (torch.nn.Module): Trained model for prediction.
        dataset (torch.utils.data.Dataset): Dataset to sample from.
        sample_indices (List[int]): Indices of samples to visualise.
    """
    model.eval()
    
    sample_images = []
    sample_heatmaps = []
    sample_labels = []
    sample_preds = []
    sample_confidences = []

    for idx in sample_indices:
        X, label = dataset[idx]
        X = X.unsqueeze(0).to(device)

        # Model prediction
        with torch.no_grad():
            output = model(X)
            prob = output.sigmoid()
            pred_label = (prob > 0.5).float().item()

        # Confidence score
        prob_value = prob.item()
        if pred_label == 1:
            confidence = (prob_value - 0.5) * 200
        else:
            confidence = (0.5 - prob_value) * 200

        # Grad-CAM
        with torch.set_grad_enabled(True):
            heatmap = generate_gradcam(model, X)

        # Unnormalise image and overlay heatmap
        original_img, overlay_img = overlay_heatmap(X.squeeze(), heatmap)

        # Store everything
        sample_images.append(original_img)
        sample_heatmaps.append(overlay_img)
        sample_labels.append(int(label))
        sample_preds.append(int(pred_label))
        sample_confidences.append(confidence)

    # Plots
    num_examples = len(sample_indices)
    num_cols = 3
    num_rows = (num_examples + num_cols - 1) // num_cols

    # Plot original images and Grad-CAM heatmaps
    fig, axes = plt.subplots(num_rows * 2, 3, figsize=(12, 6 * num_rows))

    for i in range(num_examples):
        row = (i // num_cols) * 2
        col = i % num_cols

        # Original image
        axes[row, col].imshow(sample_images[i], cmap="gray")
        true_label = "Malignant" if sample_labels[i] == 1 else "Benign"
        pred_label_str = "Malignant" if sample_preds[i] == 1 else "Benign"
        confidence = sample_confidences[i]
        axes[row, col].set_title(f"True: {true_label}\nPred: {pred_label_str}\nConf: {confidence:.1f}%")
        axes[row, col].axis("off")

        # Grad-CAM overlay
        axes[row + 1, col].imshow(sample_images[i], cmap="gray")
        axes[row + 1, col].imshow(sample_heatmaps[i], alpha=0.4, interpolation="nearest")
        axes[row + 1, col].set_title("Grad-CAM")
        axes[row + 1, col].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
#densenet169
model_path_run25 = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/webapp/static/models/best_model_run25.pth"
#densenet121
model_path_run13 = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/webapp/static/models/best_model_run13.pth"
#mobilenetv3 large
model_path_run39 = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/webapp/static/models/best_model_run39.pth"
#mobilenetv3 small
model_path_run15 = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/webapp/static/models/best_model_run15.pth"
#efficientnetb0
model_path_run20 = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/webapp/static/models/best_model_run20.pth"

In [None]:
model_d169 = load_model_d169(model_path_run25)
model_d121 = load_model_d121(model_path_run13)
model_mnl = load_model_mnl(model_path_run39)
model_mns = load_model_mns(model_path_run15)
model_eb0 = load_model_eb0(model_path_run20)

In [None]:
predict_and_visualise_single_image(model_d121, mass_test_data)

In [None]:
predict_and_visualise(model_d121, cropped_combined_test_data, mass_sample_idxs)