In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='5'

import requests
from PIL import Image
from io import BytesIO

In [2]:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Load the Llava model with attention output enabled
model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf", output_attentions=True, torch_dtype=torch.float16, device_map="auto"
)

# Load the processor for handling text and vision inputs
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
# URL of the image
image_url = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg"

# Download the image
response = requests.get(image_url)
if response.status_code == 200:
    img = Image.open(BytesIO(response.content))
    img_resized = img.resize((120, 120))  # Resize the image to 120x120 pixels

In [4]:
# Define a conversation with an image and a descriptive text prompt
conversation_1 = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Describe this image."},
        ],
    },
]

# Process the chat message into text format for the model
prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)

# Convert inputs into tensor format and move them to the GPU
inputs = processor(images=[img_resized], text=[prompt_1], return_tensors="pt").to(model.device, torch.float16)

In [5]:
# Run the model and extract attention maps
outputs = model(**inputs)




In [8]:
# Identify positions of image tokens in the input sequence
img_tok_idx = [idx for idx, val in enumerate(inputs.input_ids[0]) if val == 32000]
len(img_tok_idx)

576

In [6]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact

def visualize_attention_maps(attention_maps):
    num_layers = len(attention_maps)
    num_heads = attention_maps[0].shape[1]
    seq_len = attention_maps[0].shape[2]

    def visualize(layer_idx, head_idx):
        attn_map = attention_maps[layer_idx][0, head_idx].detach().cpu().numpy()
        mask = np.triu(np.ones_like(attn_map), k=1)
        plt.figure(figsize=(15, 7))
        plt.subplot(1, 2, 2)
        masked_attn_map = np.ma.masked_where(mask == 1, attn_map)
        plt.imshow(masked_attn_map, cmap="viridis", interpolation="nearest")
        plt.colorbar(label="Attention Score")
        plt.xlabel("Key Position")
        plt.ylabel("Query Position")
        plt.title(f"Attention Map (Lower Triangle) - Layer {layer_idx}, Head {head_idx}")
        plt.tight_layout()
        plt.show()

    interact(visualize, layer_idx=(0, num_layers - 1), head_idx=(0, num_heads - 1))

visualize_attention_maps(outputs.attentions)

interactive(children=(IntSlider(value=15, description='layer_idx', max=31), IntSlider(value=15, description='h…

In [7]:
def visualize_l2_norms(past_key_values):
    past_keys = [kv[0] for kv in past_key_values]
    num_layers = len(past_keys)
    num_heads = past_keys[0].shape[1]
    seq_len = past_keys[0].shape[2]

    l2_norms = np.zeros((num_layers, num_heads, seq_len))
    for layer in range(num_layers):
        keys = past_keys[layer]
        l2_norms[layer] = torch.norm(keys, p=2, dim=3).squeeze(0).detach().cpu().numpy()

    def plot_l2_norm(layer_idx, head_idx):
        plt.figure(figsize=(15, 7))
        plt.plot(range(seq_len), l2_norms[layer_idx, head_idx], marker="o", linestyle="-")
        plt.xlabel("Sequence Position")
        plt.ylabel("L2 Norm")
        plt.title(f"L2 Norm of Past Keys - Layer {layer_idx}, Head {head_idx}")
        plt.grid(True)
        plt.show()

    interact(plot_l2_norm, layer_idx=(0, num_layers-1), head_idx=(0, num_heads-1))

visualize_l2_norms(outputs.past_key_values)


interactive(children=(IntSlider(value=15, description='layer_idx', max=31), IntSlider(value=15, description='h…