# Two-Hop Attention Hypothesis Analysis

This notebook demonstrates how to analyze attention patterns to test the 2-hop hypothesis:

**Hypothesis:** Text tokens gather information from the image, and the last token reads from those text tokens (rather than directly from the image).

## Metrics Computed:

1. **Last Token Attention Distribution**: % attending to image vs text tokens
2. **Text Tokens Attention Distribution**: % attending to image vs other text tokens
3. **Attention Flow Score**: Correlation between text→image and last→text
4. **Information Bottleneck Score**: Which text tokens act as "hubs"
5. **Direct vs Indirect Image Access**: Compare last→image vs last→text→image
6. **Layer-wise Evolution**: How patterns change across layers
7. **Positional Analysis**: Which token types attend most to image

In [None]:
import sys
sys.path.append('../..')

import torch
from transformers import AutoTokenizer

from mechanistic.utils.two_hop_metrics import (
    compute_all_two_hop_metrics,
    print_metrics_summary,
    identify_token_ranges,
)
from mechanistic.utils.two_hop_visualization import (
    plot_complete_analysis,
    plot_attention_distribution_comparison,
    plot_attention_flow,
    plot_hub_tokens_analysis,
    plot_direct_vs_indirect,
)

import matplotlib.pyplot as plt
%matplotlib inline

## 1. Load Your Data

Load a sample that contains:
- `attentions`: Tuple of attention tensors from each layer
- `input_ids`: Token IDs
- Optional: `caption`, `question`, `pred`, `GT`, etc.

In [None]:
# Example: Load your sample data
# Replace with your actual data path
sample = torch.load('path/to/your/sample.pt')

# Or if you have a list of samples
# samples = torch.load('path/to/your/samples.pt')
# sample = samples[0]

print(f"Sample keys: {sample.keys()}")
print(f"Number of layers: {len(sample['attentions'])}")
print(f"Attention shape: {sample['attentions'][0].shape}")

## 2. Load Tokenizer

The tokenizer is needed to decode tokens and identify vision/text token boundaries.

In [None]:
# Load tokenizer for your model
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
# Or for Qwen3-VL:
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

## 3. Compute All Metrics

This single function computes all 2-hop hypothesis metrics.

In [None]:
# Compute all metrics
metrics = compute_all_two_hop_metrics(
    sample,
    tokenizer=tokenizer,
    layer_idx=-1,  # Analyze last layer
    average_heads=True,
    analyze_all_layers=True,  # Set to True to get layer-wise evolution
    keywords=["left", "right", "front", "behind", "where"]  # Customize for your task
)

# Print summary
print_metrics_summary(metrics, verbose=True)

## 4. Visualize Results

### 4.1 Attention Distribution Comparison

In [None]:
plot_attention_distribution_comparison(metrics)
plt.show()

### 4.2 Attention Flow Diagram

Shows the 2-hop path: Image → Text Tokens → Last Token

In [None]:
# Get attention tensor and token ranges
attention = sample['attentions'][-1]  # Last layer
vision_range, text_range = identify_token_ranges(sample['input_ids'], tokenizer)

plot_attention_flow(
    attention,
    vision_range,
    text_range,
    metrics,
    top_k_text_tokens=5
)
plt.show()

### 4.3 Hub Tokens Analysis

Identifies which text tokens act as "information hubs"

In [None]:
plot_hub_tokens_analysis(
    metrics,
    attention,
    vision_range,
    text_range
)
plt.show()

### 4.4 Direct vs Indirect Image Attention

In [None]:
plot_direct_vs_indirect(metrics)
plt.show()

# Interpretation
if metrics.indirect_to_direct_ratio > 1:
    print(f"✅ Supports 2-hop hypothesis: Indirect path is {metrics.indirect_to_direct_ratio:.2f}x stronger")
else:
    print(f"❌ Direct path dominates: {1/metrics.indirect_to_direct_ratio:.2f}x stronger than indirect")

### 4.5 Generate All Plots at Once

In [None]:
# Save all plots to a directory
plot_complete_analysis(
    metrics,
    attention,
    vision_range,
    text_range,
    output_dir="plots/two_hop_analysis",
    prefix="sample_0"
)

## 5. Analyze Multiple Samples

Compute aggregate statistics across many samples.

In [None]:
# Load multiple samples
samples = torch.load('path/to/your/samples.pt')
print(f"Loaded {len(samples)} samples")

