# FFN Expansion Analysis

**Paper #3 Hypothesis Test**

## Motivation

The Restriction Maps experiment showed:
- **Attention contracts**: All contraction ratios < 1
- **But W_V explodes**: ||W_V|| grows from 5 to 28 in late layers

**New Hypothesis:**
> Attention compresses information (sheaf diffusion onto H⁰).
> FFN expands information (for prediction/logit separation).

## What We Measure

1. **FFN Gain**: ||FFN(x)|| / ||x|| per layer
2. **Attention Gain**: ||Attn(x)|| / ||x|| per layer
3. **Residual Contribution**: How much does each component change the hidden state?

**Prediction:**
- Attention Gain < 1 (contraction) for l < L*
- FFN Gain > 1 (expansion) for l > L*

**Author:** Davide D'Elia  
**Date:** 2026-01-05

## 1. Setup

In [None]:
# Install dependencies
!pip install -q transformers accelerate matplotlib seaborn

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Model

In [None]:
# Model configuration
MODEL_NAME = "EleutherAI/pythia-1.4b"
# MODEL_NAME = "EleutherAI/pythia-6.9b"  # Larger model

print(f"Loading {MODEL_NAME}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
model.eval()

n_layers = model.config.num_hidden_layers
hidden_dim = model.config.hidden_size

print(f"\nModel Configuration:")
print(f"  Layers: {n_layers}")
print(f"  Hidden Dim: {hidden_dim}")
print(f"  Intermediate Dim: {model.config.intermediate_size}")

## 3. Setup Hooks for FFN and Attention

We'll use PyTorch hooks to capture:
- Input to Attention
- Output of Attention
- Input to FFN (MLP)
- Output of FFN (MLP)

In [None]:
class ActivationCapture:
    """
    Captures activations at specific points in the model.
    """
    def __init__(self):
        self.activations = defaultdict(dict)
        self.hooks = []
    
    def clear(self):
        self.activations = defaultdict(dict)
    
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def register_hooks(self, model):
        """Register hooks on attention and MLP modules."""
        self.remove_hooks()
        
        for layer_idx in range(model.config.num_hidden_layers):
            layer = model.gpt_neox.layers[layer_idx]
            
            # Hook for attention input (post layer norm)
            def make_attn_input_hook(idx):
                def hook(module, input, output):
                    # input_layernorm output = attention input
                    self.activations[idx]['attn_input'] = output.detach()
                return hook
            
            # Hook for attention output
            def make_attn_output_hook(idx):
                def hook(module, input, output):
                    # output[0] is the attention output
                    self.activations[idx]['attn_output'] = output[0].detach()
                return hook
            
            # Hook for MLP input (post layer norm)
            def make_mlp_input_hook(idx):
                def hook(module, input, output):
                    self.activations[idx]['mlp_input'] = output.detach()
                return hook
            
            # Hook for MLP output
            def make_mlp_output_hook(idx):
                def hook(module, input, output):
                    self.activations[idx]['mlp_output'] = output.detach()
                return hook
            
            # Register hooks
            h1 = layer.input_layernorm.register_forward_hook(make_attn_input_hook(layer_idx))
            h2 = layer.attention.register_forward_hook(make_attn_output_hook(layer_idx))
            h3 = layer.post_attention_layernorm.register_forward_hook(make_mlp_input_hook(layer_idx))
            h4 = layer.mlp.register_forward_hook(make_mlp_output_hook(layer_idx))
            
            self.hooks.extend([h1, h2, h3, h4])
        
        print(f"Registered {len(self.hooks)} hooks on {model.config.num_hidden_layers} layers")

# Create capture object and register hooks
capture = ActivationCapture()
capture.register_hooks(model)

## 4. Run Forward Pass and Capture Activations

In [None]:
# Test prompts (same as restriction maps for comparison)
TEST_PROMPTS = [
    "The capital of France is Paris, which is known for the Eiffel Tower.",
    "Machine learning models learn patterns from data by optimizing loss functions.",
    "The transformer architecture uses self-attention to process sequences.",
    "Once upon a time in a faraway kingdom, there lived a wise old wizard.",
    "Water boils at 100 degrees Celsius under standard atmospheric pressure.",
    "If all mammals are warm-blooded and whales are mammals, then whales are warm-blooded.",
    "Functional programming emphasizes immutability and pure functions without side effects.",
    "The speed of light in a vacuum is approximately 299,792,458 meters per second.",
]

print(f"Using {len(TEST_PROMPTS)} test prompts")

In [None]:
def compute_gains(capture, n_layers):
    """
    Compute gain metrics from captured activations.
    
    Gain = ||output|| / ||input||
    """
    gains = {
        'attn_gain': [],      # ||attn_out|| / ||attn_in||
        'mlp_gain': [],       # ||mlp_out|| / ||mlp_in||
        'attn_norm_in': [],   # ||attn_in||
        'attn_norm_out': [],  # ||attn_out||
        'mlp_norm_in': [],    # ||mlp_in||
        'mlp_norm_out': [],   # ||mlp_out||
    }
    
    for layer_idx in range(n_layers):
        acts = capture.activations[layer_idx]
        
        # Attention gain
        if 'attn_input' in acts and 'attn_output' in acts:
            attn_in = acts['attn_input'].float()
            attn_out = acts['attn_output'].float()
            
            # Compute mean norm over all tokens
            norm_in = torch.norm(attn_in, dim=-1).mean().item()
            norm_out = torch.norm(attn_out, dim=-1).mean().item()
            
            gains['attn_norm_in'].append(norm_in)
            gains['attn_norm_out'].append(norm_out)
            gains['attn_gain'].append(norm_out / (norm_in + 1e-10))
        else:
            gains['attn_norm_in'].append(0)
            gains['attn_norm_out'].append(0)
            gains['attn_gain'].append(0)
        
        # MLP gain
        if 'mlp_input' in acts and 'mlp_output' in acts:
            mlp_in = acts['mlp_input'].float()
            mlp_out = acts['mlp_output'].float()
            
            norm_in = torch.norm(mlp_in, dim=-1).mean().item()
            norm_out = torch.norm(mlp_out, dim=-1).mean().item()
            
            gains['mlp_norm_in'].append(norm_in)
            gains['mlp_norm_out'].append(norm_out)
            gains['mlp_gain'].append(norm_out / (norm_in + 1e-10))
        else:
            gains['mlp_norm_in'].append(0)
            gains['mlp_norm_out'].append(0)
            gains['mlp_gain'].append(0)
    
    return gains

# Collect gains over all prompts
all_gains = []

print("Processing prompts...")
for prompt in tqdm(TEST_PROMPTS):
    capture.clear()
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        _ = model(**inputs)
    
    gains = compute_gains(capture, n_layers)
    all_gains.append(gains)

print("\nDone!")

In [None]:
# Average gains across all prompts
avg_gains = {
    'attn_gain': np.mean([g['attn_gain'] for g in all_gains], axis=0),
    'mlp_gain': np.mean([g['mlp_gain'] for g in all_gains], axis=0),
    'attn_norm_in': np.mean([g['attn_norm_in'] for g in all_gains], axis=0),
    'attn_norm_out': np.mean([g['attn_norm_out'] for g in all_gains], axis=0),
    'mlp_norm_in': np.mean([g['mlp_norm_in'] for g in all_gains], axis=0),
    'mlp_norm_out': np.mean([g['mlp_norm_out'] for g in all_gains], axis=0),
}

layers = list(range(n_layers))

print("Average Gains per Layer:")
print("\nLayer | Attn Gain | MLP Gain | Attn > 1? | MLP > 1?")
print("-" * 55)
for l in layers:
    attn_g = avg_gains['attn_gain'][l]
    mlp_g = avg_gains['mlp_gain'][l]
    print(f"  {l:2d}  |   {attn_g:.4f}  |  {mlp_g:.4f}  |    {'YES' if attn_g > 1 else 'no '}    |   {'YES' if mlp_g > 1 else 'no '}")

## 5. Visualize Results

In [None]:
# Find transition points
attn_gains = avg_gains['attn_gain']
mlp_gains = avg_gains['mlp_gain']

# L* for attention (minimum gain = maximum contraction)
L_star_attn = np.argmin(attn_gains)

# L* for MLP (where gain crosses 1, or maximum)
L_star_mlp = np.argmax(mlp_gains)

print(f"Transition Points:")
print(f"  Attention min gain: Layer {L_star_attn} (gain = {attn_gains[L_star_attn]:.4f})")
print(f"  MLP max gain: Layer {L_star_mlp} (gain = {mlp_gains[L_star_mlp]:.4f})")

In [None]:
# Main visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Attention vs MLP Gain
ax1 = axes[0, 0]
ax1.plot(layers, attn_gains, 'b-', linewidth=2, marker='o', markersize=4, label='Attention Gain')
ax1.plot(layers, mlp_gains, 'r-', linewidth=2, marker='s', markersize=4, label='MLP/FFN Gain')
ax1.axhline(y=1.0, color='gray', linestyle='--', linewidth=2, label='Neutral (gain=1)')
ax1.axvline(x=L_star_attn, color='blue', linestyle=':', linewidth=1.5, alpha=0.7)
ax1.axvline(x=L_star_mlp, color='red', linestyle=':', linewidth=1.5, alpha=0.7)
ax1.fill_between(layers, attn_gains, 1, where=[g < 1 for g in attn_gains], 
                  alpha=0.2, color='blue', label='Attn Contraction')
ax1.fill_between(layers, mlp_gains, 1, where=[g > 1 for g in mlp_gains], 
                  alpha=0.2, color='red', label='MLP Expansion')
ax1.set_xlabel('Layer', fontsize=12)
ax1.set_ylabel('Gain (||output|| / ||input||)', fontsize=12)
ax1.set_title('Attention vs MLP Gain per Layer', fontsize=14)
ax1.legend(fontsize=9, loc='best')
ax1.grid(True, alpha=0.3)

# Plot 2: Norms
ax2 = axes[0, 1]
ax2.plot(layers, avg_gains['attn_norm_out'], 'b-', linewidth=2, marker='o', markersize=4, label='Attn Output Norm')
ax2.plot(layers, avg_gains['mlp_norm_out'], 'r-', linewidth=2, marker='s', markersize=4, label='MLP Output Norm')
ax2.set_xlabel('Layer', fontsize=12)
ax2.set_ylabel('Mean Norm', fontsize=12)
ax2.set_title('Output Norms per Layer', fontsize=14)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

# Plot 3: Ratio of gains (MLP/Attn)
ax3 = axes[1, 0]
gain_ratio = np.array(mlp_gains) / (np.array(attn_gains) + 1e-10)
ax3.plot(layers, gain_ratio, 'g-', linewidth=2, marker='^', markersize=4)
ax3.axhline(y=1.0, color='gray', linestyle='--', linewidth=2, label='Equal gains')
ax3.fill_between(layers, gain_ratio, 1, where=[r > 1 for r in gain_ratio], 
                  alpha=0.3, color='red', label='MLP dominates')
ax3.fill_between(layers, gain_ratio, 1, where=[r < 1 for r in gain_ratio], 
                  alpha=0.3, color='blue', label='Attn dominates')
ax3.set_xlabel('Layer', fontsize=12)
ax3.set_ylabel('MLP Gain / Attn Gain', fontsize=12)
ax3.set_title('Relative Contribution: MLP vs Attention', fontsize=14)
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)

