# **Visualization of CNNs: Grad-CAM**
* **Objective**: Convolutional Neural Networks are widely used on computer vision. They are powerful for processing grid-like data. However we hardly know how and why they work, due to the lack of decomposability into individually intuitive components. In this assignment, we use Grad-CAM, which highlights the regions of the input image that were important for the neural network prediction.


* NB: if `PIL` is not installed, try `conda install pillow`.
* Computations are light enough to be done on CPU.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
import matplotlib.pyplot as plt
import pickle
import urllib.request

import numpy as np
from PIL import Image

%matplotlib inline

## Download the Model
We provide you with a model `DenseNet-121`, already pretrained on the `ImageNet` classification dataset.
* **ImageNet**: A large dataset of photographs with 1 000 classes.
* **DenseNet-121**: A deep architecture for image classification (https://arxiv.org/abs/1608.06993)

In [None]:
densenet121 = models.densenet121(pretrained=True)
densenet121.eval() # set the model to evaluation model
pass

In [None]:
classes = pickle.load(urllib.request.urlopen('https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl'))

##classes is a dictionary with the name of each class 
print(classes)

## Input Images
We provide you with 20 images from ImageNet (download link on the webpage of the course or download directly using the following command line,).<br>
In order to use the pretrained model resnet34, the input image should be normalized using `mean = [0.485, 0.456, 0.406]`, and `std = [0.229, 0.224, 0.225]`, and be resized as `(224, 224)`.

In [None]:
def preprocess_image(dir_path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # Note: If the inverse normalisation is required, apply 1/x to the above object
    
    dataset = datasets.ImageFolder(dir_path, transforms.Compose([
            transforms.Resize(256), 
            transforms.CenterCrop(224), # resize the image to 224x224
            transforms.ToTensor(), # convert numpy.array to tensor
            normalize])) #normalize the tensor

    return (dataset)

In [None]:
import os
if not os.path.exists("data"):
    os.mkdir("data")
if not os.path.exists("data/TP2_images"):
    os.mkdir("data/TP2_images")
    !cd data/TP2_images && wget "https://www.lri.fr/~gcharpia/deeppractice/2025/TP2/TP2_images.zip" && unzip TP2_images.zip

dir_path = "data/" 
dataset = preprocess_image(dir_path)

In [None]:
# show the orignal image 
index = 5
input_image = Image.open(dataset.imgs[index][0]).convert('RGB')
plt.imshow(input_image)
plt.show()

In [None]:
input_image = dataset[index][0].view(1, 3, 224, 224)
output = densenet121(input_image)
values, indices = torch.topk(output, 3)
print("Top 3-classes:", indices[0].numpy(), [classes[x] for x in indices[0].numpy()])
print("Raw class scores:", values[0].detach().numpy())

# Grad-CAM 
* **Overview:** Given an image, and a category (‘tiger cat’) as input, we forward-propagate the image through the model to obtain the `raw class scores` before softmax. The gradients are set to zero for all classes except the desired class (tiger cat), which is set to 1. This signal is then backpropagated to the `rectified convolutional feature map` of interest, where we can compute the coarse Grad-CAM localization (blue heatmap).


* **To Do**: Define your own function Grad_CAM to achieve the visualization of the given images. For each image, choose the top-3 possible labels as the desired classes. Compare the heatmaps of the three classes, and conclude. 

More precisely, you should provide a function: `show_grad_cam(image: torch.tensor) -> None` that displays something like this:
![output_example.png](attachment:output_example.png)
where the heatmap will be correct (here it is just an example) and the first 3 classes are the top-3 predicted classes and the last is the least probable class according to the model.

* **Comment your code**: Your code should be easy to read and follow. Please comment your code, try to use the NumPy Style Python docstrings for your functions.

* **To be submitted within 2 weeks**: this notebook, **cleaned** (i.e. without results, for file size reasons: `menu > kernel > restart and clean`), in a state ready to be executed (with or without GPU) (if one just presses 'Enter' till the end, one should obtain all the results for all images) with a few comments at the end. No additional report, just the notebook!


* **Hints**: 
 + We need to record the output and grad_output of the feature maps to achieve Grad-CAM. In pytorch, the function `Hook` is defined for this purpose. Read the tutorial of [hook](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks) carefully.
 + More on [autograd](https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html) and [hooks](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks)
 + The pretrained model densenet doesn't have an activation function after its last layer, the output is indeed the `raw class scores`, you can use them directly. 
 + Your heatmap will have the same size as the feature map. You need to scale up the heatmap to the resized image (224x224, not the original one, before the normalization) for better observation purposes. The function [`torch.nn.functional.interpolate`](https://pytorch.org/docs/stable/nn.functional.html?highlight=interpolate#torch.nn.functional.interpolate) may help.  
 + Here is the link to the paper: [Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization](https://arxiv.org/pdf/1610.02391.pdf)

Class: ‘pug, pug-dog’ | Class: ‘tabby, tabby cat’
- | - 
![alt](https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/dog.jpg)| ![alt](https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/cat.jpg)

## Part 1: Grad-CAM implementation

<img src="attachment:6584e0fc-c5d3-4b90-904c-98adf0570295.png" width="400" />


In [None]:
class DenseNetFeatureHook:
    """
    Attaches forward/backward hooks to the last convolution layer in DenseNet121 
    (within: features -> denseblock4 -> denselayer16 -> conv2).
    This captures:
        - forward activations (feature maps),
        - backward gradients (from the chosen output).
    """
    def __init__(self, densenet_model):
        self.model = densenet_model
        self.stored_acts = None
        self.stored_grads = None
        
        # Register hooks
        self._setup_hooks()

    def _setup_hooks(self):
        """
        Locates the final conv2 submodule in denseblock4/denselayer16
        and registers forward/backward hooks on it.
        """
        last_conv = dict(self.model.features.named_children())['denseblock4']['denselayer16'].conv2
        
        # forward hook: store the forward outputs
        def _fwd_hook(module, inputs, outputs):
            self.stored_acts = outputs.detach()
        
        # backward hook: store gradients wrt those outputs
        def _bwd_hook(module, grad_in, grad_out):
            self.stored_grads = grad_out[0].detach()
        
        last_conv.register_forward_hook(_fwd_hook)
        last_conv.register_backward_hook(_bwd_hook)


def compute_gradcam(img_tensor, class_idx, hook_obj, net):
    """
    Generates a Grad-CAM heatmap for a single class index, given:
       - An input tensor (1 x 3 x H x W),
       - A hooked DenseNet model,
       - A target class to backprop from.

    Returns a NumPy array of shape (224, 224) scaled to [0,1].
    """
    net.zero_grad()
    
    outputs = net(img_tensor)  # shape: [1, 1000]

    target_score = outputs[0, class_idx]
    target_score.backward()

    # Retrieve activations & gradients from the hook
    acts = hook_obj.stored_acts       # shape: [1, channels, H, W]
    grads = hook_obj.stored_grads     # shape: [1, channels, H, W]

    # Compute channel-wise weights by global-average-pooling the gradients
    weights = F.adaptive_avg_pool2d(grads, 1)  # shape: [1, channels, 1, 1]

    # Multiply feature maps by weights and sum across channels
    weighted_sum = (weights * acts).sum(dim=1, keepdim=True)  # [1, 1, H, W]

    # ReLU to keep only positive contributions
    relu_map = F.relu(weighted_sum)

    upsampled = F.interpolate(relu_map, size=(224, 224), mode='bilinear', align_corners=False)


    mean_val = torch.mean(upsampled)
    std_val = torch.std(upsampled)
    heatmap = (upsampled - mean_val) / (std_val + 1e-8)

    heatmap_np = heatmap[0, 0].cpu().numpy()
    return heatmap_np



def display_grad_cam(img_tensor, net=None, class_dict=None):
    """
    Visualizes:
        - The original image (roughly scaled to [0,1] if normalized),
        - Grad-CAM overlays for the top-3 predicted classes,
        - Grad-CAM overlay for the lowest-scoring class.

    Args:
        img_tensor (torch.Tensor): [1, 3, 224, 224] normalized for ImageNet.
        net (torch.nn.Module): A DenseNet model. If None, loads DenseNet121 
                               with ImageNet weights.
        class_dict (dict): Optional {class_idx: "classname"} map for labeling.
    """
    if net is None:
        # Load a pretrained DenseNet121 
        from torchvision.models import densenet121, DenseNet121_Weights
        net = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
    net.eval()

    # Attach hooks to the final conv layer
    hook_obj = DenseNetFeatureHook(net)

    with torch.no_grad():
        preds = net(img_tensor)  # shape: [1, 1000]

    # top-3 classes and the minimal class
    top_vals, top_idxs = torch.topk(preds, 3, dim=1)
    min_idx = torch.argmin(preds, dim=1)

    fig, axes = plt.subplots(1, 5, figsize=(20, 5))

    image_np = img_tensor[0].permute(1, 2, 0).cpu().numpy()
    min_val, max_val = image_np.min(), image_np.max()
    disp_image = (image_np - min_val) / (max_val - min_val + 1e-8)

    # (a) Show the original (normalized) image first
    axes[0].imshow(disp_image)
    axes[0].set_title("Original")
    axes[0].axis('off')

    # (b) Compute Grad-CAM for top-3 predicted class
    for col in range(3):
        cls_index = top_idxs[0, col].item()
        # score_val = top_vals[0, col].item()

        # Generate Grad-CAM
        heatmap_top = compute_gradcam(img_tensor, cls_index, hook_obj, net)

        # Overlay with the original
        axes[col + 1].imshow(disp_image)
        axes[col + 1].imshow(heatmap_top, cmap='jet', alpha=0.5)
        
        # If class_dict is provided, get a label
        if class_dict and (cls_index in class_dict):
            label_name = class_dict[cls_index]
        else:
            label_name = f"Class {cls_index}"

        axes[col + 1].set_title(f"{label_name} ")
        axes[col + 1].axis('off')

    # (c) Compute Grad-CAM for the lowest-scoring class
    worst_idx = min_idx.item()
    worst_score = preds[0, worst_idx].item()

    heatmap_worst = compute_gradcam(img_tensor, worst_idx, hook_obj, net)

    axes[4].imshow(disp_image)
    axes[4].imshow(heatmap_worst, cmap='jet', alpha=0.5)

    if class_dict and (worst_idx in class_dict):
        lowest_label = class_dict[worst_idx]
    else:
        lowest_label = f"Class {worst_idx}"

    axes[4].set_title(f"Lowest: {lowest_label}")
    axes[4].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
display_grad_cam(input_image,class_dict=classes)

## Part 2: Try it on a few (1 to 3) images and comment

In [None]:

sample_indices = [9, 14, 16]


for sample_idx in sample_indices:
    raw_tensor = dataset[sample_idx][0]
    batched_input = raw_tensor.unsqueeze(0)
    display_grad_cam(batched_input,class_dict=classes)


1. Region of Most Prominent Attention
We frequently notice that the heatmap focuses on the subject's main body or distinguishing characteristics (such as an animal's face or torso) for accurately predicted classes (Top-1 or within the Top-3). The network is using those areas as the strongest discriminative cues for that category, according to this.
2. Class Least Likely (Lowest)
The Grad-CAM heatmap is usually more weaker or seems more random for the "worst" class. Small hotspots in unexpected places or a generally black map may occasionally be visible, signifying that the network has low activation for that class. When sporadic areas do illuminate, it typically indicates that the class has little bearing on the picture.


## Part 3: Try GradCAM on others convolutional layers, describe and comment the results

In [None]:
# Your code here
class DenseNetFeatureHook_2:
    """
    Attaches forward/backward hooks to the last convolution layer in DenseNet121 
    (within: features -> denseblock4 -> denselayer16 -> conv2).
    This captures:
        - forward activations (feature maps),
        - backward gradients (from the chosen output).
    """
    def __init__(self, densenet_model):
        self.model = densenet_model
        self.stored_acts = None
        self.stored_grads = None
        
        # Register hooks
        self._setup_hooks()

    def _setup_hooks(self):
        # This is the final conv2 in the final denseblock
        # last_conv = dict(self.model.features.named_children())['denseblock4']['denselayer16'].conv2
        
        # The first conv layer in denseblock3:
        block3 = dict(self.model.features.named_children())['denseblock3']
        #pick 'denselayer12' inside block3, then 'conv2':
        layer_choice = block3['denselayer12'].conv2
        
    
        def _fwd_hook(module, inputs, outputs):
            self.stored_acts = outputs.detach()
    
        def _bwd_hook(module, grad_in, grad_out):
            self.stored_grads = grad_out[0].detach()
    
        layer_choice.register_forward_hook(_fwd_hook)
        layer_choice.register_backward_hook(_bwd_hook)



def compute_gradcam_2(img_tensor, class_idx, hook_obj, net):
    """
    Generates a Grad-CAM heatmap for a single class index, given:
       - An input tensor (1 x 3 x H x W),
       - A hooked DenseNet model,
       - A target class to backprop from.

    Returns a NumPy array of shape (224, 224) scaled to [0,1].
    """
    net.zero_grad()
    
    outputs = net(img_tensor)  # shape: [1, 1000]

    target_score = outputs[0, class_idx]
    target_score.backward()

    # Retrieve activations & gradients from the hook
    acts = hook_obj.stored_acts       # shape: [1, channels, H, W]
    grads = hook_obj.stored_grads     # shape: [1, channels, H, W]

    # Compute channel-wise weights by global-average-pooling the gradients
    weights = F.adaptive_avg_pool2d(grads, 1)  # shape: [1, channels, 1, 1]

    # Multiply feature maps by weights and sum across channels
    weighted_sum = (weights * acts).sum(dim=1, keepdim=True)  # [1, 1, H, W]

    # ReLU to keep only positive contributions
    relu_map = F.relu(weighted_sum)

    upsampled = F.interpolate(relu_map, size=(224, 224), mode='bilinear', align_corners=False)
    
    mean_val = torch.mean(upsampled)
    std_val = torch.std(upsampled)
    heatmap = (upsampled - mean_val) / (std_val + 1e-8)

    heatmap_np = heatmap[0, 0].cpu().numpy()
    return heatmap_np



def display_grad_cam_2(img_tensor, net=None, class_dict=None):
    """
    Visualizes:
        - The original image (roughly scaled to [0,1] if normalized),
        - Grad-CAM overlays for the top-3 predicted classes,
        - Grad-CAM overlay for the lowest-scoring class.

    Args:
        img_tensor (torch.Tensor): [1, 3, 224, 224] normalized for ImageNet.
        net (torch.nn.Module): A DenseNet model. If None, loads DenseNet121 
                               with ImageNet weights.
        class_dict (dict): Optional {class_idx: "classname"} map for labeling.
    """
    if net is None:
        # Load a pretrained DenseNet121 
        from torchvision.models import densenet121, DenseNet121_Weights
        net = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
    net.eval()

    # Attach hooks to the final conv layer
    hook_obj = DenseNetFeatureHook_2(net)

    with torch.no_grad():
        preds = net(img_tensor)  # shape: [1, 1000]

    # top-3 classes and the minimal class
    top_vals, top_idxs = torch.topk(preds, 3, dim=1)
    min_idx = torch.argmin(preds, dim=1)

    fig, axes = plt.subplots(1, 5, figsize=(20, 5))

    image_np = img_tensor[0].permute(1, 2, 0).cpu().numpy()
    min_val, max_val = image_np.min(), image_np.max()
    disp_image = (image_np - min_val) / (max_val - min_val + 1e-8)

    # (a) Show the original (normalized) image first
    axes[0].imshow(disp_image)
    axes[0].set_title("Original")
    axes[0].axis('off')

    # (b) Compute Grad-CAM for top-3 predicted class
    for col in range(3):
        cls_index = top_idxs[0, col].item()
        # score_val = top_vals[0, col].item()

        # Generate Grad-CAM
        heatmap_top = compute_gradcam_2(img_tensor, cls_index, hook_obj, net)

        # Overlay with the original
        axes[col + 1].imshow(disp_image)
        axes[col + 1].imshow(heatmap_top, cmap='jet', alpha=0.5)
        
        # If class_dict is provided, get a label
        if class_dict and (cls_index in class_dict):
            label_name = class_dict[cls_index]
        else:
            label_name = f"Class {cls_index}"

        axes[col + 1].set_title(f"{label_name} ")
        axes[col + 1].axis('off')

    # (c) Compute Grad-CAM for the lowest-scoring class
    worst_idx = min_idx.item()
    worst_score = preds[0, worst_idx].item()

    heatmap_worst = compute_gradcam_2(img_tensor, worst_idx, hook_obj, net)

    axes[4].imshow(disp_image)
    axes[4].imshow(heatmap_worst, cmap='jet', alpha=0.5)

    if class_dict and (worst_idx in class_dict):
        lowest_label = class_dict[worst_idx]
    else:
        lowest_label = f"Class {worst_idx}"

    axes[4].set_title(f"Lowest: {lowest_label}")
    axes[4].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
sample_indices = [9, 14, 16]


for sample_idx in sample_indices:
    raw_tensor = dataset[sample_idx][0]
    batched_input = raw_tensor.unsqueeze(0)
    display_grad_cam_2(batched_input,class_dict=classes)


The heatmaps cover a broad region of the object, occasionally also including some background. Textures and edges are captured by the comparatively fine-grained activation. The primary object characteristics that characterize the class, however, do not appear to be as heavily highlighted.

## Part 4: Try GradCAM on `9928031928.png` , describe and comment the results

In [None]:
target_filename = "9928031928.png"

def find_image(target_filename,dataset):
    matched_path = None
    for candidate_path, _ in dataset.imgs:
        if target_filename in candidate_path:
            matched_path = candidate_path
            return matched_path
matched_path = find_image(target_filename,dataset)
# Load the image from the matched path, convert to RGB
loaded_image = Image.open(matched_path).convert("RGB")

# Transform the image using the dataset's own transform pipeline, then add batch dimension
transformed_tensor = dataset.transform(loaded_image).unsqueeze(0)

# Perform Grad-CAM visualization
display_grad_cam(transformed_tensor,class_dict=classes)


Top 1 Prediction: "Landrover, Jeep"
The heatmap appears scattered over the entire body of the elephant and partially around the background. mistakenly make it think of the texture or shape of a car. The elephant's distinguishing characteristics, such as its trunk and ears, are not given much attention in the activation.

## Part 5: What are the principal contributions of GradCAM (the answer is in the paper) ?

Gradient-weighted Class Activation Mapping (Grad-CAM), uses the gradients of any target concept (say ‘dog’ in a classification network or a sequence of words in captioning network) flowing into the final convolutional layer to produce a coarse localization map highlighting the important regions in the image for predicting the concept. Unlike previous approaches, Grad-CAM is applicable to a wide variety of CNN model-families: (1) CNNs with fully-connected layers (e.g. VGG), (2) CNNs used for structured outputs (e.g. captioning), (3) CNNs used in tasks with multi-modal inputs (e.g. visual question answering) or reinforcement learning, all without architectural changes or re-training.

## Bonus 5: What are the main differences between DenseNet and ResNet ?

DenseNet employs a dense connection mechanism, where each DenseLayer in a DenseBlock performs convolution-based feature extraction. Specifically, each DenseLayer concatenates every previous layer’s output (in the channel dimension) to form its own input, then passes this concatenated feature into a convolution operation to produce the current layer’s output. Finally, that output is also concatenated with the layer’s input as the input to the next DenseLayer.

This structure bears some resemblance to ResNet, but in a ResNet residual block, the input is added to the current layer’s output feature (rather than concatenated), and the input typically comes only from the immediately preceding layer rather than from all preceding layers.

In [None]:
import torch
import torch.nn.functional as F
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np

class ResNetFeatureHook:
    """
    Attaches forward/backward hooks to the last convolution layer in ResNet
    (commonly layer4[-1].conv3 for ResNet-50/101/152).

    This captures:
        - forward activations (feature maps),
        - backward gradients (from the chosen output).
    """
    def __init__(self, resnet_model):
        self.model = resnet_model
        self.stored_acts = None
        self.stored_grads = None

        self._setup_hooks()

    def _setup_hooks(self):
        """
        Locates the final conv3 submodule in layer4[-1] (the last residual block)
        and registers forward/backward hooks on it.
        """
'
        last_conv = self.model.layer4[-1].conv3

        def _fwd_hook(module, inputs, outputs):
            self.stored_acts = outputs.detach()
        def _bwd_hook(module, grad_in, grad_out):
            self.stored_grads = grad_out[0].detach()
        
        last_conv.register_forward_hook(_fwd_hook)
        last_conv.register_backward_hook(_bwd_hook)


def compute_gradcam_3(img_tensor, class_idx, hook_obj, net):
    """
    Generates a Grad-CAM heatmap for a single class index, given:
       - An input tensor (shape [1, 3, H, W]),
       - A hooked ResNet model,
       - A target class to backprop from.

    Returns a NumPy array of shape (224, 224) scaled via z-score normalization.
    """

    net.zero_grad()
    outputs = net(img_tensor) 
    
    target_score = outputs[0, class_idx]
    target_score.backward()
    
    acts = hook_obj.stored_acts       # [1, channels, H, W]
    grads = hook_obj.stored_grads     # [1, channels, H, W]
    
    weights = F.adaptive_avg_pool2d(grads, 1)
    
    weighted_sum = (weights * acts).sum(dim=1, keepdim=True)  # [1, 1, H, W]
    
    relu_map = F.relu(weighted_sum)
    
    upsampled = F.interpolate(relu_map, size=(224, 224), mode='bilinear', align_corners=False)
    
    mean_val = torch.mean(upsampled)
    std_val = torch.std(upsampled)
    heatmap = (upsampled - mean_val) / (std_val + 1e-8)
    
    heatmap_np = heatmap[0, 0].cpu().numpy()
    return heatmap_np

def display_grad_cam_3(img_tensor, net=None, class_dict=None):
    """
    Shows:
      - The original image (scaled to [0,1] if needed),
      - Grad-CAM for the top-3 predicted classes,
      - Grad-CAM for the lowest-scoring class.

    Args:
        img_tensor (torch.Tensor): [1, 3, 224, 224] normalized for ImageNet.
        net (torch.nn.Module): A ResNet model. If None, loads resnet50 with
                               ImageNet weights.
        class_dict (dict): Optional {class_idx: "classname"} map for labeling.
    """
    if net is None:
        from torchvision.models import resnet152, ResNet152_Weights
        net = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
    net.eval()
    
    # Attach hooks to the final conv layer in ResNet
    hook_obj = ResNetFeatureHook(net)
    
    with torch.no_grad():
        preds = net(img_tensor)  # shape: [1, 1000]
    
    top_vals, top_idxs = torch.topk(preds, 3, dim=1)
    min_idx = torch.argmin(preds, dim=1)
    
    fig, axes = plt.subplots(1, 5, figsize=(20, 5))
    
    image_np = img_tensor[0].permute(1, 2, 0).cpu().numpy()
    min_val, max_val = image_np.min(), image_np.max()
    disp_image = (image_np - min_val) / (max_val - min_val + 1e-8)
    
    axes[0].imshow(disp_image)
    axes[0].set_title("Original")
    axes[0].axis('off')
    
    for i in range(3):
        cls_idx = top_idxs[0, i].item()
        heatmap_top = compute_gradcam_3(img_tensor, cls_idx, hook_obj, net)
        
        axes[i + 1].imshow(disp_image)
        axes[i + 1].imshow(heatmap_top, cmap='jet', alpha=0.5)
        
        if class_dict and (cls_idx in class_dict):
            cls_name = class_dict[cls_idx]
        else:
            cls_name = f"Class {cls_idx}"
        
        axes[i + 1].set_title(cls_name)
        axes[i + 1].axis('off')
    
    worst_idx = min_idx.item()
    heatmap_worst = compute_gradcam_3(img_tensor, worst_idx, hook_obj, net)
    
    axes[4].imshow(disp_image)
    axes[4].imshow(heatmap_worst, cmap='jet', alpha=0.5)
    
    if class_dict and (worst_idx in class_dict):
        lowest_label = class_dict[worst_idx]
    else:
        lowest_label = f"Class {worst_idx}"
    
    axes[4].set_title(f"Lowest: {lowest_label}")
    axes[4].axis('off')
    
    plt.tight_layout()
    plt.show()


In [None]:
target_filename = "9928031928.png"

def find_image(target_filename,dataset):
    matched_path = None
    for candidate_path, _ in dataset.imgs:
        if target_filename in candidate_path:
            matched_path = candidate_path
            return matched_path
matched_path = find_image(target_filename,dataset)

loaded_image = Image.open(matched_path).convert("RGB")

transformed_tensor = dataset.transform(loaded_image).unsqueeze(0)

display_grad_cam_3(transformed_tensor,class_dict=classes)
