# Attention matrices analysis

Authors: Carlos CUEVAS VILLARMIN and Javier Alejandro LOPETEGUI GONZALEZ

This notebook has been written under the final project for the course Object Recognition and Computer Vision in the MVA master program.

The purpose of the notebook is gain insights about what the model CoVR-BLIP-2 has learned based on the attention matrices over layers and heads. Furthermore, we propose a way of weighting the learnable queries and modification text tokens embeddings based on their relevance considering the attention outputs.

**Remark:** To use this notebook it is necessary to have the attentions for test set downloaded. The examples used are the ones reported in the report of the project.

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
!mkdir -p "/content/datasets/CIRR"
!ln -s "/content/drive/MyDrive/path-to-images" "/content/datasets/CIRR/images"
!ln -s "/content/drive/MyDrive/path-to-attentions" "/content/datasets/CIRR/attentions"

In [None]:
import torch
import os
from tqdm import tqdm
import re

# Load .pt file
att_directory = "/content/datasets/CIRR/attentions"
pt_files = [f for f in os.listdir(att_directory) if re.match(r'^attentions_\d+\.pt$', f)]
pt_file_path = os.path.join(att_directory, pt_files[0])
all_attention_outputs = []

for _, file in tqdm(enumerate(pt_files)):
  # Load the file content (which is a list)
  attention_outputs_list = torch.load(os.path.join(att_directory, file))
  attention_outputs = torch.stack(attention_outputs_list)
  all_attention_outputs.append(attention_outputs)

# Concatenate all tensors in the list
concatenated_attention_outputs = torch.stack(all_attention_outputs)
# Verify the content
print(len(attention_outputs_list))
print(len(all_attention_outputs))
print(type(concatenated_attention_outputs))
print(concatenated_attention_outputs.shape)

In [None]:
# Read captions and IDS
import json

with open('/content/datasets/CIRR/attentions/captions_and_ids.json', 'r') as f:
    file_dict = json.load(f)

captions = file_dict['caption']
ids = file_dict['pair_id']

In [None]:
from transformers import BertTokenizer

# Load pre-trained BERT tokenizer
tokenizer = tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased", truncation_side="right"
        )
tokenizer.add_special_tokens({"bos_token": "[DEC]"})

In [None]:
import pandas as pd

def get_input_ids_and_tokens(tokenizer, caption, print_flag=False):
  text_tokens = tokenizer(
                  caption,
                  padding="longest",
                  truncation=True,
                  max_length=64,
                  return_tensors="pt",
              )

  input_ids = text_tokens['input_ids'][0]
  # Convert input_ids to tokens
  tokens = tokenizer.convert_ids_to_tokens(input_ids)

  if print_flag:
    # Create a pandas DataFrame for the table
    df = pd.DataFrame({"Input IDs": input_ids, "Tokens": tokens})

    # Print the DataFrame as a table
    print(df.to_string())

  return input_ids, tokens

In [None]:
batch_samples = [0, 5, 17, 90]

for sample in batch_samples:
    _, tokens = get_input_ids_and_tokens(tokenizer, captions[sample], print_flag = True)

In [None]:
def plot_heatmap(attention, tokens, path_to_save, layer_idx, name):
    '''
    Function to plot a heatmap of attention

    Args:
      attention: [seq_len, seq_len]
      tokens: list of tokens
      name: name of the heatmap
      path_to_save: path to save the heatmap
      layer_idx: index of the layer
    '''
    # Plot the heatmap
    fig_avg, ax_avg = plt.subplots(figsize=(10, 10))
    sns.heatmap(attention, cmap="hot", ax=ax_avg, cbar=False, vmin=0, vmax=1)
    ax_avg.set_title(name, fontsize=20)
    ax_avg.tick_params(left=False, bottom=False)
    ax_avg.set_xticks(np.arange(63))  # Set tick positions for 63 elements
    ax_avg.set_xticklabels([""] * 32 + tokens + ["[PAD]"] * (63 - 32 - len(tokens)), rotation=45, ha="right", fontsize = 8)
    ax_avg.set_yticks(np.arange(63))  # Set tick positions for 63 elements
    ax_avg.set_yticklabels([""] * 32 + tokens + ["[PAD]"] * (63 - 32 - len(tokens)), rotation=0, va="center", fontsize = 8)
    ax_avg.set_xlabel("Keys", fontsize=12)
    ax_avg.set_ylabel("Queries", fontsize=12)
    fig_avg.savefig(f"{path_to_save}/{name}.png", bbox_inches='tight', pad_inches=0)
    fig_avg.savefig(f"{path_to_save}/{name}.pdf", bbox_inches='tight', pad_inches=0, format='pdf')
    plt.close(fig_avg)
    #plt.tight_layout()
    #plt.show()

