In [None]:
# Query the model on test image

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.module")

from datasets import load_dataset
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
import torch
import matplotlib.pyplot as plt

model_path = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
model_name = "llava_qwen_with_alternating_attn"
model_base = None
load_8bit = False
load_4bit = False
device_map = "auto"
device = "cuda:1"
attn_implementation = "eager"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=model_base,
    model_name=model_name,
    load_8bit=load_8bit,
    load_4bit=load_4bit,
    device_map=device_map,
    torch_dtype="float16",
    attn_implementation=attn_implementation,
    multimodal=True
)
model.config.image_aspect_ratio = "nobase"


In [None]:
dataset = load_dataset("BLINK-Benchmark/BLINK", "Visual_Correspondence", split="val")
item = dataset[0]

In [None]:
image_prompt = "".join(f"Image {i+1}: {DEFAULT_IMAGE_TOKEN}\n" for i in range(4) if item[f"image_{i+1}"] is not None)
question_text = "Question: " + item["question"]
detailed_prompt = "Details: " + item["prompt"]
directive = "Answer with the optionâ€™s letter from the given choices directly."
conv = conv_templates["qwen_1_5"].copy()
conv.append_message(conv.roles[0], image_prompt + "\n" + question_text + "\n" + detailed_prompt + "\n" + directive)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).to("cuda:0")

images = [
    item[f"image_{i+1}"].convert('RGB')
    for i in range(4) if item[f"image_{i+1}"] is not None
]
image_tensors = process_images(images, image_processor, model.config)
image_tensors = [_image.to(dtype=torch.float16, device="cuda:0") for _image in image_tensors]
image_sizes = [image.size for image in images]

with torch.inference_mode():
    output = model.generate(
        input_ids,
        images=image_tensors,
        image_sizes=image_sizes,
        modalities=["image"],
        do_sample=False,
        max_new_tokens=256,
        use_cache=True,
        output_attentions=True,
        return_dict_in_generate=True,
    )

In [None]:
layer_attns = output.attentions[0]
layer_idx = 1
attn = layer_attns[layer_idx][0, -1].cpu().numpy()

plt.imshow(attn.astype(bool))