### QFormer Cross attention annalysis for CoVR-BLIP-2 model

In [None]:
# install transformers
!pip install transformers==4.26.1

In [None]:
# load the drive (if using colab)
from google.colab import drive
drive.mount('/content/drive')

In [25]:
import torch

In [125]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from PIL import Image
import json

import cv2
from tqdm import tqdm

from transformers import BertTokenizer

In [11]:
# # load attention from drive
# self_attentions = torch.load('/content/drive/MyDrive/CIRR/attentions/attentions_0.pt')
# cross_attentions = torch.load('/content/drive/MyDrive/CIRR/attentions/cross_attentions_0.pt')

In [30]:
# read captions and pair ids
captions = json.load(open('/content/drive/MyDrive/CIRR/attentions/captions_and_ids.json'))

In [31]:
ids = captions['pair_id']
captions = captions['caption']

In [32]:
len(ids), len(captions)

(128, 128)

In [None]:
# Load pre-trained BERT tokenizer (same used in the model)
tokenizer = tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased", truncation_side="right"
        )
tokenizer.add_special_tokens({"bos_token": "[DEC]"})

In [34]:
# tokenize captions
text_tokens = tokenizer(
                captions,
                padding="longest",
                truncation=True,
                max_length=64,
                return_tensors="pt",
            )

In [None]:
# Example of tokenization
i = 25
input_ids = text_tokens['input_ids'][i]

# Convert input_ids to tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids)

# Print the tokens and their corresponding input_ids
print("Input IDs:", input_ids)
print("Tokens:", tokens)

In [None]:
# Attention evolution across layers, considering all the heads (matrix plot => queries to visual tokens)
i = 2 # Select the batch index
token_values = [0 for _ in range(cross_attentions[0].shape[-1])]
for block in range(len(cross_attentions)):
    # Assuming cross_attentions[-1] corresponds to the last layer and has shape [batch_size, num_heads, seq_len1, seq_len2]
    cross_attention_layer = cross_attentions[block]  # Last layer's attention
    num_heads = 12

    # Prepare the figure with 12 subplots arranged in rows
    fig, axes = plt.subplots(num_heads, 1, figsize=(10, 20))  # 12 rows, 1 column
    fig.suptitle(f"Attention Maps for 12 Heads (Layer {block})", fontsize=16)

    # Loop through each head and plot its attention map
    for head in range(num_heads):
        attention_map = cross_attention_layer[i, head].cpu().detach().numpy()

        for j in range(attention_map.shape[-1]):
            token_values[j] += sum(attention_map[:, j])/attention_map.shape[0]
        # Rescale the attention map for visualization
        attention_rescaled = cv2.resize(attention_map, (667, 63))

        # Normalize the map to [0, 1]
        attention_normalized = (attention_rescaled - attention_rescaled.min()) / (attention_rescaled.max() - attention_rescaled.min())

        # Plot the attention map in the corresponding subplot
        ax = axes[head]
        im = ax.imshow(attention_normalized, cmap='viridis', aspect='auto')
        ax.set_title(f"Head {head + 1}", fontsize=12)
        ax.axis("off")

    # Add a single colorbar to the right of the plots
    # cbar = fig.colorbar(im, ax=axes, orientation='vertical', shrink=0.8, pad=0.02)
    # cbar.set_label("Attention Intensity", fontsize=12)

    # Adjust spacing between subplots
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Leave space for the title
    plt.show()

token_values = [token_values[k]/len(cross_attentions) for k in range(len(token_values))]

In [77]:
# get the cumulative average attention from queries to visual tokens given an index from the batch
def get_token_attention_values(index, cross_attentions):
  # print(len(cross_attentions))
  token_values = [0 for _ in range(cross_attentions[0].shape[-1])]
  # print(len(token_values))
  for block in range(len(cross_attentions)):
    cross_attention_layer = cross_attentions[block]  # Last layer's attention
    print(f"Block: {block}",cross_attention_layer.shape)
    num_heads = 12
    for head in range(num_heads):
        attention_map = cross_attention_layer[index, head].cpu().detach().numpy()

        for j in range(attention_map.shape[-1]):
            token_values[j] += sum(attention_map[:, j])/attention_map.shape[0]

    token_values = [token_values[k]/len(cross_attentions) for k in range(len(token_values))]
  return token_values

