In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_OFFLINE"] = "1"
# os.environ["MAX_PIXELS"]=

In [2]:
import torch
import torch.nn.functional as F
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

model = Qwen2VLForConditionalGeneration.from_pretrained(
 "Qwen/Qwen2-VL-7B-Instruct",
 torch_dtype=torch.bfloat16,
 attn_implementation="eager", # flash_attention_2 also produces the same error
 device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "examples/image.png",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
with torch.no_grad():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=512,
        return_dict_in_generate=True,
        output_hidden_states=True,
        output_attentions=True,
        use_cache=True,
    )


  return F.conv3d(


In [None]:
from typing import Tuple, List
from torch import Tensor
import copy

vision_start_token_idx = inputs['input_ids'][0].tolist().index(model.config.vision_start_token_id)
vision_end_token_idx = inputs['input_ids'][0].tolist().index(model.config.vision_end_token_id)

output_attn: Tuple[Tuple[Tensor, ...], ...] = copy.deepcopy(output_ids.attentions)
# get the length of the prefilling and full attention
pref_len: int = output_attn[0][0].shape[3]
full_len: int = output_attn[-1][0].shape[3]
prefill_attn: Tuple[Tensor, ...] = output_attn[0]

# batchsize should be 1
assert prefill_attn[0].shape[0] == 1
full_attn = []
for l, layer in enumerate(prefill_attn):
    layer = layer.cpu().squeeze(0).float()
    layer = F.pad(layer, (0, full_len - pref_len, 0, full_len - pref_len))
    for i in range(full_len - pref_len):
        # print(i, )
        # cur_attn = output_attn[i][l].cpu().squeeze(0).float()
        cur_attn = output_attn[i + 1][l].cpu().squeeze(0)[:, 0, :].float()
        # print(cur_attn.shape)
        layer[:, pref_len + i, :pref_len + i + 1] = cur_attn
    full_attn.append(layer)
mean_attn = torch.stack(full_attn).mean(dim=(0, 1))

image_output_attn = torch.mean(mean_attn[pref_len:, vision_start_token_idx + 1:vision_end_token_idx], dim=0)

def calculate_dynamic_threshold(attn, percentile=98):
    hist = torch.histc(attn, bins=100)
    cdf = torch.cumsum(hist, dim=0)/torch.sum(hist)
    threshold = torch.argmax((cdf > percentile/100).float()).item()/100
    return threshold

threshold = calculate_dynamic_threshold(image_output_attn)
print(threshold)

0.33


In [None]:
def weighted_vision_attention(attn_map, keep_percentage=threshold):
    # Get the attention values sorted in descending order
    sorted_attention, sorted_indices = torch.sort(attn_map, descending=True)
    
    # Determine the number of tokens to keep
    num_tokens_to_keep = int(len(sorted_attention) * keep_percentage)
    
    # Create a weight mask where the top tokens have higher weight
    weight_vision_token = torch.zeros_like(attn_map, dtype=torch.float)
    
    # Assign weights for tokens (top tokens get higher weights, others get smaller weights)
    weight_vision_token[sorted_indices[:num_tokens_to_keep]] = 1.0
    weight_vision_token[sorted_indices[num_tokens_to_keep:]] = torch.linspace(0.0, 1.0, len(sorted_attention) - num_tokens_to_keep)

    return weight_vision_token
    
weight_vision_token = weighted_vision_attention(image_output_attn)

In [3]:
# weight_vision_token.size()
input_ids = inputs["input_ids"]
n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
print(n_image_tokens)

440


In [None]:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
pixel_values = inputs["pixel_values"]
image_grid_thw = inputs["image_grid_thw"]

inputs_embeds = model.model.embed_tokens(input_ids)
if pixel_values is not None:
    pixel_values = pixel_values.type(model.visual.get_dtype())
    image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
    n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
    n_image_features = image_embeds.shape[0]
    if n_image_tokens != n_image_features:
        raise ValueError(
            f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
        )
    image_mask = (
        (input_ids == model.config.image_token_id)
        .unsqueeze(-1)
        .expand_as(inputs_embeds)
        .to(inputs_embeds.device)
    )
    image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
    image_embeds *= weight_vision_token[:, None]
    inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if attention_mask is not None:
    attention_mask = attention_mask.to(inputs_embeds.device)

print(image_embeds.shape)



In [None]:
generated_ids = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=2048)

In [None]:
output_text = processor.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
_ = [print(output) for output in output_text]