# CLIP GradCam for ViT models

Adapted from [Mikael17125 repo](https://github.com/Mikael17125/ViT-GradCAM).

In [28]:
import timm
import clip
import torch
from skimage import io  # Image processing from scikit-image

GradCam class

In [2]:
import cv2  # OpenCV for image processing
import numpy as np  # NumPy for numerical operations


class GradCam:
    def __init__(self, model, target):
        self.model = model.eval()  # Set the model to evaluation mode
        self.feature = None  # To store the features from the target layer
        self.gradient = None  # To store the gradients from the target layer
        self.handlers = []  # List to keep track of hooks
        self.target = target  # Target layer for Grad-CAM
        self._get_hook()  # Register hooks to the target layer

    # Hook to get features from the forward pass
    def _get_features_hook(self, module, input, output):
        # Store and reshape the output features
        self.feature = self.reshape_transform(output)

    # Hook to get gradients from the backward pass
    def _get_grads_hook(self, module, input_grad, output_grad):
        # Store and reshape the output gradients
        self.gradient = self.reshape_transform(output_grad)

        def _store_grad(grad):
            self.gradient = self.reshape_transform(
                grad)  # Store gradients for later use

        # Register hook to store gradients
        output_grad.register_hook(_store_grad)

    # Register forward hooks to the target layer
    def _get_hook(self):
        self.target.register_forward_hook(self._get_features_hook)
        self.target.register_forward_hook(self._get_grads_hook)

    # Function to reshape the tensor for visualization
    def reshape_transform(self, tensor, height=14, width=14):
        result = tensor[:, 1:, :].reshape(
            tensor.size(0), height, width, tensor.size(2))
        result = result.transpose(2, 3).transpose(
            1, 2)  # Rearrange dimensions to (C, H, W)
        return result

    # Function to compute the Grad-CAM heatmap
    def __call__(self, inputs):
        self.model.zero_grad()  # Zero the gradients
        output = self.model(inputs)  # Forward pass

        # Get the index of the highest score in the output
        index = np.argmax(output.cpu().data.numpy())
        target = output[0][index]  # Get the target score
        target.backward()  # Backward pass to compute gradients

        # Get the gradients and features
        gradient = self.gradient[0].cpu().data.numpy()
        weight = np.mean(gradient, axis=(1, 2))  # Average the gradients
        feature = self.feature[0].cpu().data.numpy()

        # Compute the weighted sum of the features
        cam = feature * weight[:, np.newaxis, np.newaxis]
        cam = np.sum(cam, axis=0)  # Sum over the channels
        cam = np.maximum(cam, 0)  # Apply ReLU to remove negative values

        # Normalize the heatmap
        cam -= np.min(cam)
        cam /= np.max(cam)
        # Resize to match the input image size
        cam = cv2.resize(cam, (224, 224))
        return cam  # Return the Grad-CAM heatmap

Helper functions

In [3]:
def prepare_input(image):
    image = image.copy()  # Copy the image to avoid modifying the original

    # Normalize the image using the mean and standard deviation
    means = np.array([0.5, 0.5, 0.5])
    stds = np.array([0.5, 0.5, 0.5])
    image -= means
    image /= stds

    # Transpose the image to match the model's expected input format (C, H, W)
    image = np.ascontiguousarray(np.transpose(image, (2, 0, 1)))
    image = image[np.newaxis, ...]  # Add batch dimension

    return torch.tensor(image, requires_grad=True)  # Convert to PyTorch tensor

def gen_cam(image, mask):
    # Create a heatmap from the Grad-CAM mask
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255

    # Superimpose the heatmap on the original image
    cam = (1 - 0.5) * heatmap + 0.5 * image
    cam = cam / np.max(cam)  # Normalize the result
    return np.uint8(255 * cam)  # Convert to 8-bit image

In [4]:
model = timm.create_model('vit_base_patch32_224', pretrained=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load('ViT-B/32', device=device)
clip_model.eval()

In [6]:
model.blocks[-1].norm1

LayerNorm((768,), eps=1e-06, elementwise_affine=True)

In [36]:
target_layer = clip_model.visual.transformer.resblocks[-1].ln_1

In [37]:
ff_root = "/home/lucasmc/Documents/ufrgs/data/datasets/FairFace"
img_path = f"{ff_root}/val/216.jpg"

In [None]:
img = preprocess(Image.open(img_path)).unsqueeze(0).to(device)

In [39]:
grad_cam = GradCam(clip_model.visual, target_layer)

In [43]:
mask = grad_cam(img)

TypeError: conv2d() received an invalid combination of arguments - got (numpy.ndarray, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!numpy.ndarray!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