### Attention matrices layer by layer and head by head

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np

def plot_attention_layers(concatenated_attention_outputs, BATCH_INDEX, BATCH_SAMPLE):
    '''
    Function to plot the attention matrices for a given batch and sample

    Args:
      concatenated_attention_outputs: [num_batches, num_layers, batch_size, num_heads, seq_len, seq_len]
      BATCH_INDEX: index of the batch
      BATCH_SAMPLE: index of the sample
    '''
    path_to_save = f"/content/drive/MyDrive/CoVR-FP/figures/attentions_heatmaps/batchID_{BATCH_INDEX}-batchSAMPLE_{BATCH_SAMPLE}"
    os.makedirs(path_to_save, exist_ok=True)

    # Get the tokens of the sample
    input_ids, tokens = get_input_ids_and_tokens(tokenizer, captions[BATCH_SAMPLE])

    # Get the dimensions for the subplot grid
    num_layers = len(concatenated_attention_outputs[0])
    num_heads = concatenated_attention_outputs[0][0].shape[1]

    # Create the figure and subplots
    fig_global, axes_global = plt.subplots(num_layers, num_heads, figsize=(num_heads * 10, num_layers * 10))
    #fig_global.suptitle("Attentions", fontsize=30)

    # Iterate through layers and heads to plot heatmaps
    for layer_idx in range(num_layers):

        # Get the avg head attention per layer
        avg_head_attention = torch.mean(concatenated_attention_outputs[BATCH_INDEX, layer_idx, BATCH_SAMPLE], dim=0).cpu().detach().numpy()
        plot_heatmap(avg_head_attention, tokens, path_to_save, layer_idx, f"Avg Attention Matrix - Layer {layer_idx+1}")

        # Create the figure and subplots
        fig, axes = plt.subplots(1, num_heads, figsize=(num_heads*10, 10))  # Adjust figsize as needed
        #fig.suptitle(f"Layer {layer_idx + 1}", fontsize=20)  # Add a title for the layer


        for head_idx in range(num_heads):
            print(f"Computing layer {layer_idx+1} and head {head_idx+1}")
            # Get attention weights for the current layer and head
            head_attention = concatenated_attention_outputs[BATCH_INDEX][layer_idx, :, head_idx, :].cpu().detach().numpy()
            plot_heatmap(head_attention[BATCH_SAMPLE], tokens, path_to_save, layer_idx, f"Attention Matrix - Layer {layer_idx + 1} - Head {head_idx + 1}")

            # Plot the heatmap on the corresponding subplot
            sns.heatmap(head_attention[BATCH_SAMPLE], cmap="hot", ax=axes[head_idx], cbar=False, vmin=0, vmax=1)
            #axes[head_idx].set_title(f"Head {head_idx + 1}", fontsize=8)
            axes[head_idx].tick_params(left=False, bottom=False)
            axes[head_idx].set_xticks([])
            axes[head_idx].set_yticks([])

            # Add to the main figure
            axes_global[layer_idx, head_idx].imshow(head_attention[BATCH_SAMPLE], cmap="hot", aspect="auto", vmin=0, vmax=1)
            axes_global[layer_idx, head_idx].tick_params(left=False, bottom=False)
            axes_global[layer_idx, head_idx].set_xticks([])
            axes_global[layer_idx, head_idx].set_yticks([])

        # Adjust layout and display the figure
        plt.tight_layout()
        fig.savefig(f"{path_to_save}/heatmap_layer_{layer_idx + 1}.png")
        fig.savefig(f"{path_to_save}/heatmap_layer_{layer_idx + 1}.pdf", bbox_inches='tight', pad_inches=0, format='pdf')
        plt.show()
        plt.close(fig)



    # Adjust layout and display the figure
    plt.tight_layout()
    fig_global.savefig(f"{path_to_save}/all_layers_heads_heatmaps.png")
    fig_global.savefig(f"{path_to_save}/all_layers_heads_heatmaps.pdf", bbox_inches='tight', pad_inches=0, format='pdf')
    plt.show()
    plt.close(fig_global)


