In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "7"
import json
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import argparse 
import time
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import matplotlib.pyplot as plt

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "pretrained/Qwen2.5-VL-3B-Instruct",
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        attn_implementation="eager",
        device_map="auto",
    )

# default processor
processor = AutoProcessor.from_pretrained("pretrained/Qwen2.5-VL-3B-Instruct")

In [None]:
img_path = "demo_images/catdog.png"
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": img_path,
                "resized_height": 448,
                "resized_width": 448, # 至少要(560 / 28) ** 2 = 400个token
            },
            {"type": "text", "text": "Find the dog in figure"},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True,
    enable_thinking = True, # 设置思考
)

image_inputs, video_inputs = process_vision_info(messages)

inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)

# 添加可选参数，获取返回的attentions
inputs['output_attentions'] = True
for keys in inputs.keys():
    print(keys)
inputs = inputs.to(model.device)

In [None]:
outputs,vit_attns = model(**inputs)

In [None]:
vl_attns = outputs.attentions

In [None]:
print(len(vit_attns))
print(len(vl_attns))

In [None]:
last_vit_attn = vit_attns[-1]
print(last_vit_attn.shape)

In [None]:
last_vl_attn = vl_attns[-1]
print(last_vl_attn.shape)

In [None]:
input_ids = inputs['input_ids']
print(input_ids)
print(input_ids.shape)
print("\n")
print(text)

In [None]:
image_token = 151655
attn_token = 5562
img_token_mask = (input_ids == image_token)
img_token_positions = torch.nonzero(img_token_mask, as_tuple=True)[1]
attn_idx = (input_ids == attn_token).nonzero(as_tuple=True)[1].item()
print(img_token_positions)
print(img_token_positions.shape)
print(attn_idx)

In [None]:
import torchvision
import numpy as np
from torch import nn

def visualize_attention(img, attentions, attn_idx, img_token_idx, patch_size = 28):
    """img不带batch维度"""

    # make the image divisible by the patch size
    w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - \
        img.shape[2] % patch_size
    img = img[:, :w, :h].unsqueeze(0)

    w_featmap = img.shape[-2] // patch_size
    h_featmap = img.shape[-1] // patch_size

    nh = attentions.shape[1]  # number of head

    # keep only the output patch attention
    attentions = attentions[0, :, attn_idx, img_token_idx].reshape(nh, -1)

    attentions = attentions.reshape(nh, w_featmap, h_featmap)

    attentions = nn.functional.interpolate(attentions.unsqueeze(
        0), scale_factor=patch_size, mode="nearest")[0].cpu().detach().numpy()

    return attentions

def plot_attention(img, attention, idx = -1):
    n_heads = attention.shape[0]

    plt.figure(figsize=(10, 10))
    text = ["Original Image", "Head Mean"]
    for i, fig in enumerate([img, np.mean(attention, 0)]):
        plt.subplot(1, 2, i+1)
        print(fig.shape)
        plt.imshow(fig, cmap='inferno',alpha=0.8)
        plt.title(text[i])
    plt.show()

    plt.figure(figsize=(10, 10))
    for i in range(n_heads):
        plt.subplot((n_heads + 2)//3, 3, i+1)
        plt.imshow(attention[i], cmap='inferno')
        plt.title(f"Head n: {i+1}")
    plt.tight_layout()
    if idx != -1:
        
        plt.savefig(f"temp/demo_{idx}.jpg")
    plt.show()

In [None]:
import cv2
img = cv2.imread(img_path)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cv2.resize(img,(448,448))
img = torch.tensor(img).permute(2,0,1)
plt.imshow(img.permute(1,2,0))

In [None]:
show_attens = visualize_attention(img,last_vl_attn.to(torch.float32),attn_idx,img_token_positions)

In [None]:
plot_attention(img.permute(1,2,0),show_attens)

In [None]:
def show_each_layers(vl_attns):
    for i in range(len(vl_attns)):
        show_attens = visualize_attention(img,vl_attns[i].to(torch.float32),attn_idx,img_token_positions)
        plot_attention(img.permute(1,2,0),show_attens,idx = i)

In [None]:
show_each_layers(vl_attns)

: 

In [None]:
def plot_attention(img, attention, idx = -1):
    n_heads = attention.shape[0]

    plt.figure(figsize=(10, 10))
    text = ["Original Image", "Head Mean"]
    plt.imshow(img)
    plt.imshow(np.mean(attention, 0), cmap='inferno', alpha=0.5)
    plt.title(text[1])
    plt.show()

    plt.figure(figsize=(10, 10))
    for i in range(n_heads):
        plt.subplot((n_heads + 2)//3, 3, i+1)
        plt.imshow(attention[i], cmap='inferno')
        plt.title(f"Head n: {i+1}")
    plt.tight_layout()
    if idx != -1:
        
        plt.savefig(f"temp/demo_{idx}.jpg")
    plt.show()
plot_attention(img.permute(1,2,0), show_attens)