# Grad-CAM with ResNet34

In [4]:
import sys
sys.path.append("/Users/msrobin/GitHub Repositorys/Interpretable-Deep-Fake-Detection-2")
sys.argv = ["train.py"]

In [6]:
import torch
import torchvision
import torch.nn as nn
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import os
from training.detectors.resnet34 import ResNet34



ModuleNotFoundError: No module named 'training.detectors.resnet34'

In [None]:
# Load model config manually
resnet_config = {
    "num_classes": 2,
    "inc": 3,
    "mode": "default"
}
model = ResNet34(resnet_config)


In [None]:
# Load checkpoint
weights_path = "/Users/msrobin/Downloads/bcos_resnet_minimal_b2_ckpt_best.pth"
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model.eval().cuda() if torch.cuda.is_available() else model.eval()


In [None]:
# Define preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    tensor = transform(image).unsqueeze(0)
    return image, tensor


In [None]:
# Choose image path
image_path = "/Users/msrobin/GitHub Repositorys/Interpretable-Deep-Fake-Detection-2/datasets/2x2_images/sample_image.jpg"
pil_image, input_tensor = load_image(image_path)
input_tensor = input_tensor.cuda() if torch.cuda.is_available() else input_tensor


In [None]:
# Grad-CAM setup
target_layers = [model.resnet[-1][-1].conv2]  # Last conv layer in layer4
wrapped_model = model

cam = GradCAM(model=wrapped_model, target_layers=target_layers, use_cuda=torch.cuda.is_available())
targets = [ClassifierOutputTarget(1)]  # Change target class index if needed
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]

# Convert and overlay CAM
rgb_img = np.array(pil_image.resize((224, 224))) / 255.0
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
plt.imshow(visualization)
plt.axis("off")
plt.title("Grad-CAM for ResNet34")
plt.show()


Other way of visualization 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import torch.nn.functional as F

correct = 0
total = 0

for image, label, path in dataloader:
    image = image.to(device)
    label = label.to(device)

    # Forward pass
    output = wrapped_model(image)
    pred = torch.argmax(output, dim=1)

    # Accuracy
    correct += (pred == label).sum().item()
    total += label.size(0)

    # Grad-CAM
    grayscale_cam = cam(input_tensor=image, targets=[ClassifierOutputTarget(pred.item())])[0]

    # Prepare input for visualization
    input_image = image[0].detach().cpu().permute(1, 2, 0).numpy()
    input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min() + 1e-8)

    # Overlay CAM
    cam_overlay = show_cam_on_image(input_image, grayscale_cam, use_rgb=True)

    # Show
    plt.figure(figsize=(6, 6))
    plt.imshow(cam_overlay)
    plt.title(f"Prediction: {pred.item()} | Label: {label.item()}")
    plt.axis("off")
    plt.show()

# Final accuracy
print(f"✅ Accuracy: {correct}/{total} = {correct / total:.2%}")
