In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

import random
import numpy as np
from tqdm import tqdm
from PIL import Image
import re
from transformers import LlavaOnevisionForConditionalGeneration, AutoProcessor
from datasets import load_dataset, concatenate_datasets



In [2]:
model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
device = "cuda"
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    trust_remote_code=True,
).to(device)
processor = AutoProcessor.from_pretrained(
    model_id,
    trust_remote_code=True
)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [24]:
all_subs = [
    # 'Accounting', 
    'Agriculture', 'Architecture_and_Engineering', 'Art',
    # 'Art_Theory','Basic_Medical_Science','Biology','Chemistry','Clinical_Medicine',
    # 'Computer_Science','Design','Diagnostics_and_Laboratory_Medicine','Economics',
    # 'Electronics','Energy_and_Power','Finance','Geography','History','Literature',
    # 'Manage','Marketing','Materials','Math','Mechanical_Engineering','Music',
    'Pharmacy','Physics','Psychology','Public_Health','Sociology'
]
sub_dataset_list = [load_dataset("MMMU/mmmu", subject, split="validation") for subject in all_subs]
dataset = concatenate_datasets(sub_dataset_list)

In [19]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import math

def get_attention_maps(output, inputs, model, processor):
    """
    Extract attention maps from model outputs.
    """
    first_layer_attention = output.attentions[0]
    attention_sum = torch.zeros(first_layer_attention[0].shape[1:], device=first_layer_attention[0].device)
    num_heads = len(first_layer_attention)
    
    for head_attention in first_layer_attention:
        attention_sum += head_attention[0]
    
    attention_weights = attention_sum / num_heads
    print(f"Attention weights shape: {attention_weights.shape}")
    return attention_weights

def find_image_token_positions(input_ids, processor):
    """
    Find positions of image tokens in the input sequence.
    """
    image_token = "<image>"
    image_token_ids = processor.tokenizer.encode(image_token, add_special_tokens=False)
    
    image_positions = []
    for i in range(len(input_ids[0]) - len(image_token_ids) + 1):
        if all(input_ids[0][i + j] == image_token_ids[j] for j in range(len(image_token_ids))):
            image_positions.append(i)
    
    return torch.tensor(image_positions, device=input_ids.device)

def get_adjusted_grid_size(num_tokens, patch_size, image_size):
    """
    Calculate grid size based on the model's patch processing.
    """
    height, width = image_size
    
    # Calculate base grid dimensions
    base_h = height // patch_size
    base_w = width // patch_size
    
    # Calculate total patches needed
    total_patches = base_h * base_w
    
    return base_h, base_w, total_patches

def reshape_attention_map(attention_tensor, grid_h, grid_w, num_tokens):
    """
    Reshape attention tensor to match the grid dimensions.
    """
    # Take only the patch tokens (exclude base token if present)
    patch_attention = attention_tensor[1:] if num_tokens > grid_h * grid_w else attention_tensor
    
    # Ensure we have enough values
    attention_map = torch.zeros(grid_h * grid_w, device=attention_tensor.device)
    attention_map[:len(patch_attention)] = patch_attention[:grid_h * grid_w]
    
    return attention_map.reshape(grid_h, grid_w)

def process_attention_for_image(attention_weights, image_size, input_ids, processor, model):
    """
    Process attention weights to create an attention map.
    """
    # Find image token positions
    image_positions = find_image_token_positions(input_ids, processor)
    num_tokens = len(image_positions)
    print(f"Found {num_tokens} image token positions")
    
    # Get attention for image tokens
    image_attention = attention_weights[:, image_positions]
    print(f"Image attention shape: {image_attention.shape}")
    
    # Average across sequence dimension first
    token_attention = image_attention.mean(dim=0)
    token_importance = token_attention.mean(dim=1)
    print(f"Token importance shape: {token_importance.shape}")
    
    # Get patch size from model config
    patch_size = model.config.vision_config.patch_size
    print(f"Patch size: {patch_size}")
    
    # Calculate grid dimensions
    grid_h, grid_w, total_patches = get_adjusted_grid_size(num_tokens, patch_size, image_size)
    print(f"Grid dimensions: {grid_h}x{grid_w} (total patches: {total_patches})")
    
    # Reshape attention to grid
    attention_map = reshape_attention_map(token_importance, grid_h, grid_w, num_tokens)
    
    # Upsample to original image size
    attention_map = F.interpolate(
        attention_map.unsqueeze(0).unsqueeze(0),
        size=image_size,
        mode='bilinear',
        align_corners=False
    )
    
    return attention_map.squeeze().cpu().numpy()