In [None]:
batch_samples = [0, 5, 17, 90]

for sample in batch_samples:
    plot_attention_layers(concatenated_attention_outputs, BATCH_INDEX=0, BATCH_SAMPLE=sample)

### Attention rollout matrices layer by layer and head by head

In [None]:
def attention_rollout(As):
    """
    Computes attention rollout from the given list of attention matrices.
    https://arxiv.org/abs/2005.00928

    Args:
      As: Attention matrices [num_layers, num_heads, seq_len, seq_len]

    Returns:
      rollout: Attention rollout matrices [num_layers, num_heads, seq_len, seq_len]
    """
    num_layers = As.shape[0]
    seq_len = As.shape[-1]
    rollout = [None] * num_layers

    for i in range(num_layers):
        if i == 0:  # Base case
            # Add residual connection and normalize
            rollout[i] = 0.5 * As[i] + 0.5 * torch.eye(seq_len).to(As.device)
        else:
            # Add residual connection, normalize, and multiply with the previous layer's rollout
            current = 0.5 * As[i] + 0.5 * torch.eye(seq_len).to(As.device)
            rollout[i] = torch.matmul(current, rollout[i - 1])

    rollout = torch.stack(rollout)  # Stack all rollouts layer by layer
    return rollout

def attention_rollout_per_sample(attentions, BATCH_INDEX, BATCH_SAMPLE):
    '''
    Function to compute the attention rollout for a given batch and sample

    Args:
      attentions: [num_batches, num_layers, batch_size, num_heads, seq_len, seq_len]
      BATCH_INDEX: index of the batch
      BATCH_SAMPLE: index of the sample
    '''
    # Select the attention for the specified batch and the first layer
    attention = attentions[BATCH_INDEX, :, BATCH_SAMPLE]  # Shape: [num_layers, num_heads, seq_len, seq_len]
    rollout = attention_rollout(attention)
    return rollout


