In [None]:
import requests
from PIL import Image
import matplotlib.pyplot as plt
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import seaborn as sns
import matplotlib.colors as Colormap
from matplotlib.colors import LogNorm
import numpy as np
import time

In [None]:
import numpy as np
import torch

def count_top_nprcnt_contribution_llm(llm_attention, vision_tokens_index, topk=100):
    # attn: tuple of torch tensor [1, num_heads, n_all_tokens, n_all_tokens]
    # vision_tokens_index: torch tensor [n_vision_tokens]
    num_layers = len(llm_attention)
    text_token_start = vision_tokens_index[-1]+1

    for layer in range(num_layers):
        attn = llm_attention[layer][0]  # [num_heads, n_all_tokens, n_all_tokens]
        vision_attn = attn[:, text_token_start:, vision_tokens_index]  # [num_heads, n_text_tokens, n_vision_tokens]
        vision_attn = vision_attn.sum(dim=0) # [n_text_tokens, n_vision_tokens]
        vision_attn = vision_attn.sum(dim=0) # [n_vision_tokens]

        vision_attn_sorted, _ = torch.sort(vision_attn, descending=True)
        # sum of topk tokens
        vision_attn_topk_sum = vision_attn_sorted[:topk].sum()
        # sum of all tokens
        vision_attn_sum = vision_attn.sum()

        # attein by topk
        print(f"Layer {layer}: {vision_attn_topk_sum/vision_attn_sum*100:.2f}%")


def count_top_npercents_contribution(vit_attention, topk=100):
    # vit_attention: tuple of torch tensor [1, num_heads, n_all_tokens, n_all_tokens]
    num_layers = len(vit_attention)
    n_all_tokens = vit_attention[0].shape[-1]
    
    for layer in range(num_layers):
        attn = vit_attention[layer][0] # [num_heads, n_all_tokens, n_all_tokens]
        attn = attn.sum(dim=0) # [n_all_tokens, n_all_tokens]
        cls_attn = attn[0, 1:] # [n_all_tokens]
        
        cls_attn_sorted, _ = torch.sort(cls_attn, descending=True)
        # sum of topk tokens
        cls_attn_topk_sum = cls_attn_sorted[:topk].sum()
        # sum of all tokens
        cls_attn_sum = cls_attn.sum()

        # attein by topk
        print(f"Layer {layer}: {cls_attn_topk_sum/cls_attn_sum*100:.2f}%")
        

        

    



In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm

def visualize_attention(multihead_attention, layer=31, stride=1, vision_tokens_index=[]):
    # Move the attention map to CPU and select the specified layer
    multihead_attention = multihead_attention[layer].cpu()  # Shape: (1, num_heads, n_tokens, n_tokens)
    
    # Compute the average across the heads and reshape
    averaged_attention = torch.mean(multihead_attention, dim=1)[0].float()  # Shape: (n_tokens, n_tokens)
    
    # Pooling to reduce size
    averaged_attention = torch.nn.functional.avg_pool2d(averaged_attention.unsqueeze(0).unsqueeze(0), stride, stride).squeeze(0).squeeze(0)
    
    # Color mapping
    cmap = plt.cm.get_cmap("viridis")
    
    # Figure settings
    plt.figure(figsize=(5, 5), dpi=100)
    
    # Normalization for color mapping
    log_norm = LogNorm(vmin=0.0007, vmax=averaged_attention.max())
    
    # Heatmap plot
    ax = sns.heatmap(averaged_attention, cmap=cmap, norm=log_norm)
    
    # Process vision tokens index tensor and add patches to mark them
    vision_tokens_index = vision_tokens_index.cpu().numpy()  # Convert index tensor to CPU and numpy array

    # apply stride to vision tokens index
    vision_tokens_index = vision_tokens_index // stride
    
    # Adding red box to mark vision tokens
    for idx in vision_tokens_index:
        ax.add_patch(plt.Rectangle((idx, idx), 1, 1, fill=True, edgecolor='red', lw=1))


    ax.set_xlabel('Token Index')
    ax.set_ylabel('Token Index')

    # do not show ticks
    ax.set_xticks([])
    ax.set_yticks([])

    # Title
    plt.title(f'Attention Map Visualization for {layer+1}th layer')
    
    plt.show()

In [None]:
model_id = "llava-hf/llava-v1.6-mistral-7b-hf"

In [None]:
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
    attn_implementation="eager",
).to(0)
processor = AutoProcessor.from_pretrained(model_id)

### run down there

In [None]:
url = "https://picsum.photos/400"
response = requests.get(url)

if response.status_code == 200:
    with open('image.jpg', 'wb') as file:
        file.write(response.content)
    print("Image successfully retrieved and saved.")
else:
    print(f"Failed to retrieve image. HTTP Status code: {response.status_code}")
# raw_image = Image.open('./license.png')
raw_image = Image.open('image.jpg')
raw_image.show()

In [None]:
# prompt = "USER: <image>\nTell me the story of two friends and the bear\nASSISTANT:"
prompt = "USER: <image>\nDescribe the image in detail\nASSISTANT:"

In [None]:
model.config.fast_vlm_config = {
    "spatial_budget": 0,
    "alpha_vision_token_budget": 0.2,
    "beta_sub_images_budget": 0.2,
    "clip_attn_layer": 22,
}

In [None]:
# start time
start_time = time.time()
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)

with torch.inference_mode():
    output_ids = model.generate(
        **inputs,
        do_sample=False,
        max_new_tokens=100,
        use_cache=True,
        output_attentions=True,
        output_scores=True,
        return_dict_in_generate=True,
        )

output_text = processor.decode(output_ids['sequences'][0], skip_special_tokens=False)
print(output_text)
inputs = processor(output_text, raw_image, return_tensors='pt').to(0, torch.float16)
with torch.inference_mode():
    output = model(**inputs, output_attentions=True, return_dict = True)
end_time = time.time()
vit_attention = model.vit_attentions
vit_to_llm_mapping = output.vit_to_llm_mapping
llm_attention = output.attentions

# print the time

print(f"Time taken: {end_time - start_time} seconds")

In [None]:
topk = 115

In [None]:
count_top_npercents_contribution(vit_attention, topk=topk)

In [None]:
count_top_nprcnt_contribution_llm(llm_attention, vision_tokens_index=vit_to_llm_mapping[0], topk=topk)