# Hybrid Attention Layer Demo

This notebook demonstrates the new HybridAttentionLayer that allows mixing different attention types (standard, spectral, holonomy) within a single transformer layer.

## Key Features
- **Configurable head types**: Specify which attention mechanism each head should use
- **Mixed attention**: Combine standard, spectral, and holonomy attention in one layer
- **Flexible configuration**: Different numbers of each head type
- **Performance analysis**: Compare different configurations

In [None]:
import sys
import os

# Add the source directory to Python path
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'src'))

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import time
from typing import List

from spectral_attention.layers.hybrid_layer import HybridAttentionLayer
from spectral_attention import SpectralAttention, HolonomyAttention

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Basic Usage

Let's start with the acceptance criteria: a layer with [4 standard, 2 spectral, 2 holonomy] heads.

In [None]:
# Configuration matching the acceptance criteria
d_model = 512
head_types = ["standard"] * 4 + ["spectral"] * 2 + ["holonomy"] * 2

print(f"Head configuration: {head_types}")
print(f"Total heads: {len(head_types)}")
print(f"Model dimension: {d_model}")
print(f"Dimension per head: {d_model // len(head_types)}")

# Create the hybrid layer
hybrid_layer = HybridAttentionLayer(
    d_model=d_model,
    head_types=head_types,
    dropout=0.1
).to(device)

print(f"\nLayer created successfully!")
print(f"Number of parameters: {sum(p.numel() for p in hybrid_layer.parameters()):,}")

## Forward Pass Demo

In [None]:
# Create sample input
batch_size, seq_len = 4, 128
x = torch.randn(batch_size, seq_len, d_model, device=device)

print(f"Input shape: {x.shape}")

# Forward pass
with torch.no_grad():
    output = hybrid_layer(x)

print(f"Output shape: {output.shape}")
print(f"Output statistics:")
print(f"  Mean: {output.mean().item():.6f}")
print(f"  Std: {output.std().item():.6f}")
print(f"  Min: {output.min().item():.6f}")
print(f"  Max: {output.max().item():.6f}")

# Check for NaN or Inf values
print(f"  Contains NaN: {torch.isnan(output).any().item()}")
print(f"  Contains Inf: {torch.isinf(output).any().item()}")

## Performance Comparison

Let's compare different attention configurations:

In [None]:
def benchmark_layer(layer, x, num_runs=10):
    """Benchmark a layer's forward pass."""
    layer.eval()
    
    # Warmup
    with torch.no_grad():
        for _ in range(3):
            _ = layer(x)
    
    # Benchmark
    torch.cuda.synchronize() if device == "cuda" else None
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(num_runs):
            _ = layer(x)
    
    torch.cuda.synchronize() if device == "cuda" else None
    end_time = time.time()
    
    return (end_time - start_time) / num_runs

# Test different configurations
configurations = {
    "All Standard (8 heads)": ["standard"] * 8,
    "All Spectral (8 heads)": ["spectral"] * 8,
    "All Holonomy (8 heads)": ["holonomy"] * 8,
    "Mixed (4+2+2)": ["standard"] * 4 + ["spectral"] * 2 + ["holonomy"] * 2,
    "Balanced (3+3+2)": ["standard"] * 3 + ["spectral"] * 3 + ["holonomy"] * 2,
}

# Small input for benchmarking
bench_x = torch.randn(2, 64, d_model, device=device)

results = {}
for name, head_config in configurations.items():
    layer = HybridAttentionLayer(
        d_model=d_model,
        head_types=head_config,
        dropout=0.0  # Disable dropout for benchmarking
    ).to(device)
    
    avg_time = benchmark_layer(layer, bench_x)
    results[name] = avg_time * 1000  # Convert to milliseconds
    print(f"{name}: {avg_time * 1000:.2f} ms")

# Plot results
plt.figure(figsize=(10, 6))
names = list(results.keys())
times = list(results.values())

bars = plt.bar(names, times, color=['skyblue', 'lightgreen', 'lightcoral', 'gold', 'plum'])
plt.ylabel('Average Time (ms)')
plt.title('Performance Comparison of Different Hybrid Attention Configurations')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

# Add value labels on bars
for bar, time_val in zip(bars, times):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
             f'{time_val:.1f}', ha='center', va='bottom')