def plot_attention_rollout(attentions, BATCH_INDEX, BATCH_SAMPLE):
    '''
    Function to plot the attention rollout matrices for a given batch and sample

    Args:
      concatenated_attention_outputs: [num_batches, num_layers, batch_size, num_heads, seq_len, seq_len]
      BATCH_INDEX: index of the batch
      BATCH_SAMPLE: index of the sample
    '''

    path_to_save = f"/content/drive/MyDrive/CoVR-FP/figures/attentions_heatmaps/batchID_{BATCH_INDEX}-batchSAMPLE_{BATCH_SAMPLE}/attention_rollout"
    os.makedirs(path_to_save, exist_ok=True)

    # Get the tokens of the sample
    input_ids, tokens = get_input_ids_and_tokens(tokenizer, captions[BATCH_SAMPLE])
    # Get the attention rollout
    rolled_out_attention = attention_rollout_per_sample(concatenated_attention_outputs, BATCH_INDEX, BATCH_SAMPLE)  # Dim: [num_layer, num_head, seq_len, seq_len]
    #rolled_out_attention_mean = torch.mean(rolled_out_attention, dim=1).cpu().detach().numpy()  # Dim: [num_layer, seq_len, seq_len]
    num_layers = rolled_out_attention.shape[0]
    num_heads = rolled_out_attention.shape[1]

    # Create the figure and subplots
    fig_global, axes_global = plt.subplots(num_layers, num_heads, figsize=(num_heads * 10, num_layers * 10))
    #fig_global.suptitle("Attention Rollout", fontsize=30)

    for layer_idx in range(num_layers):
        # Get the avg head attention per layer
        avg_rollout_attention = torch.mean(rolled_out_attention[layer_idx], dim=0).cpu().detach().numpy()
        plot_heatmap(avg_rollout_attention, tokens, path_to_save, layer_idx, f"Avg Attention Rollout - Layer {layer_idx+1}")

        # Create the figure and subplots
        fig_layer, axes_layer = plt.subplots(1, num_heads, figsize=(num_heads*10, 10))  # Adjust figsize as needed
        #fig_layer.suptitle(f"Layer {layer_idx + 1}", fontsize=20)  # Add a title for the layer

        for head_idx in range(num_heads):
          fig, ax = plt.subplots(figsize=(10, 10))
          plot_heatmap(rolled_out_attention[layer_idx,head_idx].cpu().detach().numpy(), tokens, path_to_save, layer_idx, f"Attention Rollout - Layer {layer_idx + 1} - Head {head_idx + 1}")

          #Add to the global figure
          axes_global[layer_idx, head_idx].imshow(rolled_out_attention[layer_idx,head_idx].cpu().detach().numpy(), cmap="hot", aspect="auto", vmin=0, vmax=1)
          axes_global[layer_idx, head_idx].tick_params(left=False, bottom=False)
          axes_global[layer_idx, head_idx].set_xticks([])
          axes_global[layer_idx, head_idx].set_yticks([])

          #Add to the layer figure
          axes_layer[head_idx].imshow(rolled_out_attention[layer_idx,head_idx].cpu().detach().numpy(), cmap="hot", aspect="auto", vmin=0, vmax=1)
          axes_layer[head_idx].tick_params(left=False, bottom=False)
          axes_layer[head_idx].set_xticks([])
          axes_layer[head_idx].set_yticks([])

        #Save fig_layer
        fig_layer.savefig(f"{path_to_save}/heatmap_rollout_layer_{layer_idx+1}.png", bbox_inches='tight', pad_inches=0)
        fig_layer.savefig(f"{path_to_save}/heatmap_rollout_layer_{layer_idx+1}.pdf", bbox_inches='tight', pad_inches=0, format='pdf')
        plt.close(fig_layer)

    #Save global
    fig_global.savefig(f"{path_to_save}/all_layers_heads_attention_rollout.png", bbox_inches='tight', pad_inches=0, format='pdf')
    fig_global.savefig(f"{path_to_save}/all_layers_heads_attention_rollout.pdf", bbox_inches='tight', pad_inches=0, format='pdf')
    plt.close(fig_global)


In [None]:
batch_samples = [0, 5, 17, 90]

for sample in batch_samples:
    plot_attention_rollout(concatenated_attention_outputs, BATCH_INDEX=0, BATCH_SAMPLE=sample)

### Attention rollout weights apporach

In [None]:
from torch.nn import functional as F

