## GradCAM (Well, EigenCAM) for EfficientNet

In [1]:
import os
import sys

root_path = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(root_path)

import time
import numpy as np
from PIL import Image

import torch
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl

# Grad-CAM imports
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2

from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    EarlyStopping,
)

from src.utils.helpers import load_config
from src.training.dataset import ImageDataModule
from src.models.classification_model import ImageClassifier

from pytorch_grad_cam import EigenCAM, LayerCAM

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [2]:
# Helper to deal with ViT weirdness
def vit_reshape_transform(tensor):
    # Get patch grid size 
    h, w = net.backbone.patch_embed.grid_size 

    # Drop CLS token at position 0 
    patch_tokens = tensor[:, 1:, :]  

    # Reshape  
    patch_tokens = patch_tokens.reshape(tensor.size(0), h, w, tensor.size(2))

    # Permute to (B, C, H_p, W_p) for Grad-CAM
    patch_tokens = patch_tokens.permute(0, 3, 1, 2).contiguous()

    return patch_tokens

In [3]:
# Can really use any model here - the important part is the test dataset in the model. 
config = load_config("sampled_efficientnet_b0.yaml")
config1 = config.copy()

## ---- Data Set goes here:
config1["data"]["test_path"] = "../../src/data/sampled/test"
data_module = ImageDataModule(config1)
data_module.setup("test")  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## ---- Model you want to test goes here:
ckpt_path = "../EfficientNet Standard/checkpoints/last.ckpt"

net = ImageClassifier.load_from_checkpoint(ckpt_path, config=config)
net.to(device)
net.eval()

# Sanity check
print("Model device (first param):", next(net.parameters()).device)

[NOTICE] Found 50 corrupted files in ../../src/data/sampled/train.
[NOTICE] Found 5 corrupted files in ../../src/data/sampled/validation.
[NOTICE] Found 13 corrupted files in ../../src/data/sampled/test.
Model device (first param): cuda:0


## Loop this section for each image you want

In [19]:
# Set Image Index to evaluate - See the file test_results_per_image.csv for possible results
idx=5

In [20]:

# Verify the image you want is right
test_ds = data_module.test_dataset  
img_tensor, label, path = test_ds[idx]
print(f"Using image #{idx} -> {path}, label = {label}")

# More weirdness for size
input_tensor = img_tensor.unsqueeze(0).to(device)  

# For overlay: load original image and resize to network input size
pil_img = Image.open(path).convert("RGB").resize((data_module.img_size, data_module.img_size))
img_float = np.array(pil_img).astype(np.float32) / 255.0   


Using image #5 -> ../../src/data/sampled/test\ai\004_sdv5_00095.jpg, label = 0


In [21]:
# Setup Grad Cam
# I used EigenCAM for this for consistency with ViT

## Option 1 - Standard
#last_block = net.backbone.blocks[-1]
#target_layers = [last_block]

## Option 2 - Conv head
target_layers = [net.backbone.conv_head]

cam = EigenCAM(
    model=net,
    target_layers=target_layers, 
    # I don't need reshaping for this
)

In [22]:
# Do Predictions
with torch.no_grad():
    logits = net(input_tensor)
    pred_class = int(logits.argmax(dim=1).item())
print("Predicted class:", pred_class) ## 0 = AI, 1 = Real
 
grayscale_cam = cam(
    input_tensor=input_tensor,
    targets=[ClassifierOutputTarget(pred_class)],
)[0, :]   

Predicted class: 0


In [23]:
# Make Heatmap and Save
cam_image = show_cam_on_image(img_float, grayscale_cam, image_weight=0.6,  use_rgb=True)
cam_bgr = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
out_path = "gradcam_efficientnet_"+str(idx)+".jpg"
cv2.imwrite(out_path, cam_bgr)

print(f"Saved Grad-CAM overlay to {out_path}")

Saved Grad-CAM overlay to gradcam_efficientnet_5.jpg
