# Qwen3-VL Attention Extraction for Spatial Reasoning

This notebook extracts **cross-attention** from Qwen3-VL-4B-Instruct to understand how the model attends to image regions when reasoning about spatial relations.

**Key Difference from CLIP:**
- CLIP: Only self-attention (no text → image interaction)
- Qwen3-VL: Has cross-attention where text tokens attend to image patches

**What we'll extract:**
1. Vision encoder self-attention (image patches attending to each other)
2. **Cross-attention** (text tokens attending to image patches) ← Most important!
3. Language model self-attention (text attending to text + image)

In [None]:
import sys
sys.path.append('../..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

from dataset_zoo.aro_datasets import get_controlled_images_b
from mechanistic.utils.visualization import (
    plot_attention_heatmap, 
    plot_attention_on_image,
    plot_cross_attention_attribution
)
from mechanistic.utils.hooks import AttentionCache
from mechanistic.utils.metrics import attention_entropy, attention_concentration

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Step 1: Load Qwen3-VL Model

We'll load the model with `output_attentions=True` to get attention weights.

In [None]:
# Model path (update if needed)
model_path = "/leonardo_work/EUHPC_D27_102/compmech/models/Qwen3-VL-4B-Instruct"

print(f"Loading Qwen3-VL from: {model_path}")

# Load model
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # output_attentions=True  # This needs to be set during forward pass
)

# Load processor
processor = AutoProcessor.from_pretrained(model_path)

print(f"Model loaded successfully")
print(f"Device: {model.device}")

## Step 2: Load Sample from Controlled Images B

In [None]:
# Load dataset
dataset = get_controlled_images_b(
    image_preprocess=None,
    download=False,
    root_dir='/leonardo_work/EUHPC_D27_102/compmech/whatsup_vlms_data'
)

print(f"Dataset size: {len(dataset)}")

# Get a sample
sample_idx = 0
sample = dataset[sample_idx]

# Extract caption and objects
caption = sample['caption_options'][0]
words = caption.split()
object1 = words[1]
object2 = words[-1]

print(f"\nSample {sample_idx}:")
print(f"Caption: {caption}")
print(f"Object 1: {object1}")
print(f"Object 2: {object2}")

# Display image
image = sample['image_options'][0]
plt.figure(figsize=(8, 6))
plt.imshow(image)
plt.title(f"Sample Image\n{caption}")
plt.axis('off')
plt.show()

## Step 3: Create Prompt in Qwen Format

In [None]:
# Create question
question = f"Where is the {object1} in relation to the {object2}? Answer with left, right, front or behind."
print(f"Question: {question}")

# Prepare messages in Qwen format
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": question},
        ],
    }
]

# Apply chat template
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(f"\nFormatted prompt (first 200 chars):\n{text[:200]}...")

## Step 4: Process Input and Extract Attention

**Important:** Qwen-VL processes images into vision tokens that get concatenated with text tokens. We need to:
1. Identify which tokens are image tokens
2. Extract attention from text tokens to image tokens (cross-attention pattern)

In [None]:
# Process inputs
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to(model.device)

print(f"Input shape: {inputs['input_ids'].shape}")
print(f"Total tokens: {inputs['input_ids'].shape[1]}")

# Decode tokens to see the sequence
tokens = processor.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
print(f"\nFirst 20 tokens: {tokens[:20]}")
print(f"Last 20 tokens: {tokens[-20:]}")

## Step 5: Forward Pass with Attention Output

In [None]:
# Forward pass with attention
with torch.no_grad():
    outputs = model(
        **inputs,
        output_attentions=True,
        return_dict=True
    )

print(f"Output keys: {outputs.keys()}")

# Get attentions
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
    attentions = outputs.attentions
    print(f"\nNumber of layers with attention: {len(attentions)}")
    print(f"Attention shape (first layer): {attentions[0].shape}")
    print(f"Format: [batch, heads, seq_len, seq_len]")
else:
    print("\nWarning: No attention weights found in outputs!")
    print("This might mean the model doesn't support output_attentions or we need a different approach.")

## Step 6: Identify Image vs Text Tokens

