In [39]:
import timm
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np


In [40]:

device = 'cpu'


model = timm.create_model('vit_base_patch16_224', pretrained=True)
model = model.to(device)
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [48]:
for block in model.blocks:
            block.attn.fused_attn = False

In [41]:
image_path = 'assets/ILSVRC2012_val_00006597.jpg'
image = Image.open(image_path).convert("RGB")

In [42]:
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),         
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

In [43]:
input_tensor = preprocess(image).unsqueeze(0).to('cpu')  
input_tensor = input_tensor.cpu()

In [44]:
def grad_rollout(attentions, gradients, discard_ratio):
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention, grad in zip(attentions, gradients):                
            weights = grad
            attention_heads_fused = (attention*weights).mean(axis=1)
            attention_heads_fused[attention_heads_fused < 0] = 0

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            #indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    
    mask = result[0, 0 , 1 :]
    
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask    

class VITAttentionGradRollout:
    def __init__(self, model, attention_layer_name='attn_drop',
        discard_ratio=0.9):
        self.model = model
        self.discard_ratio = discard_ratio
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)
                module.register_backward_hook(self.get_attention_gradient)

        self.attentions = []
        self.attention_gradients = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def get_attention_gradient(self, module, grad_input, grad_output):
        self.attention_gradients.append(grad_input[0].cpu())

    def __call__(self, input_tensor, category_index):
        self.model.zero_grad()
        output = self.model(input_tensor)
        category_mask = torch.zeros(output.size())
        category_mask[:, category_index] = 1
        loss = (output*category_mask).sum()
        loss.backward()

        return grad_rollout(self.attentions, self.attention_gradients,
            self.discard_ratio)

In [45]:
vit_grad_rollout = VITAttentionGradRollout(
    model=model,
    attention_layer_name='attn_drop',  
    discard_ratio=0.9
)

In [46]:
with torch.no_grad():
    output = model(input_tensor)
    predicted_class = torch.argmax(output, dim=1).item()

In [54]:
salience_map = vit_grad_rollout(input_tensor, predicted_class)




In [55]:
def compute_salience_scores(salience_map, num_subsets=10):
    """
    Compute salience scores for subsets of pixels in the salience map.

    Args:
        salience_map (numpy array): The salience map as a 2D numpy array.
        num_subsets (int): The number of subsets to divide the pixels into.

    Returns:
        list: Salience scores for each subset.
    """
    
    flattened_map = salience_map.flatten()
    
    
    sorted_indices = np.argsort(-flattened_map)  
    sorted_salience = flattened_map[sorted_indices]
    
    
    total_pixels = len(sorted_salience)
    subset_size = total_pixels // num_subsets
    salience_scores = []

    for i in range(num_subsets):
        
        start_idx = i * subset_size
        end_idx = (i + 1) * subset_size if i < num_subsets - 1 else total_pixels
        subset = sorted_salience[start_idx:end_idx]

        
        salience_score = subset.sum()
        salience_scores.append(float(salience_score))  

    return salience_scores