In [1]:
import os, gc, sys, inspect

os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_OFFLINE"] = "1"
%load_ext autoreload

In [2]:
import os
from typing import List, Tuple
from PIL import Image
import torch
import torch.nn.functional as F
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
import matplotlib.pyplot as plt


class Qwen2VLInference:
    def __init__(self, model_name: str, device: str = "cuda"):
        self.device = device
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype="auto",
            attn_implementation="eager",
        ).to(device)
        self.processor = Qwen2VLProcessor.from_pretrained(model_name)
        self.apply_monkey_patch()

    def apply_monkey_patch(self):
        """
        Apply the monkey patch to modify model behavior.
        """
        from qwen_mod import get_rope_index_modified  # Adjust this to your module's structure
        self.model.get_rope_index = get_rope_index_modified.__get__(self.model)

    def preprocess_inputs(self, image_path: str, conversation: List[dict]):
        image = Image.open(image_path)
        text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
        inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
        return inputs.to(self.device), text_prompt

    def generate_output(self, inputs, max_new_tokens=512):
        with torch.no_grad():
            return self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                return_dict_in_generate=True,
                output_hidden_states=True,
                output_attentions=True,
                use_cache=True,
            )

    def calculate_attention(self, output_attn, pref_len, full_len):
        prefill_attn = output_attn[0]
        assert prefill_attn[0].shape[0] == 1
        full_attn = []
        for l, layer in enumerate(prefill_attn):
            layer = layer.cpu().squeeze(0).float()
            layer = F.pad(layer, (0, full_len - pref_len, 0, full_len - pref_len))
            for i in range(full_len - pref_len):
                cur_attn = output_attn[i + 1][l].cpu().squeeze(0)[:, 0, :].float()
                layer[:, pref_len + i, :pref_len + i + 1] = cur_attn
            full_attn.append(layer)
        return torch.stack(full_attn).mean(dim=(0, 1))

    def filter_vision_tokens(self, inputs, mean_attn, vision_start_idx, vision_end_idx, keep_ratio=0.6):
        image_attention = torch.mean(mean_attn[vision_start_idx:vision_end_idx], dim=0)
        _, top_indices = image_attention.topk(int(len(image_attention) * keep_ratio))
        top_indices += vision_start_idx
        for i in range(vision_start_idx, vision_end_idx):
            if i in top_indices:
                inputs["attention_mask"][0, i] = True
            else:
                inputs["attention_mask"][0, i] = False
        return inputs

    def visualize_attention(self, mean_attn, full_len):
        sqrt_attn_map = np.sqrt(mean_attn.numpy())
        plt.imshow(sqrt_attn_map[1:, 1:], cmap="inferno")
        plt.colorbar()
        plt.show()

    def run_two_stage_inference(self, image_path: str, conversation: List[dict], keep_ratio=0.6):
        # Stage 1: Initial inference
        inputs, text_prompt = self.preprocess_inputs(image_path, conversation)
        output_ids = self.generate_output(inputs)
        output_attn = output_ids.attentions
        pref_len = output_attn[0][0].shape[3]
        full_len = output_attn[-1][0].shape[3]

        # Calculate attention map
        mean_attn = self.calculate_attention(output_attn, pref_len, full_len)

        # Visualize attention (optional)
        #self.visualize_attention(mean_attn, full_len)

        # Filter vision tokens based on attention
        vision_start_idx = inputs["input_ids"][0].tolist().index(self.model.config.vision_start_token_id)
        vision_end_idx = inputs["input_ids"][0].tolist().index(self.model.config.vision_end_token_id)
        inputs = self.filter_vision_tokens(inputs, mean_attn, vision_start_idx, vision_end_idx, keep_ratio)

        # Stage 2: Inference with filtered tokens
        refined_output_ids = self.generate_output(inputs)
        refined_text = self.processor.batch_decode(
            refined_output_ids.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        return refined_text




In [3]:
if __name__ == "__main__":
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
    handler = Qwen2VLInference(model_name)
    
    image_path = "examples/image.png"
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": "Find the missing amounts for company B. Options are ['$63,020', '$58,410', '$71,320', '$77,490']"},
            ],
        },
    ]

    result = handler.run_two_stage_inference(image_path, conversation, keep_ratio=0.6)
    print("Final Output:", result)

`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

  return F.conv3d(
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


Final Output: ["system\nYou are a helpful assistant.\nuser\nFind the missing amounts for company B. Options are ['$63,020', '$58,410', '$71,320', '$77,490']\nassistant\nTo find the missing amounts for Company B, we need to balance the income statement. The income statement is structured as follows:\n\n\\[ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ 

In [None]:
from datasets import load_dataset
# 加载数据集
dataset = load_dataset("MMMU/MMMU")