In [None]:
# get all the token_values
print(len(cross_attentions))
all_batch_token_values = []
for i in tqdm(range(len(cross_attentions[0]))):
  all_batch_token_values.append(get_token_attention_values(i, cross_attentions))

In [38]:
# get queries' modification text
annotations = json.load(open('/content/drive/MyDrive/CIRR/ann_test1.json'))
annotations_dir = {}

for annotation in annotations:
  annotations_dir[annotation['pairid']] = annotation

In [147]:
# plot heat_maps based on attentions for 5 random images
def plot_heat_maps_for_random_images(batch_token_values, annotations, pair_ids, captions, images_path="drive/MyDrive/CIRR/images/test1/", num_images=5, indexes = None, print_captions=True):
    if indexes is None:
      random_indices = random.sample(range(len(captions)), num_images)
    else:
      random_indices = indexes

    # Prepare lists to store the original and heatmap images
    original_images = []
    heatmap_images = []
    selected_captions = []

    for index in random_indices:
        # Retrieve the image path and caption
        reference_image = annotations[pair_ids[index]]['reference']
        image_path = images_path + reference_image + '.png'
        selected_captions.append(captions[index])

        # Load and resize the original image
        original_image = Image.open(image_path).resize((364, 364))  # Resize to match the target size (optional)
        original_images.append(original_image)

        # Process the token values for the heatmap
        token_values = batch_token_values[index]
        mean_value = np.mean(token_values)
        # token_values = [1 if token_values[i] > mean_value else 0 for i in range(len(token_values))]
        token_values = [token_values[i]*10 if token_values[i] > mean_value else 0 for i in range(len(token_values))]

        # Define image and patch size
        image_size = 364
        patch_size = 14
        num_patches_per_side = image_size // patch_size

        # Create the heatmap grid
        heatmap_grid = np.zeros((num_patches_per_side, num_patches_per_side))
        patch_index = 0
        for row in range(num_patches_per_side):
            for col in range(num_patches_per_side):
                if patch_index < len(token_values):
                    heatmap_grid[row, col] = token_values[patch_index]
                    patch_index += 1

        # Normalize and apply a colormap
        heatmap_normalized = heatmap_grid
        # heatmap_normalized = (heatmap_grid - np.min(heatmap_grid)) / (np.max(heatmap_grid) - np.min(heatmap_grid))
        heatmap = cm.hot(heatmap_normalized)

        # Create the final heatmap image
        heatmap_img = np.zeros((image_size, image_size, 4))
        for i in range(num_patches_per_side):
            for j in range(num_patches_per_side):
                patch_val = heatmap[i, j]
                start_x = j * patch_size
                start_y = i * patch_size
                heatmap_img[start_y:start_y + patch_size, start_x:start_x + patch_size] = patch_val

        heatmap_img_resized = Image.fromarray((heatmap_img * 255).astype(np.uint8)).resize((image_size, image_size))
        heatmap_img_resized_array = np.array(heatmap_img_resized)

        # Combine the heatmap with the original image
        original_image_rgba = original_image.convert("RGBA")
        original_image_array = np.array(original_image_rgba)
        # combined_image = original_image_array.copy()
        # alpha = 0.5
        # for c in range(3):  # Iterate over R, G, B channels
            # combined_image[..., c] = (alpha * heatmap_img_resized_array[..., c] + (1 - alpha) * combined_image[..., c])

        # heatmap_images.append(combined_image)
        heatmap_images.append(heatmap_img_resized_array)
    #captions:
    if print_captions:
      for i in range(len(selected_captions)):
        print(f"Caption image: {i+1}",selected_captions[i])
    # Plot the images in a grid with captions
    fig, axes = plt.subplots(2, num_images, figsize=(20, 10))  # Increased figure size for larger images

    for i in range(num_images):
        # Top row: original images with captions
        axes[0, i].imshow(original_images[i])
        axes[0, i].axis('off')
        # axes[0, i].set_title(selected_captions[i], fontsize=12, wrap=True)  # Add caption as title

        # Bottom row: images with heatmaps
        axes[1, i].imshow(original_images[i])
        axes[1, i].imshow(heatmap_images[i], alpha=0.7)
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

