In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
from torchvision import models, transforms
import matplotlib.pyplot as plt
import math

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        
        self.hook_handles.append(target_layer.register_forward_hook(self._save_activation))
        self.hook_handles.append(target_layer.register_full_backward_hook(self._save_gradient))

    def _save_activation(self, module, input, output):
        self.activations = output.detach()

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def _reshape_transform(self, tensor):
        result = tensor[:, 1:, :] 

        result = result.transpose(1, 2)

        width = int(math.sqrt(result.size(2)))
        height = int(math.sqrt(result.size(2)))

        result = result.reshape(tensor.size(0), result.size(1), height, width)
        return result

    def __call__(self, input_tensor, target_category=None):
        self.model.zero_grad()
        output = self.model(input_tensor)
        
        if target_category is None:
            target_category = torch.argmax(output, dim=1).item()
            
        one_hot = torch.zeros_like(output).to(input_tensor.device)
        one_hot[:, target_category] = 1
        
        target_score = (output * one_hot).sum()
        target_score.backward()
    
        gradients = self.gradients
        activations = self.activations

        if len(activations.shape) == 3:
            gradients = self._reshape_transform(gradients)
            activations = self._reshape_transform(activations)

        alpha = torch.mean(gradients, dim=(2, 3), keepdim=True)
        cam = (alpha * activations).sum(dim=1, keepdim=True)
        
        cam = F.relu(cam)

        cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
        cam = cam.squeeze(0)
        
        cam = cam - cam.min()
        if cam.max() > 1e-6:
            cam = cam / cam.max()
        else:
            cam = torch.zeros_like(cam)
            
        return cam.squeeze().cpu().numpy()

    def __del__(self):
        for handle in self.hook_handles:
            handle.remove()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)

num_features = model.heads.head.in_features
model.heads.head = nn.Linear(num_features, 2) 

model.load_state_dict(torch.load("Vit-Base_weights.pth", map_location=device))
model = model.to(device)
model.eval()

target_layer = model.encoder.layers[-1].ln_1 

grad_cam = GradCAM(model, target_layer)

In [None]:
def show_cam_on_image(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
    return superimposed_img

def deprocess_image(img_tensor, mean, std):
    mean_tensor = torch.tensor(mean).view(3, 1, 1).to(img_tensor.device)
    std_tensor = torch.tensor(std).view(3, 1, 1).to(img_tensor.device)
    img = img_tensor * std_tensor + mean_tensor
    img = img.permute(1, 2, 0).cpu().numpy()
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    return img

single_image_path = "DSC6772_idx129.png"
original_img_pil = Image.open(single_image_path).convert("RGB")
width, height = original_img_pil.size

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

data_transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

img_tensor = data_transform(original_img_pil).unsqueeze(0).to(device)

heatmap = grad_cam(img_tensor)

denormalized_img_np = deprocess_image(img_tensor.squeeze(0), mean, std)
overlayed_img = show_cam_on_image(denormalized_img_np, heatmap)

heatmap_pil = Image.fromarray(np.uint8(255 * heatmap)).resize((width, height), resample=Image.BILINEAR)
overlayed_img_pil = Image.fromarray(overlayed_img).resize((width, height), resample=Image.BILINEAR)

plt.figure(figsize=(6, 6))
plt.imshow(overlayed_img_pil)
plt.axis('off')
plt.tight_layout()

plt.savefig(r"BC15_Vitbase_DSC6772_idx129.png", bbox_inches='tight', pad_inches=0, dpi=300)

plt.show()