Qwen-VL uses special tokens to mark image regions. We need to identify:
- `<|vision_start|>` and `<|vision_end|>` markers
- Image patch tokens between these markers

In [None]:
# Find image token positions
token_ids = inputs['input_ids'][0]
tokens_list = [processor.tokenizer.decode(t) for t in token_ids]

# Look for vision markers
vision_start_token = "<|vision_start|>"
vision_end_token = "<|vision_end|>"
image_pad_token = "<|image_pad|>"

# Find indices
vision_start_idx = None
vision_end_idx = None

for i, token in enumerate(tokens_list):
    if vision_start_token in token:
        vision_start_idx = i
    if vision_end_token in token:
        vision_end_idx = i
        break

if vision_start_idx and vision_end_idx:
    print(f"Image tokens span: [{vision_start_idx}, {vision_end_idx}]")
    print(f"Number of image tokens: {vision_end_idx - vision_start_idx - 1}")
    
    # Calculate grid size (assuming square patches)
    num_image_tokens = vision_end_idx - vision_start_idx - 1
    grid_size = int(np.sqrt(num_image_tokens))
    print(f"Image grid size: {grid_size}x{grid_size}")
else:
    print("Could not find vision tokens!")
    print(f"Tokens: {tokens_list}")

## Step 7: Extract Cross-Attention (Text → Image)

We want to see how text tokens (especially spatial words like "left", "right") attend to image patches.

In [None]:
if attentions and vision_start_idx and vision_end_idx:
    # Get last layer attention
    last_layer_attn = attentions[-1]  # [batch, heads, seq, seq]
    
    # Average over heads
    avg_attn = last_layer_attn[0].mean(0)  # [seq, seq]
    
    # Extract cross-attention: text tokens → image tokens
    # Text tokens are after vision_end_idx
    text_start_idx = vision_end_idx + 1
    
    # Attention from text to image patches
    cross_attn = avg_attn[text_start_idx:, vision_start_idx+1:vision_end_idx]
    
    print(f"Cross-attention shape: {cross_attn.shape}")
    print(f"Format: [num_text_tokens, num_image_patches]")
    
    # Get text tokens for labeling
    text_tokens = tokens_list[text_start_idx:]
    print(f"\nText tokens: {text_tokens[:20]}...")  # First 20 text tokens
else:
    print("Cannot extract cross-attention - missing attention or vision token indices")

## Step 8: Visualize Cross-Attention for Key Words

Let's find spatial words (left, right, front, behind) and object names in the text, then visualize where they attend.

In [None]:
if 'cross_attn' in locals():
    # Find indices of interesting tokens
    spatial_words = ['left', 'right', 'front', 'behind', 'where']
    object_words = [object1.lower(), object2.lower()]
    interesting_words = spatial_words + object_words
    
    interesting_indices = []
    interesting_labels = []
    
    for i, token in enumerate(text_tokens):
        token_clean = token.strip().lower()
        for word in interesting_words:
            if word in token_clean:
                interesting_indices.append(i)
                interesting_labels.append(f"{token}({i})")
                break
    
    print(f"Found {len(interesting_indices)} interesting tokens: {interesting_labels}")
    
    if len(interesting_indices) > 0:
        # Plot attention for each interesting token
        n_tokens = min(len(interesting_indices), 6)  # Max 6 tokens
        fig, axes = plt.subplots(1, n_tokens + 1, figsize=(4 * (n_tokens + 1), 4))
        
        # Plot original image
        axes[0].imshow(image)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Plot attention for each token
        for plot_idx, token_idx in enumerate(interesting_indices[:n_tokens]):
            token_attn = cross_attn[token_idx]  # [num_image_patches]
            
            # Reshape to 2D grid
            attn_grid = token_attn.cpu().reshape(grid_size, grid_size).numpy()
            
            # Normalize
            attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min() + 1e-8)
            
            # Resize to image size
            from scipy.ndimage import zoom
            img_array = np.array(image)
            zoom_factor = (img_array.shape[0] / grid_size, img_array.shape[1] / grid_size)
            attn_resized = zoom(attn_grid, zoom_factor, order=1)
            
            # Plot
            axes[plot_idx + 1].imshow(img_array)
            axes[plot_idx + 1].imshow(attn_resized, cmap='hot', alpha=0.6)
            axes[plot_idx + 1].set_title(f'Token: {interesting_labels[plot_idx]}')
            axes[plot_idx + 1].axis('off')
        
        plt.suptitle(f'Cross-Attention: Text Tokens → Image Patches\n{caption}', fontsize=14, y=1.02)
        plt.tight_layout()
        plt.show()
    else:
        print("No interesting tokens found in text!")