# Analyze all samples
all_metrics = []
for i, sample in enumerate(samples[:50]):  # Analyze first 50
    try:
        metrics = compute_all_two_hop_metrics(
            sample,
            tokenizer=tokenizer,
            layer_idx=-1,
            analyze_all_layers=False  # Faster if you don't need layer-wise
        )
        all_metrics.append(metrics)
    except Exception as e:
        print(f"Error on sample {i}: {e}")
        continue

print(f"Successfully analyzed {len(all_metrics)} samples")

### 5.1 Compute Aggregate Statistics

In [None]:
import numpy as np

# Extract metrics
last_token_image_pcts = [m.last_token_image_pct for m in all_metrics]
text_tokens_image_pcts = [m.text_tokens_image_pct_mean for m in all_metrics]
attention_flow_scores = [m.attention_flow_score for m in all_metrics]
hub_scores = [m.information_bottleneck_score for m in all_metrics]
indirect_ratios = [m.indirect_to_direct_ratio for m in all_metrics]

print("\n" + "="*80)
print("AGGREGATE STATISTICS")
print("="*80)
print(f"\nLast Token → Image: {np.mean(last_token_image_pcts):.2f}% ± {np.std(last_token_image_pcts):.2f}%")
print(f"Text Tokens → Image: {np.mean(text_tokens_image_pcts):.2f}% ± {np.std(text_tokens_image_pcts):.2f}%")
print(f"\nAttention Flow Score: {np.mean(attention_flow_scores):.3f} ± {np.std(attention_flow_scores):.3f}")
print(f"Hub Score: {np.mean(hub_scores):.3f} ± {np.std(hub_scores):.3f}")
print(f"\nIndirect/Direct Ratio: {np.mean(indirect_ratios):.2f}x ± {np.std(indirect_ratios):.2f}x")
print(f"% samples supporting 2-hop (ratio > 1): {np.mean([r > 1 for r in indirect_ratios]) * 100:.1f}%")
print("="*80)

### 5.2 Plot Aggregate Distributions

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Last token image attention
axes[0, 0].hist(last_token_image_pcts, bins=30, alpha=0.7, color='#FF6B6B', edgecolor='black')
axes[0, 0].axvline(np.mean(last_token_image_pcts), color='red', linestyle='--', linewidth=2)
axes[0, 0].set_xlabel('Last Token → Image (%)')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Last Token Image Attention')
axes[0, 0].grid(True, alpha=0.3)

# Attention flow score
axes[0, 1].hist(attention_flow_scores, bins=30, alpha=0.7, color='#4ECDC4', edgecolor='black')
axes[0, 1].axvline(np.mean(attention_flow_scores), color='red', linestyle='--', linewidth=2)
axes[0, 1].set_xlabel('Attention Flow Score')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Attention Flow Score (Correlation)')
axes[0, 1].grid(True, alpha=0.3)

# Indirect/Direct ratio
axes[1, 0].hist(indirect_ratios, bins=30, alpha=0.7, color='#95E1D3', edgecolor='black')
axes[1, 0].axvline(np.mean(indirect_ratios), color='red', linestyle='--', linewidth=2)
axes[1, 0].axvline(1.0, color='gray', linestyle=':', linewidth=2, label='Equal (1.0x)')
axes[1, 0].set_xlabel('Indirect/Direct Ratio')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Indirect vs Direct Ratio')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Scatter: Last token vs Text tokens
axes[1, 1].scatter(text_tokens_image_pcts, last_token_image_pcts, alpha=0.5)
axes[1, 1].set_xlabel('Text Tokens → Image (mean %)')
axes[1, 1].set_ylabel('Last Token → Image (%)')
axes[1, 1].set_title('Last Token vs Text Tokens Image Attention')
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle(f'2-Hop Hypothesis: Aggregate Statistics (n={len(all_metrics)})', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

## 6. Interpretation Guide

### Strong Evidence for 2-Hop:
- ✅ Last token has LOW image attention (< 20%)
- ✅ Text tokens have HIGH image attention (> 40%)
- ✅ Attention flow score is POSITIVE and HIGH (> 0.5)
- ✅ Indirect/Direct ratio > 1 (indirect path is stronger)

### Weak or No Evidence:
- ❌ Last token has HIGH image attention (> 50%)
- ❌ Attention flow score is NEGATIVE or LOW (< 0.2)
- ❌ Indirect/Direct ratio < 1 (direct path dominates)

### Hub Tokens:
- Look for **object names** and **spatial relation words** acting as hubs
- High hub scores (> 0.5) indicate strong bottleneck effect