In [203]:
# Function to plot heatmaps for random images
def plot_heat_maps_for_random_images_new(batch_token_values, annotations, pair_ids, captions, images_path="drive/MyDrive/CIRR/images/test1/", num_images=5, indexes=None, print_captions=True):
    # if indices are provided keep them, otherwise take 5 random images from the batch
    if indexes is None:
        random_indices = random.sample(range(len(captions)), num_images)
    else:
        random_indices = indexes

    # Prepare lists to store the original and heatmap images
    original_images = []
    heatmap_images = []
    selected_captions = []

    for index in random_indices:
        # Retrieve the image path and caption
        reference_image = annotations[pair_ids[index]]['reference']
        image_path = images_path + reference_image + '.png'
        selected_captions.append(captions[index])

        # Load and resize the original image
        original_image = Image.open(image_path).resize((364, 364))  # Resize to match the target size
        original_images.append(original_image)

        # Process the token values for the heatmap
        # Apply normalization (based on the mean value) to make the heatmap more informative
        token_values = batch_token_values[index]
        mean_value = np.mean(token_values)
        token_values = [token_values[i]*255 if token_values[i] > mean_value else 0 for i in range(len(token_values))]

        # Define image and patch size
        image_size = 364
        patch_size = 14
        num_patches_per_side = image_size // patch_size

        # Create the heatmap grid
        heatmap_grid = np.zeros((num_patches_per_side, num_patches_per_side))
        patch_index = 0
        for row in range(num_patches_per_side):
            for col in range(num_patches_per_side):
                if patch_index < len(token_values):
                    heatmap_grid[row, col] = token_values[patch_index]
                    patch_index += 1

        # Normalize the heatmap values (not normalization added in this case)
        heatmap_normalized = heatmap_grid

        # Resize the heatmap to match the image size
        heatmap_resized = np.array(Image.fromarray(heatmap_normalized).resize((image_size, image_size), resample=Image.BICUBIC))

        # Apply a colormap
        colormap = cm.get_cmap('viridis')
        heatmap_colored = colormap(heatmap_resized)

        # Combine the heatmap with the original image
        original_image_rgba = original_image.convert("RGBA")
        original_image_array = np.array(original_image_rgba)
        heatmap_colored_array = (heatmap_colored[:, :, :3] * 255).astype(np.uint8)

        combined_image = original_image_array.copy()
        alpha = 0.7  # Transparency factor
        for c in range(3):  # Iterate over RGB channels
            combined_image[..., c] = (alpha * heatmap_colored_array[..., c] + (1 - alpha) * combined_image[..., c])

        heatmap_images.append(combined_image)

    # Print captions if required
    if print_captions:
        for i in range(len(selected_captions)):
            print(f"Caption image {i + 1}: {selected_captions[i]}")

    # Plot the images in a grid
    fig, axes = plt.subplots(2, num_images, figsize=(20, 10))  # Increased figure size

    for i in range(num_images):
        # Top row: original images
        axes[0, i].imshow(original_images[i])
        axes[0, i].axis('off')

        # Bottom row: images with heatmaps
        axes[1, i].imshow(heatmap_images[i])
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
import random
plot_heat_maps_for_random_images_new(all_batch_token_values, annotations_dir, ids, captions, indexes = [0,1,2,3,4])

In [210]:
# method to get the token_values evolution through attntion blocks
# In this case we save the cumulative cross attention values for each layer
def get_token_attention_values_evolution(cross_attentions):
  token_values_evolution = []

  # iterate over blocks
  for block in range(len(cross_attentions)):
    print(f"Block: {block}")
    cross_attention_layer = cross_attentions[block]  # Last layer's attention

    # get the mean across heads for this block
    mean_cross_attention_layer = torch.mean(cross_attention_layer, dim=1)
    print(mean_cross_attention_layer.shape)

    # get the mean across queries_tokens
    mean_cross_attention_layer = torch.mean(mean_cross_attention_layer, dim=1)
    print(mean_cross_attention_layer.shape)

    if block == 0:
      token_values_evolution.append(mean_cross_attention_layer.cpu().detach().numpy())
    else:
      token_values_evolution.append(token_values_evolution[-1] + (block) * mean_cross_attention_layer.cpu().detach().numpy())
      # average over blocks
      token_values_evolution[-1] = token_values_evolution[-1]/(block+1)

  print(len(token_values_evolution))
  return token_values_evolution


In [None]:
# get the token values evolution for the entire batch
print(len(cross_attentions))
token_values_evolution = get_token_attention_values_evolution(cross_attentions)