## Step 9: Aggregate Attention for Spatial Words

In [None]:
if 'cross_attn' in locals() and len(interesting_indices) > 0:
    # Get attention for spatial words only
    spatial_token_indices = []
    for i, token in enumerate(text_tokens):
        if any(word in token.lower() for word in spatial_words):
            spatial_token_indices.append(i)
    
    if len(spatial_token_indices) > 0:
        # Average attention across all spatial tokens
        spatial_attn = cross_attn[spatial_token_indices].mean(0)  # [num_image_patches]
        
        # Reshape and visualize
        attn_grid = spatial_attn.cpu().reshape(grid_size, grid_size).numpy()
        attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min() + 1e-8)
        
        # Resize
        from scipy.ndimage import zoom
        img_array = np.array(image)
        zoom_factor = (img_array.shape[0] / grid_size, img_array.shape[1] / grid_size)
        attn_resized = zoom(attn_grid, zoom_factor, order=1)
        
        # Plot
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
        
        ax1.imshow(image)
        ax1.set_title('Original Image')
        ax1.axis('off')
        
        ax2.imshow(img_array)
        im = ax2.imshow(attn_resized, cmap='hot', alpha=0.6)
        ax2.set_title('Averaged Spatial Word Attention')
        ax2.axis('off')
        
        plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
        plt.suptitle(f'Where does the model look for spatial reasoning?\n{caption}', fontsize=14)
        plt.tight_layout()
        plt.show()
    else:
        print("No spatial words found in text!")

## Step 10: Analyze Attention Metrics

In [None]:
if 'cross_attn' in locals():
    # Calculate metrics for cross-attention
    entropy = attention_entropy(cross_attn, dim=-1)
    concentration = attention_concentration(cross_attn, k=5)
    
    print(f"Cross-Attention Metrics:")
    print(f"  Average entropy: {entropy.mean():.3f} (lower = more focused)")
    print(f"  Average top-5 concentration: {concentration.mean():.3f} (higher = more concentrated)")
    
    # Plot distribution
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    ax1.hist(entropy.cpu().numpy(), bins=30, alpha=0.7, edgecolor='black')
    ax1.axvline(entropy.mean().item(), color='red', linestyle='--', label=f'Mean: {entropy.mean():.3f}')
    ax1.set_xlabel('Attention Entropy')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Attention Entropy\n(per text token)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.hist(concentration.cpu().numpy(), bins=30, alpha=0.7, edgecolor='black', color='orange')
    ax2.axvline(concentration.mean().item(), color='red', linestyle='--', label=f'Mean: {concentration.mean():.3f}')
    ax2.set_xlabel('Top-5 Concentration')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Distribution of Attention Concentration\n(per text token)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## Next Steps

### ✅ What This Notebook Demonstrates:
1. Loading Qwen3-VL and extracting attention weights
2. Identifying image vs text tokens in the sequence
3. Extracting **cross-attention** (text → image)
4. Visualizing where spatial words attend on the image
5. Analyzing attention metrics (entropy, concentration)

### 🚀 Next:
1. **Test on multiple samples**
   - Compare correct vs incorrect predictions
   - Different spatial relations (left/right vs front/behind)
   
2. **Layer-wise analysis**
   - How does cross-attention evolve across layers?
   - Which layers are most important?
   
3. **Statistical analysis**
   - Do correct predictions have more focused attention?
   - Do spatial words consistently attend to relevant regions?
   
4. **Build automated pipeline**
   - Batch processing
   - Save attention maps
   - Generate summary statistics