In [1]:
import requests
from PIL import Image
import matplotlib.pyplot as plt
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import seaborn as sns
import matplotlib.colors as Colormap
from matplotlib.colors import LogNorm
import numpy as np
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np
import torch

def cosine_sim(x, y):
    # x: embedding torch.tensor [1024]
    # y: embedding torch.tensor [1024]
    return (x @ y) / (x.norm() * y.norm())

def vision_token_redundency(image_features, num_window, start_i):
    # image_features: torch.tensor [1, num_tokens, embedding_dim]
    # num_window: int
    # start_i: int
    image_features = image_features[0, :, :]  # [num_tokens, embedding_dim]
    similarity_matrix = np.zeros((num_window, num_window))

    # Calculate cosine similarity for the specified window
    for i in range(num_window):
        for j in range(num_window):
            similarity = cosine_sim(image_features[start_i+i],image_features[start_i+j])
            similarity_matrix[i, j] = similarity.item()  # Convert tensor to a scalar

    # Plot the similarity matrix using a heatmap
    plt.figure(figsize=(6, 6))
    plt.imshow(similarity_matrix, cmap='viridis')
    plt.title('Cosine Similarity')
    plt.colorbar()
    plt.xlabel('token index')
    plt.ylabel('token index')
    plt.show()




def print_accumulated_attn_by_vision_token(attn, vision_tokens_index):
    # attn: tuple of torch tensor [1, num_heads, n_all_tokens, n_all_tokens]
    # vision_tokens_index: torch tensor [n_vision_tokens]
    num_layers = len(attn)

    for layer in range(num_layers):
        layer_attn = attn[layer][0]  # [num_heads, n_all_tokens, n_all_tokens]
        layer_attn = torch.sum(layer_attn, dim=0)  # [n_all_tokens, n_all_tokens]
        vision_attn = layer_attn[: , vision_tokens_index]  # [n_vision_tokens, n_all_tokens]

        sum_vision_attn = torch.sum(vision_attn, dim=0).sum()
        sum_full_attn = torch.sum(layer_attn, dim=0).sum()

        acc_attn_by_vision_token = sum_vision_attn / sum_full_attn
        print(f"Acc attn in Layer {layer}: {acc_attn_by_vision_token:.4f}")



In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm

def visualize_vit_attention(vit_attention,layer):
    attn = vit_attention[layer][0] # [num_heads, n_all_tokens, n_all_tokens]
    attn = attn[:,0,1:] # [num_heads,n_all_tokens]
    attn = attn.mean(dim=0) # [n_all_tokens]

    num_vision_tokens = attn.shape[-1]
    num_grid = int(np.sqrt(num_vision_tokens))

    attn = attn.cpu().numpy().reshape((num_grid,num_grid))
    # Plotting the attention heatmap
    plt.figure(figsize=(5, 5))
    sns.heatmap(attn, annot=False, cmap='viridis', norm=LogNorm(vmin=attn.min()+1e-8, vmax=attn.max()),cbar=False)
    plt.title(f'Layer {layer+1}\'s Vision Tokens Attention Map')
    plt.show()



def visualize_vision_attention_in_llm(llm_attention, layer=0, vit_to_llm_mapping=[], mark_topk=0):
    num_vision_tokens = len(vit_to_llm_mapping)
    text_token_start = vit_to_llm_mapping[-1] + 1
    num_grid = int(np.ceil(np.sqrt(num_vision_tokens)))
    attn = llm_attention[layer][0]  # [num_heads, n_all_tokens, n_all_tokens]
    vision_attn = attn[:, text_token_start:, vit_to_llm_mapping]  # [num_heads, n_text_tokens, n_vision_tokens]
    vision_attn = vision_attn.mean(dim=0)  # [n_text_tokens, n_vision_tokens]
    vision_attn = vision_attn.mean(dim=0)  # [n_vision_tokens]
    vision_attn = vision_attn.cpu().numpy().reshape((num_grid, num_grid))
    
    # Plotting the attention heatmap
    plt.figure(figsize=(5, 5))
    sns.heatmap(vision_attn, annot=False, cmap='viridis', norm=LogNorm(vmin=0.00009, vmax=vision_attn.max()),cbar=False)
    plt.title(f'Layer {layer+1}\'s Vision Tokens Attention Map')

    # if mark_topk > 0:
    #     flat_indices = np.argpartition(vision_attn.flatten(), -mark_topk)[-mark_topk:]  # Get indices of top-k values
    #     topk_indices = np.array(np.unravel_index(flat_indices, vision_attn.shape)).T  # Convert flat indices to 2D indices
    #     # Plot a red box around top-k patches
    #     for idx in topk_indices:
    #         plt.gca().add_patch(plt.Rectangle((idx[1]-0.5, idx[0]-0.5), 1, 1, fill=False, edgecolor='red', lw=2))

    plt.show()



