In [None]:
import sys
sys.path.append('../..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

import random
import torch
import pickle
from pathlib import Path

from dataset_zoo.aro_datasets import get_controlled_images_b, get_controlled_images_a
from mechanistic.utils.visualization import (
    plot_attention_heatmap, 
    plot_attention_on_image,
    plot_cross_attention_attribution
)
from mechanistic.utils.hooks import AttentionCache
from mechanistic.utils.metrics import attention_entropy, attention_concentration

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# ============================================================
# EXPERIMENT 2B: Statistics across all layers
# ============================================================

# Collect statistics for all layers
sample_idx = 0
sample = outputs[sample_idx]
num_layers = len(sample['attentions'])

layer_stats = []
for layer_idx in range(num_layers):
    attention_layer = sample['attentions'][layer_idx]
    avg_attn = attention_layer.mean(dim=1)[0]
    flat_attn = avg_attn.flatten().numpy()
    
    layer_stats.append({
        'layer': layer_idx,
        'mean': flat_attn.mean(),
        'std': flat_attn.std(),
        'max': flat_attn.max(),
        'min': flat_attn.min(),
        'median': np.median(flat_attn)
    })

# Plot statistics trends across layers
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

layers = [s['layer'] for s in layer_stats]
means = [s['mean'] for s in layer_stats]
stds = [s['std'] for s in layer_stats]
maxs = [s['max'] for s in layer_stats]
mins = [s['min'] for s in layer_stats]

axes[0, 0].plot(layers, means, marker='o', linewidth=2, markersize=6)
axes[0, 0].set_title('Mean Attention Weight per Layer', fontweight='bold')
axes[0, 0].set_xlabel('Layer')
axes[0, 0].set_ylabel('Mean')
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(layers, stds, marker='o', linewidth=2, markersize=6, color='orange')
axes[0, 1].set_title('Std Deviation per Layer', fontweight='bold')
axes[0, 1].set_xlabel('Layer')
axes[0, 1].set_ylabel('Std')
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(layers, maxs, marker='o', linewidth=2, markersize=6, color='green', label='Max')
axes[1, 0].plot(layers, mins, marker='o', linewidth=2, markersize=6, color='red', label='Min')
axes[1, 0].set_title('Max/Min Attention Weight per Layer', fontweight='bold')
axes[1, 0].set_xlabel('Layer')
axes[1, 0].set_ylabel('Value')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Plot range (max - min) per layer
ranges = [s['max'] - s['min'] for s in layer_stats]
axes[1, 1].plot(layers, ranges, marker='o', linewidth=2, markersize=6, color='purple')
axes[1, 1].set_title('Range (Max - Min) per Layer', fontweight='bold')
axes[1, 1].set_xlabel('Layer')
axes[1, 1].set_ylabel('Range')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print table of statistics
print("\n" + "="*80)
print("LAYER-WISE ATTENTION STATISTICS")
print("="*80)
print(f"{'Layer':<8} {'Mean':<10} {'Std':<10} {'Min':<10} {'Max':<10} {'Median':<10}")
print("-"*80)
for stats in layer_stats:
    print(f"{stats['layer']:<8} {stats['mean']:<10.6f} {stats['std']:<10.6f} "
          f"{stats['min']:<10.6f} {stats['max']:<10.6f} {stats['median']:<10.6f}")