# Understanding Attention Rollout

**Objective:** In this notebook we explore the implementation of **Attention Rollout** on ViT model. The main goals are to understand how the attention layers were extracted and how they implemented the algorithm.

# Importing libraries and Loading model

In [2]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt

from vit_rollout import VITAttentionRollout
from vit_grad_rollout import VITAttentionGradRollout

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {DEVICE}")

model = torch.hub.load('facebookresearch/deit:main', 
        'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.to(DEVICE)

Using cpu


Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=Fals

## Original Attention Rollout function

In [4]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2

def rollout(attentions, discard_ratio, head_fusion):
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention in attentions:
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(axis=1)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(axis=1)[0]
            elif head_fusion == "min":
                attention_heads_fused = attention.min(axis=1)[0]
            else:
                raise "Attention head fusion type Not supported"

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2
            a = a / a.sum(dim=-1)

            result = torch.matmul(a, result)
    
    # Look at the total attention between the class token,
    # and the image patches
    mask = result[0, 0 , 1 :]
    # In case of 224x224 image, this brings us from 196 to 14
    width = int(mask.size(-1)**0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask    

class VITAttentionRollout:
    def __init__(self, model, attention_layer_name='attn_drop', head_fusion="mean",
        discard_ratio=0.9):
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)

        self.attentions = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def __call__(self, input_tensor):
        self.attentions = []
        with torch.no_grad():
            output = self.model(input_tensor)

        return rollout(self.attentions, self.discard_ratio, self.head_fusion)

## Adding notes

In [None]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2

def rollout(attentions, discard_ratio, head_fusion):
    result = torch.eye(attentions[0].size(-1)) # Here it seems it starts the reccurent relation
    # Since the first matrix is only taken as it is, it starts with identity (see *1)
    
    with torch.no_grad():
        for attention in attentions: # iterates through all attention layers
    
            '''
            The attention tensor has shape 1 x 3 x 197 x 197
            Hence in this first part we are either taking the mean, max or mean 
            across all channels 
            '''
            
            # Selects the type of diffusion being used
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(axis=1)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(axis=1)[0] # [0] takes the max value, whereas [1] takes the indice where it was max
            elif head_fusion == "min":
                attention_heads_fused = attention.min(axis=1)[0]
            else:
                raise "Attention head fusion type Not supported"
            
            '''
            Hence now we only have 1 channel, each tensor of dimension 1 x 197 x 197
            '''

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) # flattens 1 x 197**2
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) # The False is for largest 
            '''
            torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
                Returns the k largest elements of the given input tensor along a given dimension.
                
                Most likely drops to speed up computations?
            '''
            indices = indices[indices != 0] ## Here it ensures that the class token is never droped
            flat[0, indices] = 0 ## Set the lowest to 0
            # It flattens to make changes inplace easier, below we go back to original tensor 1 x 197 x 197

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0*I)/2 ## Apparently here we add the residual and normalize
            a = a / a.sum(dim=-1) # Normalization, Chapter 3 and appendix A.1 -> NOT CLEAR 
            # Almost sure there is a typo, should be a.sum(dim=-1, keepdim=True) as in the blog row should be normalized

            result = torch.matmul(a, result) # (*1)
    
    
    
    # Look at the total attention between the class token,
    # and the image patches
    mask = result[0, 0 , 1 :] # [0: batch (1 image), 0: attentions of class token, '1 :': takes out the class token]
    # hence we end up with a mask of shape 196 (since we had 196 small squares (image split in 14 x 14)) (case 224x224)
    
    width = int(mask.size(-1)**0.5) # gets square root of 196 which is 14
    mask = mask.reshape(width, width).numpy() # put back in image format -> 14 x 14
    mask = mask / np.max(mask) # normalize for plotting
    return mask # return back -> we remark that it is in only one single channel

class VITAttentionRollout:
    def __init__(self, model, attention_layer_name='attn_drop', head_fusion="mean",
        discard_ratio=0.9):
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio 
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)

        self.attentions = []

    def get_attention(self, module, input, output):
        ## It is only getting the output of the module.... so it seems it is a function of the image we are looking at.
        self.attentions.append(output.cpu())

    def __call__(self, input_tensor):
        self.attentions = []
        with torch.no_grad():
            output = self.model(input_tensor)

        return rollout(self.attentions, self.discard_ratio, self.head_fusion)

## Understanding how it extracts the attention layer

In [9]:
for name, module in model.named_modules():
    print(name)