def visualize_attention(multihead_attention, layer=31, stride=1, vision_tokens_index=[]):

    multihead_attention = multihead_attention[layer].cpu()  # Shape: (1, num_heads, n_tokens, n_tokens)
    averaged_attention = torch.mean(multihead_attention, dim=1)[0].float()  # Shape: (n_tokens, n_tokens)
    averaged_attention = torch.nn.functional.avg_pool2d(averaged_attention.unsqueeze(0).unsqueeze(0), stride, stride).squeeze(0).squeeze(0)
    cmap = plt.cm.get_cmap("viridis")
    

    plt.figure(figsize=(5, 5), dpi=100)
    log_norm = LogNorm(vmin=0.0007, vmax=averaged_attention.max())
    ax = sns.heatmap(averaged_attention, cmap=cmap, norm=log_norm)

    ax.set_xlabel('Token Index')
    ax.set_ylabel('Token Index')

    # do not show ticks
    ax.set_xticks([])
    ax.set_yticks([])

    # Title
    plt.title(f'Attention Map Visualization for {layer+1}th layer')
    plt.show()

In [4]:
model_id = "llava-hf/llava-v1.6-vicuna-7b-hf"

In [5]:
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
    attn_implementation="eager",
).to(0)
processor = AutoProcessor.from_pretrained(model_id)

You are using a model of type llava_next to instantiate a model of type llava. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.40s/it]


### run down there

In [6]:
url = "https://picsum.photos/400"
response = requests.get(url)

if response.status_code == 200:
    with open('image.jpg', 'wb') as file:
        file.write(response.content)
    print("Image successfully retrieved and saved.")
else:
    print(f"Failed to retrieve image. HTTP Status code: {response.status_code}")
raw_image = Image.open('../billboard.jpg')
# raw_image.show()

Image successfully retrieved and saved.


In [7]:
# prompt = "USER: <image>\nTell me the story of two friends and the bear\nASSISTANT:"
prompt = "USER: <image>\nWhat is the main text written on the billboard?\nASSISTANT:"

In [8]:
model.config.fast_vlm_config = {
    "spatial_budget": 0,
    "alpha_vision_token_budget": 0.5,
    "beta_sub_images_budget": 0.5,
    "clip_attn_layer": 22,
}

In [9]:
# start time
start_time = time.time()
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)

with torch.inference_mode():
    output_ids = model.generate(
        **inputs,
        do_sample=False,
        max_new_tokens=250,
        use_cache=True,
        output_attentions=True,
        output_scores=True,
        return_dict_in_generate=True,
        )

output_text = processor.decode(output_ids['sequences'][0], skip_special_tokens=False)
print(output_text)
inputs = processor(output_text, raw_image, return_tensors='pt').to(0, torch.float16)
with torch.inference_mode():
    output = model(**inputs, output_attentions=True, return_dict = True)
end_time = time.time()
vit_attention = model.vit_attentions
vit_to_llm_mapping = model.vit_to_llm_mapping
llm_attention = output.attentions
image_features = model.image_features
# print the time
print(f"Time taken: {end_time - start_time} seconds")

ValueError: The following `model_kwargs` are not used by the model: ['image_sizes'] (note: typos in the generate arguments will also show up in this list)

: 

In [None]:
visualize_attention(multihead_attention=llm_attention,layer=9,stride=1,vision_tokens_index=vit_to_llm_mapping[0])

In [None]:
# vision_token_redundency(image_features=image_features,num_window=200,start_i=0)