# Attention Extraction Prototype

This notebook validates attention extraction from vision-language models for spatial reasoning analysis.

**Goals:**
1. Extract attention weights from a VLM (starting with HuggingFace models)
2. Visualize attention on a sample image
3. Verify we can map attention to image regions

**Models to test:**
- Qwen2.5-VL (HuggingFace version)
- Alternative: PaliGemma (if Qwen attention extraction is difficult)

In [None]:
import sys
sys.path.append('../..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from dataset_zoo.aro_datasets import get_controlled_images_b
from mechanistic.utils.visualization import plot_attention_heatmap, plot_attention_on_image
from mechanistic.utils.hooks import AttentionCache

%matplotlib inline

## Step 1: Load a Sample from Controlled Images B

In [None]:
# Load dataset
dataset = get_controlled_images_b(
    image_preprocess=None,  # No preprocessing, we want raw images
    download=False,
    root_dir='../../data'
)

print(f"Dataset size: {len(dataset)}")

# Get a sample
sample_idx = 0
sample = dataset[sample_idx]

print(f"\nSample {sample_idx}:")
print(f"Caption options: {sample['caption_options']}")
print(f"Correct caption: {sample['caption_options'][0]}")

# Display image
image = sample['image_options'][0]
plt.figure(figsize=(8, 6))
plt.imshow(image)
plt.title(f"Sample Image\n{sample['caption_options'][0]}")
plt.axis('off')
plt.show()

## Step 2: Extract Caption and Create Prompt

For Controlled_B, the format is: "A {object1} {relation} a {object2}"
We'll create a question asking about the spatial relation.

In [None]:
# Parse the caption to extract objects and relation
caption = sample['caption_options'][0]
words = caption.split()
object1 = words[1]  # Position [1]
object2 = words[-1]  # Position [-1]
relation = ' '.join(words[2:-2])  # Everything in between

print(f"Object 1: {object1}")
print(f"Relation: {relation}")
print(f"Object 2: {object2}")

# Create a question prompt
question = f"Where is the {object1} in relation to the {object2}? Answer with left, right, front or behind."
print(f"\nQuestion: {question}")

## Step 3: Load Model with Attention Output

We'll start with a simpler model to validate the approach. Let's try with a vision transformer that we can easily extract attention from.

**Options:**
1. Use CLIP vision encoder (simpler, just for validation)
2. Use Qwen2.5-VL (full VLM, but more complex)
3. Use PaliGemma (good middle ground)

Let's start with **Option 1 (CLIP)** for quick validation, then move to full VLMs.

In [None]:
# Option 1: CLIP Vision Transformer (for validation)
from transformers import CLIPProcessor, CLIPModel

print("Loading CLIP model...")
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name, output_attentions=True)
processor = CLIPProcessor.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Model loaded on {device}")

## Step 4: Extract Attention Weights

In [None]:
# Process inputs
inputs = processor(
    text=[question],
    images=image,
    return_tensors="pt",
    padding=True
).to(device)

# Forward pass with attention
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)

# Get vision attention weights
vision_attentions = outputs.vision_model_output.attentions
print(f"Number of vision layers: {len(vision_attentions)}")
print(f"Attention shape (first layer): {vision_attentions[0].shape}")
print(f"Format: [batch, heads, seq_len, seq_len]")

# Get text attention weights  
text_attentions = outputs.text_model_output.attentions
print(f"\nNumber of text layers: {len(text_attentions)}")
print(f"Text attention shape (first layer): {text_attentions[0].shape}")

## Step 5: Visualize Attention from Last Layer

The last layer attention is typically most interpretable for the final decision.

In [None]:
# Get last layer vision attention
last_layer_attn = vision_attentions[-1]  # [1, heads, seq_len, seq_len]
print(f"Last layer attention shape: {last_layer_attn.shape}")

# Average over heads
avg_attn = last_layer_attn[0].mean(0)  # [seq_len, seq_len]
print(f"Averaged attention shape: {avg_attn.shape}")

# For vision transformers, seq_len = num_patches + 1 (CLS token)
# CLIP uses 32x32 patches for 224x224 images = 7x7 = 49 patches + 1 CLS = 50 tokens
num_patches = int(np.sqrt(avg_attn.shape[0] - 1))
print(f"Number of patches per dimension: {num_patches}")

## Step 6: Visualize Attention Heatmap

In [None]:
# Plot attention heatmap
fig, ax = plot_attention_heatmap(
    avg_attn,
    layer_name="CLIP Vision (Last Layer)",
    figsize=(10, 8)
)
plt.show()

## Step 7: Overlay Attention on Image

We'll visualize how the CLS token attends to different image patches.

In [None]:
# Get CLS token attention to patches (row 0, columns 1:)
cls_to_patches = avg_attn[0, 1:]  # [num_patches^2]

# Normalize for visualization
cls_to_patches = (cls_to_patches - cls_to_patches.min()) / (cls_to_patches.max() - cls_to_patches.min())

# Overlay on image
fig, ax = plot_attention_on_image(
    image,
    cls_to_patches,
    num_patches=num_patches,
    alpha=0.6,
    cmap='hot'
)
plt.show()

## Step 8: Visualize Multiple Heads

Different attention heads may focus on different aspects.

In [None]:
from mechanistic.utils.visualization import plot_multi_head_attention

# Plot all heads from last layer
fig, axes = plot_multi_head_attention(
    last_layer_attn,
    layer_name="CLIP Vision (Last Layer)",
    max_heads=12,
    figsize=(20, 16)
)
plt.show()

## Step 9: Analyze Attention Across Layers

In [None]:
from mechanistic.utils.metrics import attention_entropy, attention_concentration

# Calculate metrics for each layer
entropies = []
concentrations = []

for layer_idx, attn in enumerate(vision_attentions):
    # Average over heads and batch
    avg_layer_attn = attn[0].mean(0)
    
    # Calculate entropy (how focused?)
    entropy = attention_entropy(avg_layer_attn, dim=-1).mean().item()
    entropies.append(entropy)
    
    # Calculate concentration (top-5 sum)
    concentration = attention_concentration(avg_layer_attn, k=5).mean().item()
    concentrations.append(concentration)

# Plot metrics across layers
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(entropies, marker='o')
ax1.set_xlabel('Layer')
ax1.set_ylabel('Attention Entropy')
ax1.set_title('Attention Entropy Across Layers')
ax1.grid(True, alpha=0.3)

ax2.plot(concentrations, marker='o', color='orange')
ax2.set_xlabel('Layer')
ax2.set_ylabel('Top-5 Concentration')
ax2.set_title('Attention Concentration Across Layers')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Average entropy: {np.mean(entropies):.3f}")
print(f"Average concentration: {np.mean(concentrations):.3f}")

## Next Steps

### ✅ What We've Validated:
1. We can extract attention weights from vision transformers
2. We can visualize attention as heatmaps
3. We can overlay attention on images
4. We can calculate attention metrics (entropy, concentration)

### 🚀 Next:
1. **Adapt for Qwen2.5-VL or PaliGemma**
   - These models have full VLM capabilities
   - Extract cross-attention (text → image)
   
2. **Run on Multiple Samples**
   - Correct vs incorrect predictions
   - Different spatial relations
   
3. **Build Analysis Pipeline**
   - Automated extraction
   - Statistical analysis
   - Summary visualizations

### 📝 Notes for Full VLM:
- Need to identify cross-attention layers (text attending to image)
- May need to use non-vLLM versions for attention access
- Consider using Captum or other attribution libraries for more advanced analysis