# 6. Grad-CAM Visualization

This notebook generates **Gradient-weighted Class Activation Maps** (Grad-CAM)
to visualize which regions of an image the model focuses on for each detection.

Supports both Grad-CAM and Grad-CAM++ (better for multiple instances).

**Prerequisites:** Run `1_setup.ipynb` and have trained models available.

## 6.1 Configuration

In [None]:
import os
import matplotlib.pyplot as plt

from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

import config
from utils.gradcam import Detectron2GradCAM

In [None]:
# ===================== CONFIGURE =====================

# Image to analyze
IMAGE_PATH = "/content/drive/MyDrive/TESE/images/10137.jpg"

# Model
MODEL_KEY = "faster_rcnn_R50"
TRAINED_MODEL_KEY = "total_faster_rcnn_R50"
MODEL_SOURCE = "agar"

# Grad-CAM settings
TARGET_INSTANCE = 0                               # Which detection to explain
LAYER_NAME = "backbone.bottom_up.res5.2.conv3"    # Target conv layer
GRAD_CAM_TYPE = "GradCAM"                         # 'GradCAM' or 'GradCAM++'

# ====================================================

config_file_path = model_zoo.get_config_file(config.MODELS[MODEL_KEY])
model_file = config.get_model_weights(TRAINED_MODEL_KEY, MODEL_SOURCE)

print(f"Image: {IMAGE_PATH}")
print(f"Model: {TRAINED_MODEL_KEY}")
print(f"Layer: {LAYER_NAME}")

## 6.2 Generate Grad-CAM

In [None]:
cam_extractor = Detectron2GradCAM(
    config_file=config_file_path,
    model_file=model_file,
)

image_dict, cam_orig = cam_extractor.get_cam(
    img=IMAGE_PATH,
    target_instance=TARGET_INSTANCE,
    layer_name=LAYER_NAME,
    grad_cam_type=GRAD_CAM_TYPE,
)

## 6.3 Visualize

In [None]:
plt.rcParams["figure.figsize"] = (20, 8)

# Draw detections
v = Visualizer(
    image_dict["image"][:, :, ::-1],
    MetadataCatalog.get(cam_extractor.cfg.DATASETS.TRAIN[0]),
    scale=1.0,
)
out = v.draw_instance_predictions(
    image_dict["output"][0]["instances"][TARGET_INSTANCE].to("cpu")
)

fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# Left: detections
axes[0].imshow(out.get_image())
axes[0].set_title(f"Detections (Instance {TARGET_INSTANCE})")
axes[0].axis("off")

# Right: Grad-CAM overlay
axes[1].imshow(out.get_image())
axes[1].imshow(image_dict["cam"], cmap="jet", alpha=0.5)
axes[1].set_title(
    f"{GRAD_CAM_TYPE} â€” Instance {TARGET_INSTANCE} "
    f"(class: {image_dict['label']})"
)
axes[1].axis("off")

plt.tight_layout()
plt.show()