In [None]:
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torchvision import models, transforms
from torchvision.datasets import CIFAR10

In [None]:
# resnet is trained on imagenet, so we use imagenet stats
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std  = (0.229, 0.224, 0.225)

transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

ds = CIFAR10(
    root="../../assets/cifar10", 
    train=True, 
    download=True, 
    transform=transforms
)

In [None]:
input_image, label = ds[40]

mean = torch.tensor(imagenet_mean).view(3, 1, 1)
std = torch.tensor(imagenet_std).view(3, 1, 1)

img = input_image * std + mean
img = img.clamp(0, 1)
img = img.permute(1, 2, 0)  # (224, 224, 3)

plt.imshow(img)
plt.axis("off")
plt.title(f"Label: {label}")
plt.show()

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 10 classes

model.load_state_dict(torch.load("../../assets/models/finetuned_resnet18.pth"))
model.eval()

In [None]:
model.eval()

target_class = 5-1  # example target class

input_image = input_image.unsqueeze(0)  # (1, 3, 224, 224)
input_image = input_image.clone().detach()
input_image.requires_grad_(True)

output = model(input_image)

model.zero_grad()
output[0, target_class].backward()

# get gradients w.r.t. input
gradients = input_image.grad[0]  # (3, H, W)

# saliency map: max or sum across channels
saliency_map = gradients.abs().sum(dim=0)

saliency_map = saliency_map.cpu().numpy()
saliency_map = (saliency_map - saliency_map.min()) / (
    saliency_map.max() - saliency_map.min() + 1e-8
)

In [None]:
plt.figure(figsize=(4, 4))
plt.imshow(img)
plt.imshow(saliency_map, cmap="hot", alpha=0.5)
plt.title("Saliency Overlay")
plt.axis("off")
plt.show()