patch_embed
patch_embed.proj
patch_embed.norm
pos_drop
norm_pre
blocks
blocks.0
blocks.0.norm1
blocks.0.attn
blocks.0.attn.qkv
blocks.0.attn.attn_drop
blocks.0.attn.proj
blocks.0.attn.proj_drop
blocks.0.ls1
blocks.0.drop_path1
blocks.0.norm2
blocks.0.mlp
blocks.0.mlp.fc1
blocks.0.mlp.act
blocks.0.mlp.drop1
blocks.0.mlp.fc2
blocks.0.mlp.drop2
blocks.0.ls2
blocks.0.drop_path2
blocks.1
blocks.1.norm1
blocks.1.attn
blocks.1.attn.qkv
blocks.1.attn.attn_drop
blocks.1.attn.proj
blocks.1.attn.proj_drop
blocks.1.ls1
blocks.1.drop_path1
blocks.1.norm2
blocks.1.mlp
blocks.1.mlp.fc1
blocks.1.mlp.act
blocks.1.mlp.drop1
blocks.1.mlp.fc2
blocks.1.mlp.drop2
blocks.1.ls2
blocks.1.drop_path2
blocks.2
blocks.2.norm1
blocks.2.attn
blocks.2.attn.qkv
blocks.2.attn.attn_drop
blocks.2.attn.proj
blocks.2.attn.proj_drop
blocks.2.ls1
blocks.2.drop_path1
blocks.2.norm2
blocks.2.mlp
blocks.2.mlp.fc1
blocks.2.mlp.act
blocks.2.mlp.drop1
blocks.2.mlp.fc2
blocks.2.mlp.drop2
blocks.2.ls2
blocks.2.drop_path2
blocks.3
b

In [10]:
for name, module in model.named_modules():
    print(module)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=Fals

Hence only the blocks.0.attn.attn_drop with 0 varying until 11 is being extracted from the model:

In [12]:
attention_layer_name = 'attn_drop'

for name, module in model.named_modules():
    if attention_layer_name in name:
        print(name)
        # module.register_forward_hook(self.get_attention)

blocks.0.attn.attn_drop
blocks.1.attn.attn_drop
blocks.2.attn.attn_drop
blocks.3.attn.attn_drop
blocks.4.attn.attn_drop
blocks.5.attn.attn_drop
blocks.6.attn.attn_drop
blocks.7.attn.attn_drop
blocks.8.attn.attn_drop
blocks.9.attn.attn_drop
blocks.10.attn.attn_drop
blocks.11.attn.attn_drop


# Understanding the register_forward_hook

ChatGPT explanation:

module.register_forward_hook: This method takes one argument, which is a hook function. When registered to a module (layer) in a PyTorch model, this hook function will be automatically called every time the forward method of the module has been executed.

Purpose: By registering self.get_attention as a forward hook to the attention layers (identified by attention_layer_name in the model), the VITAttentionRollout class is designed to collect and store all the attention matrices produced during the forward pass of the input tensor through the model. This can be particularly useful for interpretability and visualization, helping to understand what the model is "looking at" or considering important in the input data.

So, in summary, module.register_forward_hook(self.get_attention) allows your custom class VITAttentionRollout to automatically capture and record the outputs of the specified attention layers each time they process an input, without having to manually modify the model's forward method or its internal structure.

-> So basically every time we do a forward pass on the model structure (like when we want to know the class of the image we are looking at), this function will also be called (like if it was added to the model structure itself). Therefore, most likely, the way we get the attention layers is by storing all of them when we pass through the model the image we want to classify **(does it change the weights?)** (interesting to verify, see the get_attention in VITARollout)

# Example of code for getting the output

In [None]:
#@title Select the rollout strategy to be used and display results
#@markdown **Note** that for `grad_attention_rollout` passing a category index is mandatory.
rollout_strategy = "attention_rollout" #@param ["attention_rollout", "grad_attention_rollout"]
category_index =  264#@param {type:"integer"}

print(f"Using {rollout_strategy}")

input_tensor  = preprocess_image("examples/input.png", transform) ## Gets the input tensor in the correct format

## Not interesting ##
if rollout_strategy == "grad_attention_rollout" and category_index < 0:
    raise Exception("Category index is mandatory when using Gradient Attention Rollout")

elif rollout_strategy == "grad_attention_rollout" and category_index > 0:
    grad_rollout = VITAttentionGradRollout(model, discard_ratio=DISCARD_RATIO)
    mask = grad_rollout(input_tensor, category_index)
    name = "grad_rollout_{}_{:.3f}_{}.png".format(category_index,
        DISCARD_RATIO, "mean")
## Not interesting ##



######################################################################################################################
elif rollout_strategy == "attention_rollout":
    attention_rollout = VITAttentionRollout(model, discard_ratio=DISCARD_RATIO) 
    # At this point, the input image was not yet passed thorugh the model
    mask = attention_rollout(input_tensor) # here is the moment it passes, hence it might influence the attention layers (?)
    # Here in mask we already have the final image..s...
    name = "attention_rollout_{:.3f}_{}.png".format(DISCARD_RATIO, "mean")
######################################################################################################################



## Only plotting until the end ##

np_img = np.array(img)[:, :, ::-1]
mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0])) #### since mask is 14 x 14, reshape to fit initial image size
mask = show_mask_on_image(np_img, mask)

# mask = (mask - mask.min()) / (mask.max() - mask.min())
# mask = mask.clip(0.7,1)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))

ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(img)
_ = ax2.imshow(mask)

The output of attention layers include the image information since we look at the $V$