def compute_weights_based_on_attention_rollout(attentions, BATCH_INDEX, BATCH_SAMPLE, temp=1, return_logs = False):
  '''
  Compute the weights based on the attention rollout

  Args:
    attentions: [num_batches, num_layers, batch_size, num_heads, seq_len, seq_len]
    BATCH_INDEX: index of the batch
    BATCH_SAMPLE: index of the sample

  Returns:
    queries_dist: distribution/weights of the learnable queries
    text_dist: distribution/weights of the modification text tokens
  '''
  rollout = attention_rollout_per_sample(attentions, BATCH_INDEX, BATCH_SAMPLE) #Dim: [num_layers, num_heads, seq_len, seq_len]
  # Average over last layer
  rollout_avg_last_layer = torch.mean(rollout[-1], dim=0) # Dim: [seq_len, seq_len]

  # Get the average for each column
  rollout_avg_last_layer_per_column = torch.mean(rollout_avg_last_layer, dim=0) # Dim: [seq_len]

  #Split into values for queries and text
  queries_rollout = rollout_avg_last_layer_per_column[:32]
  text_rollout    = rollout_avg_last_layer_per_column[32:]

  # Set to low value to all the values that are under a threshold
  threshold = 0.01
  queries_rollout[queries_rollout < threshold] = -1e7
  text_rollout[text_rollout < threshold] = -1e7

  #How many non-zero values are
  if return_logs:
    print(f"Queries rollouts lower than 0: {torch.sum(queries_rollout < 0)} out of {len(queries_rollout)}")
    print(f"Text rollouts lower than 0: {torch.sum(text_rollout < 0)} out of {len(text_rollout)}")

  queries_dist = F.softmax(queries_rollout/temp, dim=0)
  text_dist = F.softmax(text_rollout/temp, dim=0)

  return queries_dist, text_dist



In [None]:
def get_interesting_text_tokens(text_dist, tokenizer, BATCH_SAMPLE):
  '''
  Function to get the interesting text tokens

  Args:
    text_dist: distribution/weights of the modification text tokens
    tokenizer: tokenizer
    BATCH_SAMPLE: index of the sample
  '''
  _, tokens = get_input_ids_and_tokens(tokenizer, captions[BATCH_SAMPLE])

  #Get the ids of the dist > 0
  ids = torch.where(text_dist > 0)[0]
  for id in ids:
    try:
      print(f"Token: {tokens[id]} \t - Prob: {text_dist[id]:.4f}")
    except IndexError:
      print(f"Token: [PAD] \t - Prob: {text_dist[id]:.4f}")

In [None]:
batch_samples = [0, 5, 17, 90]

for sample in batch_samples:
    _, text_dist = compute_weights_based_on_attention_rollout(concatenated_attention_outputs, BATCH_INDEX=0, BATCH_SAMPLE=sample)
    get_interesting_text_tokens(text_dist, tokenizer, BATCH_SAMPLE=sample)