# Plot 4: Combined view with phases
ax4 = axes[1, 1]
combined_gain = np.array(attn_gains) * np.array(mlp_gains)
ax4.plot(layers, combined_gain, 'm-', linewidth=2, marker='d', markersize=4, label='Combined Gain (Attn × MLP)')
ax4.axhline(y=1.0, color='gray', linestyle='--', linewidth=2)
ax4.fill_between(layers, combined_gain, 1, where=[g > 1 for g in combined_gain], 
                  alpha=0.3, color='green', label='Net Expansion')
ax4.fill_between(layers, combined_gain, 1, where=[g < 1 for g in combined_gain], 
                  alpha=0.3, color='purple', label='Net Contraction')
ax4.set_xlabel('Layer', fontsize=12)
ax4.set_ylabel('Combined Gain', fontsize=12)
ax4.set_title('Net Effect: Attn × MLP Gain', fontsize=14)
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3)

plt.suptitle(f'{MODEL_NAME}: FFN vs Attention Expansion Analysis\n(Hypothesis: Attention contracts, FFN expands)', 
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('ffn_expansion_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n>>> Figure saved as 'ffn_expansion_analysis.png'")

## 6. Phase Analysis

In [None]:
# Identify phases based on gains
print("="*60)
print("PHASE ANALYSIS")
print("="*60)

# Count layers in each regime
attn_contracting = sum(1 for g in attn_gains if g < 1)
attn_expanding = sum(1 for g in attn_gains if g > 1)
mlp_contracting = sum(1 for g in mlp_gains if g < 1)
mlp_expanding = sum(1 for g in mlp_gains if g > 1)

print(f"\nAttention:")
print(f"  Contracting (gain < 1): {attn_contracting}/{n_layers} layers")
print(f"  Expanding (gain > 1):   {attn_expanding}/{n_layers} layers")
print(f"  Min gain: {min(attn_gains):.4f} at layer {np.argmin(attn_gains)}")
print(f"  Max gain: {max(attn_gains):.4f} at layer {np.argmax(attn_gains)}")

print(f"\nMLP/FFN:")
print(f"  Contracting (gain < 1): {mlp_contracting}/{n_layers} layers")
print(f"  Expanding (gain > 1):   {mlp_expanding}/{n_layers} layers")
print(f"  Min gain: {min(mlp_gains):.4f} at layer {np.argmin(mlp_gains)}")
print(f"  Max gain: {max(mlp_gains):.4f} at layer {np.argmax(mlp_gains)}")

# Net effect
combined = np.array(attn_gains) * np.array(mlp_gains)
net_contracting = sum(1 for g in combined if g < 1)
net_expanding = sum(1 for g in combined if g > 1)

print(f"\nNet Effect (Attn × MLP):")
print(f"  Net Contracting: {net_contracting}/{n_layers} layers")
print(f"  Net Expanding:   {net_expanding}/{n_layers} layers")
print("="*60)

In [None]:
# Detailed layer-by-layer analysis
print("\nDetailed Layer Analysis:")
print("="*70)
print(f"{'Layer':^6} | {'Attn Gain':^10} | {'MLP Gain':^10} | {'Combined':^10} | {'Phase':^15}")
print("-"*70)

for l in layers:
    ag = attn_gains[l]
    mg = mlp_gains[l]
    cg = ag * mg
    
    # Determine phase
    if ag < 1 and mg < 1:
        phase = "COMPRESS"
    elif ag < 1 and mg > 1:
        phase = "ATTN-/MLP+"
    elif ag > 1 and mg < 1:
        phase = "ATTN+/MLP-"
    else:
        phase = "EXPAND"
    
    # Add net indicator
    if cg < 0.9:
        net = "[NET-]"
    elif cg > 1.1:
        net = "[NET+]"
    else:
        net = "[~1.0]"
    
    print(f"{l:^6} | {ag:^10.4f} | {mg:^10.4f} | {cg:^10.4f} | {phase:^10} {net}")

print("="*70)

## 7. Comparison with Restriction Maps Results

In [None]:
# Reference values from restriction maps experiment (Pythia-1.4B)
RESTRICTION_MAPS_REF = {
    'L_star_contraction': 10,  # Layer with minimum contraction ratio
    'L_star_entropy': 10,       # Layer with minimum attention entropy
    'contraction_ratios': [0.29, 0.26, 0.26, 0.30, 0.20, 0.21, 0.18, 0.20, 0.23, 0.26,
                           0.12, 0.16, 0.15, 0.18, 0.15, 0.18, 0.17, 0.22, 0.18, 0.17, 0.14]
}

print("="*60)
print("COMPARISON WITH RESTRICTION MAPS")
print("="*60)

print(f"\nRestriction Maps (Attention-only):")
print(f"  L* (min contraction): Layer {RESTRICTION_MAPS_REF['L_star_contraction']}")
print(f"  Contraction ratio at L*: {RESTRICTION_MAPS_REF['contraction_ratios'][10]:.4f}")

print(f"\nThis Experiment (Attention + MLP):")
print(f"  Attention min gain: Layer {L_star_attn} (gain = {attn_gains[L_star_attn]:.4f})")
print(f"  MLP max gain: Layer {L_star_mlp} (gain = {mlp_gains[L_star_mlp]:.4f})")

# Check correlation
if len(RESTRICTION_MAPS_REF['contraction_ratios']) >= len(attn_gains):
    ref_ratios = RESTRICTION_MAPS_REF['contraction_ratios'][:len(attn_gains)]
    correlation = np.corrcoef(ref_ratios, attn_gains)[0, 1]
    print(f"\nCorrelation (Restriction Ratio vs Attn Gain): {correlation:.4f}")

print("="*60)

## 8. Summary and Export

In [None]:
import json

# Prepare summary
summary = {
    'model': MODEL_NAME,
    'n_layers': int(n_layers),
    'hidden_dim': int(hidden_dim),
    'n_prompts': len(TEST_PROMPTS),
    
    'attention': {
        'gains': [float(g) for g in attn_gains],
        'min_gain': float(min(attn_gains)),
        'max_gain': float(max(attn_gains)),
        'L_star_min': int(np.argmin(attn_gains)),
        'n_contracting': int(attn_contracting),
        'n_expanding': int(attn_expanding)
    },
    
    'mlp': {
        'gains': [float(g) for g in mlp_gains],
        'min_gain': float(min(mlp_gains)),
        'max_gain': float(max(mlp_gains)),
        'L_star_max': int(np.argmax(mlp_gains)),
        'n_contracting': int(mlp_contracting),
        'n_expanding': int(mlp_expanding)
    },
    
    'combined': {
        'gains': [float(g) for g in combined],
        'n_net_contracting': int(net_contracting),
        'n_net_expanding': int(net_expanding)
    },
    
    'hypothesis_test': {
        'attention_contracts': bool(attn_contracting > attn_expanding),
        'mlp_expands': bool(mlp_expanding > mlp_contracting),
        'hypothesis_confirmed': bool(attn_contracting > attn_expanding and mlp_expanding > mlp_contracting)
    }
}

# Save to JSON
with open('ffn_expansion_results.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("="*60)
print("FFN EXPANSION ANALYSIS SUMMARY")
print("="*60)
print(f"\nModel: {MODEL_NAME}")
print(f"Prompts: {len(TEST_PROMPTS)}")
print(f"\nHypothesis: 'Attention contracts, FFN expands'")
print(f"\nResults:")
print(f"  Attention contracting: {attn_contracting}/{n_layers} layers")
print(f"  MLP expanding: {mlp_expanding}/{n_layers} layers")
print(f"\nHypothesis Confirmed: {summary['hypothesis_test']['hypothesis_confirmed']}")
print(f"\nFiles saved:")
print(f"  - ffn_expansion_analysis.png")
print(f"  - ffn_expansion_results.json")
print("="*60)

In [None]:
# Create ZIP archive
import zipfile
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"ffn_expansion_results_{timestamp}.zip"

with zipfile.ZipFile(zip_filename, 'w') as zipf:
    zipf.write('ffn_expansion_analysis.png')
    zipf.write('ffn_expansion_results.json')

print(f">>> Created: {zip_filename}")

## 9. Interpretation

### Original Hypothesis

From the Restriction Maps experiment:
- Contraction ratios were ALL < 1 (attention always contracts)
- But ||W_V|| explodes in late layers

**Revised Hypothesis:**
> Attention compresses (sheaf diffusion onto H⁰).
> FFN expands (for prediction/logit separation).

### Expected Pattern

```
Layer:    0 -------- L* -------- N
          |          |           |
Attn:     Contract   Min         Contract
MLP:      Contract   Transition  EXPAND
Net:      Contract   ~Neutral    EXPAND
```

### Theoretical Significance

If confirmed:
1. **Division of Labor**: Attention and FFN have complementary roles
2. **Sheaf Interpretation**: Attention = sheaf diffusion (compression)
3. **Prediction Mechanism**: FFN = amplification for logit separation

This refines the Sheaf-theoretic framework:
- Original: "Attention does both contraction and expansion"
- Revised: "Attention contracts (sheaf), FFN expands (prediction)"

## 10. Download Results

In [None]:
# Download all results
from google.colab import files

print("Downloading result files...")
print()

# Download ZIP
print(f"1. ZIP Archive: {zip_filename}")
files.download(zip_filename)

# Individual files
print("\n2. Individual files:")
print("   - ffn_expansion_analysis.png")
files.download('ffn_expansion_analysis.png')
print("   - ffn_expansion_results.json")
files.download('ffn_expansion_results.json')

print("\n>>> All files downloaded!")

In [None]:
# Cleanup
capture.remove_hooks()
print("Hooks removed.")