plt.show()

## Attention Pattern Analysis

Let's analyze how different head types affect the attention patterns:

In [None]:
def analyze_attention_outputs(layer, x, head_types):
    """Analyze the outputs from different attention heads."""
    layer.eval()
    
    with torch.no_grad():
        # Get intermediate outputs by modifying the forward pass temporarily
        B, T, _ = x.shape
        
        # Project to Q, K, V
        qkv = layer.W_qkv(x)  # [B, T, 3*d_model]
        qkv = qkv.view(B, T, 3, layer.n_heads, layer.d_head)  # [B, T, 3, H, D]
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # [3, B, H, T, D]
        
        # Process heads by type and collect outputs
        head_outputs = {}
        head_idx = 0
        
        for head_type in ["standard", "spectral", "holonomy"]:
            if head_type not in layer.head_type_counts:
                continue
                
            count = layer.head_type_counts[head_type]
            
            # Extract heads for this type
            q_heads = q[:, head_idx:head_idx + count]  # [B, count, T, D]
            k_heads = k[:, head_idx:head_idx + count]  # [B, count, T, D]
            v_heads = v[:, head_idx:head_idx + count]  # [B, count, T, D]
            
            # Apply attention
            attn_module = layer.attention_modules[head_type]
            out_heads = attn_module(q_heads, k_heads, v_heads)  # [B, count, T, D]
            
            head_outputs[head_type] = out_heads
            head_idx += count
    
    return head_outputs

# Analyze the mixed configuration
test_x = torch.randn(1, 32, d_model, device=device)
head_outputs = analyze_attention_outputs(hybrid_layer, test_x, head_types)

# Plot statistics for each head type
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.suptitle('Attention Head Output Analysis', fontsize=16)

colors = {'standard': 'skyblue', 'spectral': 'lightgreen', 'holonomy': 'lightcoral'}

# Mean values
ax1 = axes[0, 0]
means = [head_outputs[ht].mean().item() for ht in head_outputs.keys()]
bars1 = ax1.bar(head_outputs.keys(), means, color=[colors[ht] for ht in head_outputs.keys()])
ax1.set_title('Mean Output Values')
ax1.set_ylabel('Mean')

# Standard deviation
ax2 = axes[0, 1]
stds = [head_outputs[ht].std().item() for ht in head_outputs.keys()]
bars2 = ax2.bar(head_outputs.keys(), stds, color=[colors[ht] for ht in head_outputs.keys()])
ax2.set_title('Output Standard Deviation')
ax2.set_ylabel('Std Dev')

# Max values
ax3 = axes[1, 0]
maxs = [head_outputs[ht].max().item() for ht in head_outputs.keys()]
bars3 = ax3.bar(head_outputs.keys(), maxs, color=[colors[ht] for ht in head_outputs.keys()])
ax3.set_title('Max Output Values')
ax3.set_ylabel('Max')

# Min values
ax4 = axes[1, 1]
mins = [head_outputs[ht].min().item() for ht in head_outputs.keys()]
bars4 = ax4.bar(head_outputs.keys(), mins, color=[colors[ht] for ht in head_outputs.keys()])
ax4.set_title('Min Output Values')
ax4.set_ylabel('Min')