In [None]:
# plot heat_maps corresponding to the cumulative average cross attention values across layers
indexes = [0,1,2,3,4]
# indexes = random.sample(range(len(captions)), 5)
for i in range(len(token_values_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  plot_heat_maps_for_random_images_new(token_values_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_captions = print_captions)
  print_captions = False

In [213]:
# method to get the token_values evolution through attntion blocks per query
# In this case we do not average over the queries
def get_token_attention_evolution(cross_attentions):
  token_values_evolution = []
  for block in range(len(cross_attentions)):
    print(f"Block: {block}")
    cross_attention_layer = cross_attentions[block]  # Last layer's attention

    # get the mean across heads
    mean_cross_attention_layer = torch.mean(cross_attention_layer, dim=1)
    print(mean_cross_attention_layer.shape)

    if block == 0:
      token_values_evolution.append(mean_cross_attention_layer.cpu().detach().numpy())
    else:
      token_values_evolution.append(token_values_evolution[-1] + (block) * mean_cross_attention_layer.cpu().detach().numpy())
      # average over blocks
      token_values_evolution[-1] = token_values_evolution[-1]/(block+1)

  return token_values_evolution

In [None]:
token_attention_evolution = get_token_attention_evolution(cross_attentions)

In [None]:
# plot heat_maps per layer
indexes = [0,1,2,3,4]
# indexes = random.sample(range(len(captions)), 5)
query_index = 0 # first query
token_values_query = [att[:,query_index,:] for att in token_attention_evolution]
print(token_values_query[0].shape)
for i in range(len(token_values_query)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  plot_heat_maps_for_random_images_new(token_values_query[i], annotations_dir, ids, captions, indexes = indexes, print_captions = print_captions)
  print_captions = False

In [None]:
# plot heat_maps per layer
indexes = [0,1,2,3,4]
# indexes = random.sample(range(len(captions)), 5)
query_index = 31 # last query
token_values_query = [att[:,query_index,:] for att in token_attention_evolution]
print(token_values_query[0].shape)
for i in range(len(token_values_query)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  plot_heat_maps_for_random_images_new(token_values_query[i], annotations_dir, ids, captions, indexes = indexes, print_captions = print_captions)
  print_captions = False

In [230]:
# Function to plot heatmaps for 5 random queries given the attention values corresponding to one specific layer
def plot_heat_maps_for_5_random_queries(image_index, all_token_attentions, annotations, pair_ids, captions,
                                             images_path="drive/MyDrive/CIRR/images/test1/", num_queries=5,
                                             indexes=None, print_caption=True):
    # if indexes were already provided keep them
    if indexes is None:
        random_indices = random.sample(range(len(all_token_attentions[image_index])), num_queries)
        print(random_indices)
    else:
        random_indices = indexes

    # Prepare lists to store the original and heatmap images
    heatmap_images = []

    # Retrieve the image path and caption
    reference_image = annotations[pair_ids[image_index]]['reference']
    image_path = images_path + reference_image + '.png'
    selected_caption = captions[image_index]


    # Load and resize the original image
    original_image = Image.open(image_path).resize((364, 364))

    iterate over queries
    for index in random_indices:
        # Process the token values for the heatmap (from the specific layer)
        # Applied a scaler based on the mean value to make the heatmap more informative
        token_values = all_token_attentions[image_index][index]
        mean_value = np.mean(token_values)
        token_values = [token_values[i] * 255 if token_values[i] > mean_value else 0 for i in range(len(token_values))]

        # Define image and patch size
        image_size = 364
        patch_size = 14
        num_patches_per_side = image_size // patch_size

        # Create the heatmap grid
        heatmap_grid = np.zeros((num_patches_per_side, num_patches_per_side))
        patch_index = 0
        for row in range(num_patches_per_side):
            for col in range(num_patches_per_side):
                if patch_index < len(token_values):
                    heatmap_grid[row, col] = token_values[patch_index]
                    patch_index += 1

        # Normalize the heatmap values (No normalization in thsi case)
        heatmap_normalized = heatmap_grid

        # Resize the heatmap to match the image size
        heatmap_resized = np.array(Image.fromarray(heatmap_normalized).resize((image_size, image_size), resample=Image.BICUBIC))

        # Apply a colormap (e.g., viridis)
        colormap = cm.get_cmap('viridis')
        heatmap_colored = colormap(heatmap_resized)

        # Combine the heatmap with the original image
        original_image_rgba = original_image.convert("RGBA")
        original_image_array = np.array(original_image_rgba)
        heatmap_colored_array = (heatmap_colored[:, :, :3] * 255).astype(np.uint8)

        combined_image = original_image_array.copy()
        alpha = 0.7  # Transparency factor
        for c in range(3):  # Iterate over RGB channels
            combined_image[..., c] = (alpha * heatmap_colored_array[..., c] + (1 - alpha) * combined_image[..., c])

        heatmap_images.append(combined_image)

    # Print captions if required
    if print_caption:
        print(f"Caption image: {selected_caption}")

    # Plot the images in a grid
    num_images = len(random_indices)
    fig, axes = plt.subplots(2, num_queries, figsize=(20, 10))  # Adjusted figure size for 5 queries

    for i in range(num_images):
        # Top row: original images
        axes[0, i].imshow(original_image)
        axes[0, i].axis('off')

        # Bottom row: images with heatmaps
        axes[1, i].imshow(heatmap_images[i])
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()


Now we will plot the attention values evolution for a subset of 5 queries across all the layers. It will allow us to gain insights about how the mechanism is working for our specific task

In [None]:
# plot heat_maps per layer
indexes = [0,10,20,25,31]
# indexes = random.sample(range(len(captions)), 5)
image_index = 1
for i in range(len(token_attention_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  print( token_attention_evolution[i].shape)
  plot_heat_maps_for_layer5_random_queries(image_index, token_attention_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_caption = print_captions)
  print_captions = False

In [None]:
# plot heat_maps per layer
indexes = [0,10,20,25,31]
# indexes = random.sample(range(len(captions)), 5)
image_index = 100
for i in range(len(token_attention_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  print( token_attention_evolution[i].shape)
  plot_heat_maps_for_layer5_random_queries(image_index, token_attention_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_caption = print_captions)
  print_captions = False

In [None]:
# plot heat_maps per layer
indexes = [0,10,20,25,31]
# indexes = random.sample(range(len(captions)), 5)
image_index = 35
for i in range(len(token_attention_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  print( token_attention_evolution[i].shape)
  plot_heat_maps_for_layer5_random_queries(image_index, token_attention_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_caption = print_captions)
  print_captions = False

In [None]:
# plot heat_maps per layer
indexes = [0,10,20,25,31]
# indexes = random.sample(range(len(captions)), 5)
image_index = 120
for i in range(len(token_attention_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  print( token_attention_evolution[i].shape)
  plot_heat_maps_for_layer5_random_queries(image_index, token_attention_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_caption = print_captions)
  print_captions = False

In [None]:
# plot heat_maps per layer
indexes = [0,10,20,25,31]
# indexes = random.sample(range(len(captions)), 5)
image_index = 15
for i in range(len(token_attention_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  print( token_attention_evolution[i].shape)
  plot_heat_maps_for_layer5_random_queries(image_index, token_attention_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_caption = print_captions)
  print_captions = False

In [None]:
# plot heat_maps per layer
indexes = [0,10,20,25,31]
# indexes = random.sample(range(len(captions)), 5)
image_index = 0
for i in range(len(token_attention_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  print( token_attention_evolution[i].shape)
  plot_heat_maps_for_layer5_random_queries(image_index, token_attention_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_caption = print_captions)
  print_captions = False

In [None]:
# plot heat_maps per layer
indexes = [0,10,20,25,31]
# indexes = random.sample(range(len(captions)), 5)
image_index = 90
for i in range(len(token_attention_evolution)):
  print(f"Layer: {i}")
  if i == 0:
    print_captions = True
  print( token_attention_evolution[i].shape)
  plot_heat_maps_for_layer5_random_queries(image_index, token_attention_evolution[i], annotations_dir, ids, captions, indexes = indexes, print_caption = print_captions)
  print_captions = False

### Analysis

In overall, we can notice that, at early stages of the architecture (first layers) each query (among this subset) appear to look at complementary features in the image, which imply more especializatoin. Then, as you go deeper in the architecture, all of them end up converging to a very similar attention map, which could be related to patches carrying more relevant and generic information about the specific query. Moreover, the specific behavior for each query appears to be consistent for all the images we considered.

An important outcome out of this analysis is that it would be interesting to make experiments in order to check how many queries or layers are actually needed to maintain good results in the task. Because, we hypothesize that a smaller number of queries or even less layers would be enough for the task.