In [None]:
# For all the samples in all the batches
for i in range(len(concatenated_attention_outputs)):
  for j in range(concatenated_attention_outputs.shape[2]):
    print(f"Pair id: {ids[j]}")
    queries_dist, text_dist = compute_weights_based_on_attention_rollout(concatenated_attention_outputs, i, j, temp=0.2, return_logs=True)
    get_interesting_text_tokens(text_dist, tokenizer, j)
    #print(queries_dist)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def plot_attention_rollout_per_query(attentions, BATCH_INDEX, BATCH_SAMPLE):
    '''
    Create the attention rollout evolution over layers for each query

    Args:
      concatenated_attention_outputs: [num_batches, num_layers, batch_size, num_heads, seq_len, seq_len]
      BATCH_INDEX: index of the batch
      BATCH_SAMPLE: index of the sample
    '''
    path_to_save = f"/content/drive/MyDrive/CoVR-FP/figures/attentions_heatmaps/batchID_{BATCH_INDEX}-batchSAMPLE_{BATCH_SAMPLE}/query_evol_rollout"
    os.makedirs(path_to_save, exist_ok=True)


    input_ids, tokens = get_input_ids_and_tokens(tokenizer, captions[BATCH_SAMPLE])


    rolled_out_attention = attention_rollout_per_sample(concatenated_attention_outputs, BATCH_INDEX, BATCH_SAMPLE)  # Dim: [num_layer, num_head, seq_len, seq_len]
    num_layers = rolled_out_attention.shape[0]
    num_heads = rolled_out_attention.shape[1]
    seq_len = rolled_out_attention.shape[3]


    for query_idx in tqdm(range(seq_len)):
        for head_idx in range(num_heads):
            query_head_evolution = []
            for layer_idx in range(num_layers):
                head_attention = rolled_out_attention[layer_idx, head_idx].cpu().detach().numpy()  # [seq_len, seq_len]
                query_head_evolution.insert(0, head_attention[query_idx])

            query_head_evolution = np.array(query_head_evolution)  # [num_layers, seq_len]

            fig, ax = plt.subplots(figsize=(12,8))
            cax = ax.imshow(query_head_evolution, cmap="hot", aspect="auto", vmin=0, vmax=1)
            ax.tick_params(left=False, bottom=False)
            ax.set_xticks(range(seq_len))  # Set tick positions for 63 elements
            ax.set_xticklabels([""] * 32 + tokens + ["[PAD]"] * (63 - 32 - len(tokens)), rotation=45, ha="right", fontsize = 8)
            ax.set_yticks(range(num_layers))
            ax.set_yticklabels([f"{num_layers - i}" for i in range(num_layers)])
            if query_idx < 32:
                ax.set_title(f"Learnabla query {query_idx} - Head {head_idx+1}", fontsize = 20)
            else:
                try:
                    ax.set_title(f"Query: {tokens[query_idx - 32]} - Avg over heads", fontsize = 20)
                except IndexError:
                    ax.set_title(f"Query: [PAD] - Avg over heads", fontsize = 20)
            ax.set_xlabel("Keys", fontsize = 12)
            ax.set_ylabel("Layers", rotation=90, fontsize=12)


            cbar = fig.colorbar(cax, ax=ax)
            cbar.set_label("Attention Weight")

            # Save the image
            fig.savefig(f"{path_to_save}/query_{query_idx}_head_{head_idx+1}_evolution.pdf", bbox_inches='tight', pad_inches=0, format= "pdf")
            #fig.savefig(f"{path_to_save}/query_{query_idx}_head_{head_idx+1}_evolution.png")
            plt.close(fig)

        # Create an avg heatmap for all heads per query
        query_avg_evolution = []
        for layer_idx in range(num_layers):
            avg_attention = torch.mean(rolled_out_attention[layer_idx, :, query_idx], dim=0).cpu().detach().numpy()
            query_avg_evolution.insert(0, avg_attention)

        query_avg_evolution = np.array(query_avg_evolution)  # [num_layers, seq_len]
        # Print the query_avg_evolution for layer 12
        #print(query_idx, query_avg_evolution[0])

        fig, ax = plt.subplots(figsize=(12, 8))
        cax = ax.imshow(query_avg_evolution, cmap="hot", aspect="auto", vmin=0, vmax=1)
        ax.tick_params(left=False, bottom=False)
        ax.set_xticks(range(seq_len))
        ax.set_xticklabels([""] * 32 + tokens + ["[PAD]"] * (63 - 32 - len(tokens)), rotation=45, ha="right", fontsize = 8)
        ax.set_yticks(range(num_layers))
        ax.set_yticklabels([f"{num_layers - i}" for i in range(num_layers)])
        if query_idx < 32:
                ax.set_title(f"Learnable query {query_idx} - Avg over heads", fontsize=20)
        else:
            try:
                ax.set_title(f"Query: {tokens[query_idx - 32]} - Avg over heads", fontsize=20)
            except IndexError:
                ax.set_title(f"Query: [PAD] - Avg over heads", fontsize=20)
        ax.set_xlabel("Keys", fontsize=12)
        ax.set_ylabel("Layers", fontsize=12)
        #cbar = fig.colorbar(cax, ax=ax)
        #cbar.set_label("Attention Weight", fontsize=12)

        # Save the avg heatmap per query
        fig.savefig(f"{path_to_save}/query_{query_idx}_average_evolution.pdf", bbox_inches='tight', pad_inches=0, format='pdf')
        #fig.savefig(f"{path_to_save}/query_{query_idx}_average_evolution.png",  bbox_inches='tight', pad_inches=0)
        plt.close(fig)

In [None]:
batch_samples = [0, 5, 17, 90]

for sample in batch_samples:
    plot_attention_rollout_per_query(concatenated_attention_outputs, BATCH_INDEX=0, BATCH_SAMPLE=sample)