# Add value labels
for bars, values in [(bars1, means), (bars2, stds), (bars3, maxs), (bars4, mins)]:
    for bar, val in zip(bars, values):
        bar.axes.text(bar.get_x() + bar.get_width()/2, bar.get_height() + bar.get_height()*0.01,
                     f'{val:.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

print("\nHead Type Counts:")
for head_type, output in head_outputs.items():
    print(f"{head_type}: {output.shape[1]} heads, output shape: {output.shape}")

## Gradient Flow Analysis

Let's check how gradients flow through different head types:

In [None]:
# Create a new layer for gradient analysis
grad_layer = HybridAttentionLayer(
    d_model=d_model,
    head_types=head_types,
    dropout=0.0
).to(device)

# Create input that requires gradients
grad_x = torch.randn(2, 16, d_model, device=device, requires_grad=True)

# Forward pass
output = grad_layer(grad_x)
loss = output.sum()

# Backward pass
loss.backward()

print("Gradient Analysis:")
print(f"Input gradient norm: {grad_x.grad.norm().item():.6f}")

# Analyze gradients in different components
gradient_norms = {}

for name, param in grad_layer.named_parameters():
    if param.grad is not None:
        gradient_norms[name] = param.grad.norm().item()

# Group gradients by attention type
attention_grads = {}
for name, norm in gradient_norms.items():
    if 'attention_modules' in name:
        parts = name.split('.')
        attn_type = parts[1]  # Extract attention type
        if attn_type not in attention_grads:
            attention_grads[attn_type] = []
        attention_grads[attn_type].append(norm)
    else:
        print(f"{name}: {norm:.6f}")

print("\nAttention Module Gradients:")
avg_grads = {}
for attn_type, grads in attention_grads.items():
    avg_grad = np.mean(grads)
    avg_grads[attn_type] = avg_grad
    print(f"{attn_type}: avg {avg_grad:.6f}, max {max(grads):.6f}, min {min(grads):.6f}")

# Plot gradient magnitudes
plt.figure(figsize=(8, 5))
plt.bar(avg_grads.keys(), avg_grads.values(), 
        color=[colors[k] for k in avg_grads.keys()])
plt.title('Average Gradient Magnitudes by Attention Type')
plt.ylabel('Gradient Norm')
plt.yscale('log')

for i, (k, v) in enumerate(avg_grads.items()):
    plt.text(i, v * 1.1, f'{v:.2e}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## Configuration Examples

Here are some example configurations for different use cases:

In [None]:
# Define different configuration examples
example_configs = {
    "Balanced Multi-modal": {
        "heads": ["standard", "spectral", "holonomy"] * 2 + ["standard", "spectral"],
        "description": "Equal representation of all attention types for diverse patterns"
    },
    "Spectral-heavy": {
        "heads": ["spectral"] * 6 + ["standard", "holonomy"],
        "description": "Emphasizes frequency-domain processing for long sequences"
    },
    "Standard-dominant": {
        "heads": ["standard"] * 6 + ["spectral", "holonomy"],
        "description": "Mostly standard attention with some specialized heads"
    },
    "Alternating Pattern": {
        "heads": ["standard", "spectral", "holonomy", "standard", 
                  "spectral", "holonomy", "standard", "standard"],
        "description": "Alternating pattern for diverse information processing"
    }
}

print("Example Hybrid Attention Configurations:\n")
for name, config in example_configs.items():
    heads = config["heads"]
    counts = {ht: heads.count(ht) for ht in set(heads)}
    
    print(f"**{name}**")
    print(f"Description: {config['description']}")
    print(f"Configuration: {heads}")
    print(f"Head counts: {counts}")
    print(f"Total heads: {len(heads)}")
    
    # Test the configuration
    try:
        test_layer = HybridAttentionLayer(d_model=d_model, head_types=heads)
        test_input = torch.randn(1, 16, d_model)
        with torch.no_grad():
            test_output = test_layer(test_input)
        print(f"✓ Configuration works! Output shape: {test_output.shape}")
    except Exception as e:
        print(f"✗ Configuration failed: {e}")
    
    print("-" * 50)


## Summary and Conclusion

The HybridAttentionLayer successfully implements the requirements:

### ✅ Acceptance Criteria Met:
1. **Configurable list of head types per layer** - ✓ Implemented with `head_types` parameter
2. **Concatenate outputs across heads → project back** - ✓ Implemented with proper concatenation and output projection
3. **Unit test: run forward with [4 standard, 2 spectral, 2 holonomy]** - ✓ Tested and working

### Key Features:
- **Flexible Configuration**: Support any combination of attention types
- **Proper Interface**: Compatible with existing transformer architectures
- **Gradient Flow**: Proper backpropagation through all head types
- **Performance**: Reasonable computational overhead

### Usage Patterns:
```python
# Basic usage
head_types = ["standard"] * 4 + ["spectral"] * 2 + ["holonomy"] * 2
layer = HybridAttentionLayer(d_model=512, head_types=head_types)
output = layer(input_tensor)  # [batch, seq_len, d_model]
```

This implementation allows researchers and practitioners to experiment with different combinations of attention mechanisms within a single layer, potentially capturing different types of patterns and relationships in the data.