def visualize_attention(image, output, inputs, model, processor, save_path=None):
    """
    Create and save attention visualization.
    """
    try:
        # Get attention maps
        attention_weights = get_attention_maps(output, inputs, model, processor)
        
        # Process attention for visualization
        attention_map = process_attention_for_image(
            attention_weights,
            image.size[::-1],  # Convert (width, height) to (height, width)
            inputs['input_ids'],
            processor,
            model
        )
        
        # Normalize attention map
        attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
        
        # Create visualization
        plt.figure(figsize=(15, 5))
        
        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(image)
        plt.title('Original Image')
        plt.axis('off')
        
        # Attention heatmap
        plt.subplot(1, 3, 2)
        plt.imshow(attention_map, cmap='hot')
        plt.title('Attention Heatmap')
        plt.colorbar()
        plt.axis('off')
        
        # Overlay
        plt.subplot(1, 3, 3)
        plt.imshow(image)
        plt.imshow(attention_map, cmap='hot', alpha=0.5)
        plt.title('Attention Overlay')
        plt.axis('off')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
            plt.close()
        else:
            plt.show()
            
    except RuntimeError as e:
        if "out of memory" in str(e):
            torch.cuda.empty_cache()
            print("Out of memory error. Try adjusting parameters.")
        raise e

def analyze_vqa_attention(image, question, answers, model, processor, device):
    """
    Analyze attention patterns for a VQA sample.
    """
    # Prepare input
    instruct = "you are an advanced question answer model, please answer this question to the best of your ability with a single letter in brackets like this [X]."
    content = [
        {"type": "text", "text": instruct},
        {"type": "image"},
        {"type": "text", "text": question},
        {"type": "text", "text": answers},
    ]
    response = {"role": "user", "content": content}
    
    # Process input
    prompt = processor.apply_chat_template([response], add_generation_prompt=True)
    inputs = processor(
        text=[prompt],
        images=[image],
        return_tensors="pt",
        padding=True
    )
    
    # Move inputs to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate with attention outputs
    generation_config = {
        "max_new_tokens": 32,
        "temperature": 1.0,
        "do_sample": False,
        "num_beams": 1,
        "repetition_penalty": 1.0,
        "length_penalty": 1.0,
        "early_stopping": False,
        "pad_token_id": processor.tokenizer.eos_token_id,
        "return_dict_in_generate": True,
        "output_attentions": True
    }
    
    with torch.no_grad():
        output = model.generate(**inputs, **generation_config)
    
    # Get model response
    response = processor.batch_decode(
        output.sequences,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0].split("\n")[-1]
    
    print(f"Model response: {response}")
    
    # Print useful debugging information
    decoded = processor.tokenizer.decode(inputs['input_ids'][0])
    print("\nA few tokens around first image token:")
    pos = decoded.find("<image>")
    print(decoded[max(0, pos-50):min(len(decoded), pos+50)])
    
    # Visualize attention
    visualize_attention(image, output, inputs, model, processor)
    
    return output, inputs

In [25]:
# Get a sample from your dataset
sample = dataset[0]

# Analyze attention patterns
output, inputs = analyze_vqa_attention(
    sample['image_1'].convert('RGB'),
    sample['question'],
    sample['options'],
    model,
    processor,
    device
)

OutOfMemoryError: CUDA out of memory. Tried to allocate 186.00 MiB. GPU 0 has a total capacty of 39.38 GiB of which 38.12 MiB is free. Including non-PyTorch memory, this process has 39.33 GiB memory in use. Of the allocated memory 37.86 GiB is allocated by PyTorch, and 992.97 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF