# Restriction Maps Extraction: Empirical Validation

**Paper #3 Core Experiment**

## Theoretical Prediction

From the Sheaf-Theoretic framework, Transformer attention defines **Restriction Maps**:

$$\rho_{ij} = \sqrt{A_{ij}} \cdot W_V$$

where:
- $A_{ij}$ = attention weight from token $i$ to token $j$
- $W_V$ = value projection matrix
- $\rho_{ij}$ = restriction map (linear transformation for "message passing")

## Key Predictions to Validate

1. **Contraction vs Expansion**: $\|\rho_{ij}\| < 1$ for $l < L^*$, $\|\rho_{ij}\| > 1$ for $l > L^*$
2. **Spectral Gap**: Sheaf Laplacian $L_F = \delta^T \delta$ shows distinct spectral signature at $L^*$
3. **Consistency Measure**: Global section existence correlates with semantic coherence

**Author:** Davide D'Elia  
**Date:** 2026-01-04

## 1. Setup

In [None]:
# Install dependencies
!pip install -q transformers accelerate einops scipy matplotlib seaborn

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
from scipy import linalg
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Model with Attention Output

In [None]:
# Model configuration
MODEL_NAME = "EleutherAI/pythia-1.4b"  # Smaller for faster extraction
# MODEL_NAME = "EleutherAI/pythia-6.9b"  # Full size (needs more VRAM)

print(f"Loading {MODEL_NAME}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_attentions=True,
    output_hidden_states=True
)
model.eval()

n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads
hidden_dim = model.config.hidden_size
head_dim = hidden_dim // n_heads

print(f"\nModel Configuration:")
print(f"  Layers: {n_layers}")
print(f"  Attention Heads: {n_heads}")
print(f"  Hidden Dim: {hidden_dim}")
print(f"  Head Dim: {head_dim}")

## 3. Extract Attention and Value Matrices

For each layer, we need:
- $A_{ij}$ = softmax(QK^T / sqrt(d)) — attention weights
- $W_V$ = value projection matrix

In [None]:
def get_value_projection_matrices(model):
    """
    Extract W_V matrices from all layers.
    
    Returns:
        Dict[layer_idx -> W_V tensor of shape (hidden_dim, hidden_dim)]
    """
    W_V_matrices = {}
    
    for layer_idx in range(model.config.num_hidden_layers):
        # Access the attention layer
        # Path depends on model architecture (Pythia uses GPT-NeoX structure)
        attn = model.gpt_neox.layers[layer_idx].attention
        
        # In GPT-NeoX, QKV are concatenated in one projection
        # query_key_value has shape (3 * hidden_dim, hidden_dim)
        qkv_weight = attn.query_key_value.weight.data.float().cpu()
        
        # Split into Q, K, V
        # Shape: (3 * hidden_dim, hidden_dim) -> 3 x (hidden_dim, hidden_dim)
        hidden_dim = model.config.hidden_size
        
        # The weight is stored as (out_features, in_features)
        # Split along output dimension
        W_Q = qkv_weight[:hidden_dim, :]
        W_K = qkv_weight[hidden_dim:2*hidden_dim, :]
        W_V = qkv_weight[2*hidden_dim:, :]
        
        W_V_matrices[layer_idx] = W_V
    
    return W_V_matrices

print("Extracting W_V matrices from all layers...")
W_V_matrices = get_value_projection_matrices(model)

print(f"\nExtracted W_V matrices:")
print(f"  Shape: {W_V_matrices[0].shape}")
print(f"  Layers: {len(W_V_matrices)}")

In [None]:
def extract_attention_and_hidden_states(model, tokenizer, prompt, device="cuda"):
    """
    Run forward pass and extract attention weights and hidden states.
    
    Returns:
        attentions: tuple of (n_layers) tensors, each (batch, n_heads, seq_len, seq_len)
        hidden_states: tuple of (n_layers + 1) tensors, each (batch, seq_len, hidden_dim)
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(
            **inputs,
            output_attentions=True,
            output_hidden_states=True
        )
    
    return outputs.attentions, outputs.hidden_states, inputs

# Test with a sample prompt
TEST_PROMPT = "The capital of France is Paris, which is known for the Eiffel Tower."

print(f"Test prompt: '{TEST_PROMPT}'")
attentions, hidden_states, inputs = extract_attention_and_hidden_states(model, tokenizer, TEST_PROMPT)

seq_len = attentions[0].shape[-1]
print(f"\nExtracted:")
print(f"  Sequence length: {seq_len}")
print(f"  Attention shape per layer: {attentions[0].shape}")
print(f"  Hidden state shape per layer: {hidden_states[0].shape}")

## 4. Compute Restriction Maps

The restriction map from token $j$ to token $i$ is:

$$\rho_{ij} = \sqrt{A_{ij}} \cdot W_V$$

This is a $d \times d$ matrix that transforms the value vector at position $j$ before it contributes to position $i$.

In [None]:
def compute_restriction_maps(attention_weights, W_V, head_idx=0, threshold=1e-8):
    """
    Compute restriction maps rho_ij = sqrt(A_ij) * W_V

    Args:
        attention_weights: tensor of shape (batch, n_heads, seq_len, seq_len)
        W_V: tensor of shape (hidden_dim, hidden_dim)
        head_idx: which attention head to use
        threshold: minimum attention weight to include (lowered for sparse attention)

    Returns:
        restriction_maps: dict of (i, j) -> rho_ij matrix
        sqrt_A: the sqrt of attention weights
    """
    # Get attention for specific head
    # Shape: (seq_len, seq_len)
    A = attention_weights[0, head_idx].float().cpu()
    seq_len = A.shape[0]

    # Compute sqrt of attention weights
    sqrt_A = torch.sqrt(A + 1e-10)  # Add epsilon for numerical stability

    # For each pair (i, j), compute rho_ij = sqrt(A_ij) * W_V
    # This gives a scalar * matrix = matrix
    restriction_maps = {}

    for i in range(seq_len):
        for j in range(seq_len):
            if A[i, j] > threshold:  # Only compute for non-zero attention
                rho_ij = sqrt_A[i, j] * W_V
                restriction_maps[(i, j)] = rho_ij

    # If no maps found (very sparse attention), include top-k by attention weight
    if len(restriction_maps) == 0:
        # Flatten and get top-k indices
        A_flat = A.flatten()
        k = min(10, len(A_flat))  # At least 10 maps
        top_k_indices = torch.topk(A_flat, k).indices
        for idx in top_k_indices:
            i = idx.item() // seq_len
            j = idx.item() % seq_len
            rho_ij = sqrt_A[i, j] * W_V
            restriction_maps[(i, j)] = rho_ij

    return restriction_maps, sqrt_A, A

# Compute for first layer
layer_idx = 0
restriction_maps_L0, sqrt_A_L0, A_L0 = compute_restriction_maps(
    attentions[layer_idx], 
    W_V_matrices[layer_idx]
)

print(f"Layer {layer_idx} Restriction Maps:")
print(f"  Number of non-zero maps: {len(restriction_maps_L0)}")
print(f"  Map shape: {list(restriction_maps_L0.values())[0].shape}")

## 5. Analyze Restriction Map Properties

### Key Metrics:
1. **Operator Norm**: $\|\rho_{ij}\|_{op}$ — measures contraction/expansion
2. **Frobenius Norm**: $\|\rho_{ij}\|_F$ — overall magnitude
3. **Spectral Radius**: $\max|\lambda_i|$ — long-term behavior

In [None]:
def analyze_restriction_maps(restriction_maps, W_V):
    """
    Compute statistics of restriction maps.
    
    Returns:
        dict with various statistics
    """
    # W_V baseline statistics (compute first, with error handling)
    W_V_np = W_V.numpy()
    try:
        W_V_op_norm = np.linalg.norm(W_V_np, ord=2)
    except np.linalg.LinAlgError:
        # SVD didn't converge, use Frobenius norm as fallback
        W_V_op_norm = np.linalg.norm(W_V_np, ord='fro') / np.sqrt(min(W_V_np.shape))
    W_V_frob_norm = np.linalg.norm(W_V_np, ord='fro')

    # Handle empty restriction maps
    if len(restriction_maps) == 0:
        return {
            'operator_norms': np.array([]),
            'frobenius_norms': np.array([]),
            'spectral_radii': np.array([]),
            'mean_op_norm': 0.0,
            'std_op_norm': 0.0,
            'max_op_norm': 0.0,
            'min_op_norm': 0.0,
            'W_V_op_norm': W_V_op_norm,
            'W_V_frob_norm': W_V_frob_norm,
            'n_maps': 0,
            'contraction_ratio': 0.0
        }

    operator_norms = []
    frobenius_norms = []
    spectral_radii = []

    for (i, j), rho in restriction_maps.items():
        rho_np = rho.numpy()

        # Operator norm (largest singular value) with fallback
        try:
            op_norm = np.linalg.norm(rho_np, ord=2)
        except np.linalg.LinAlgError:
            # SVD didn't converge, use Frobenius norm estimate
            op_norm = np.linalg.norm(rho_np, ord='fro') / np.sqrt(min(rho_np.shape))
        operator_norms.append(op_norm)

        # Frobenius norm (always works)
        frob_norm = np.linalg.norm(rho_np, ord='fro')
        frobenius_norms.append(frob_norm)

        # Spectral radius (skip - too expensive and error-prone for large matrices)
        spectral_radii.append(np.nan)

    return {
        'operator_norms': np.array(operator_norms),
        'frobenius_norms': np.array(frobenius_norms),
        'spectral_radii': np.array(spectral_radii),
        'mean_op_norm': float(np.mean(operator_norms)) if operator_norms else 0.0,
        'std_op_norm': float(np.std(operator_norms)) if operator_norms else 0.0,
        'max_op_norm': float(np.max(operator_norms)) if operator_norms else 0.0,
        'min_op_norm': float(np.min(operator_norms)) if operator_norms else 0.0,
        'W_V_op_norm': float(W_V_op_norm),
        'W_V_frob_norm': float(W_V_frob_norm),
        'n_maps': len(restriction_maps),
        'contraction_ratio': float(np.mean(operator_norms) / W_V_op_norm) if (operator_norms and W_V_op_norm > 0) else 0.0
    }

# Analyze layer 0
stats_L0 = analyze_restriction_maps(restriction_maps_L0, W_V_matrices[0])

print(f"Layer 0 Restriction Map Statistics:")
print(f"  Number of maps: {stats_L0['n_maps']}")
print(f"  Mean Operator Norm: {stats_L0['mean_op_norm']:.4f}")
print(f"  Std Operator Norm: {stats_L0['std_op_norm']:.4f}")
print(f"  W_V Operator Norm: {stats_L0['W_V_op_norm']:.4f}")
print(f"  Contraction Ratio: {stats_L0['contraction_ratio']:.4f}")

## 6. Layer-wise Analysis: Contraction vs Expansion

**Prediction:** 
- Layers < L*: Restriction maps are **contractive** (norm < 1 relative to W_V)
- Layers > L*: Restriction maps become **expansive** (norm > 1 relative to W_V)

In [None]:
# Analyze all layers
print("Analyzing restriction maps for all layers...")
print("(Using top-k fallback for sparse attention layers)\n")

layer_stats = {}
sparse_layers = []
failed_layers = []

for layer_idx in tqdm(range(n_layers), desc="Layers"):
    try:
        restriction_maps, sqrt_A, A = compute_restriction_maps(
            attentions[layer_idx],
            W_V_matrices[layer_idx]
        )
        stats = analyze_restriction_maps(restriction_maps, W_V_matrices[layer_idx])
        layer_stats[layer_idx] = stats
        
        # Track sparse layers
        if stats['n_maps'] < 20:
            sparse_layers.append((layer_idx, stats['n_maps']))
    
    except Exception as e:
        # Complete failure - use dummy values
        failed_layers.append((layer_idx, str(e)[:50]))
        W_V_np = W_V_matrices[layer_idx].numpy()
        layer_stats[layer_idx] = {
            'operator_norms': np.array([]),
            'frobenius_norms': np.array([]),
            'spectral_radii': np.array([]),
            'mean_op_norm': 0.0,
            'std_op_norm': 0.0,
            'max_op_norm': 0.0,
            'min_op_norm': 0.0,
            'W_V_op_norm': float(np.linalg.norm(W_V_np, ord='fro')),
            'W_V_frob_norm': float(np.linalg.norm(W_V_np, ord='fro')),
            'n_maps': 0,
            'contraction_ratio': 0.0
        }

print("\nDone!")

if failed_layers:
    print(f"\n⚠️  {len(failed_layers)} layers had numerical errors (using fallback):")
    for layer, err in failed_layers:
        print(f"  Layer {layer}: {err}")

if sparse_layers:
    print(f"\nNote: {len(sparse_layers)} layers had sparse attention (<20 maps):")
    for layer, n_maps in sparse_layers[:10]:  # Show first 10
        print(f"  Layer {layer}: {n_maps} maps")

In [None]:
# Extract metrics for plotting
layers = list(range(n_layers))
mean_op_norms = [layer_stats[l]['mean_op_norm'] for l in layers]
contraction_ratios = [layer_stats[l]['contraction_ratio'] for l in layers]
W_V_op_norms = [layer_stats[l]['W_V_op_norm'] for l in layers]

# Find transition point (where contraction ratio crosses 1 or changes trend)
contraction_ratios_np = np.array(contraction_ratios)

# Method 1: Find minimum contraction ratio (most contractive layer)
L_star_contraction = np.argmin(contraction_ratios_np)

# Method 2: Find inflection point
second_derivative = np.diff(np.diff(contraction_ratios_np))
L_star_inflection = np.argmax(np.abs(second_derivative)) + 1

print(f"Transition Points:")
print(f"  L* (min contraction): Layer {L_star_contraction}")
print(f"  L* (inflection): Layer {L_star_inflection}")

In [None]:
# Plot Contraction Analysis
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Mean Operator Norm of Restriction Maps
ax1 = axes[0, 0]
ax1.plot(layers, mean_op_norms, 'b-', linewidth=2, marker='o', markersize=4, label='Mean ||rho_ij||')
ax1.plot(layers, W_V_op_norms, 'r--', linewidth=2, label='||W_V||')
ax1.axvline(x=L_star_contraction, color='green', linestyle=':', linewidth=2, label=f'L* = {L_star_contraction}')
ax1.set_xlabel('Layer', fontsize=12)
ax1.set_ylabel('Operator Norm', fontsize=12)
ax1.set_title('Restriction Map Norms vs W_V Norm', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot 2: Contraction Ratio
ax2 = axes[0, 1]
ax2.plot(layers, contraction_ratios, 'g-', linewidth=2, marker='s', markersize=4)
ax2.axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Neutral (ratio=1)')
ax2.axvline(x=L_star_contraction, color='purple', linestyle=':', linewidth=2, label=f'L* = {L_star_contraction}')
ax2.fill_between(layers, contraction_ratios, 1.0, 
                  where=[c < 1 for c in contraction_ratios], 
                  alpha=0.3, color='blue', label='Contractive')
ax2.fill_between(layers, contraction_ratios, 1.0, 
                  where=[c >= 1 for c in contraction_ratios], 
                  alpha=0.3, color='red', label='Expansive')
ax2.set_xlabel('Layer', fontsize=12)
ax2.set_ylabel('Contraction Ratio (mean||rho|| / ||W_V||)', fontsize=12)
ax2.set_title('Contraction vs Expansion by Layer', fontsize=14)
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3)

# Plot 3: W_V Operator Norm across layers
ax3 = axes[1, 0]
ax3.plot(layers, W_V_op_norms, 'm-', linewidth=2, marker='^', markersize=4)
ax3.axvline(x=L_star_contraction, color='green', linestyle=':', linewidth=2, label=f'L* = {L_star_contraction}')
ax3.set_xlabel('Layer', fontsize=12)
ax3.set_ylabel('||W_V|| Operator Norm', fontsize=12)
ax3.set_title('Value Projection Matrix Norm', fontsize=14)
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)

# Plot 4: Attention Entropy (how focused is attention?)
# Higher entropy = more distributed attention
ax4 = axes[1, 1]

# Compute attention entropy per layer
attention_entropies = []
for layer_idx in range(n_layers):
    A = attentions[layer_idx][0, 0].float().cpu().numpy()  # Head 0
    # Row-wise entropy (how distributed is attention from each position)
    entropies = []
    for row in A:
        row = row + 1e-10  # Numerical stability
        row = row / row.sum()  # Ensure normalization
        entropy = -np.sum(row * np.log(row))
        entropies.append(entropy)
    attention_entropies.append(np.mean(entropies))

ax4.plot(layers, attention_entropies, 'c-', linewidth=2, marker='d', markersize=4)
ax4.axvline(x=L_star_contraction, color='green', linestyle=':', linewidth=2, label=f'L* = {L_star_contraction}')
ax4.set_xlabel('Layer', fontsize=12)
ax4.set_ylabel('Mean Attention Entropy', fontsize=12)
ax4.set_title('Attention Distribution (High = Distributed)', fontsize=14)
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3)

plt.suptitle(f'{MODEL_NAME}: Restriction Map Analysis\n(Prediction: Contractive before L*, Expansive after L*)', 
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('restriction_maps_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n>>> Figure saved as 'restriction_maps_analysis.png'")

## 7. Sheaf Laplacian Construction

The Sheaf Laplacian is:

$$L_F = \delta^T \delta$$

where $\delta$ is the coboundary operator. For our complete graph with restriction maps:

$$\delta: C^0(G; F) \to C^1(G; F)$$
$$(\delta x)_{ij} = \rho_{ji} x_j - \rho_{ij} x_i$$

The Laplacian measures **inconsistency** between local sections.

In [None]:
def build_sheaf_laplacian(attention_weights, W_V, head_idx=0, use_subsample=True, max_tokens=8):
    """
    Construct the Sheaf Laplacian L_F = delta^T delta
    
    For computational efficiency, we work with a subsampled graph.
    
    Args:
        attention_weights: (batch, n_heads, seq_len, seq_len)
        W_V: (hidden_dim, hidden_dim)
        head_idx: which attention head
        use_subsample: whether to subsample tokens
        max_tokens: maximum tokens to use if subsampling
    
    Returns:
        L_F: Sheaf Laplacian matrix
        spectral_info: eigenvalue information
    """
    A = attention_weights[0, head_idx].float().cpu()
    seq_len = A.shape[0]
    d = W_V.shape[0]  # hidden_dim
    
    # Subsample for computational efficiency
    if use_subsample and seq_len > max_tokens:
        indices = np.linspace(0, seq_len-1, max_tokens, dtype=int)
        A = A[np.ix_(indices, indices)]
        seq_len = max_tokens
    
    # For efficiency, use a smaller projection of W_V
    proj_dim = min(32, d)  # Project to smaller dimension
    W_V_small = W_V[:proj_dim, :proj_dim].numpy()
    
    # Compute sqrt(A)
    sqrt_A = torch.sqrt(A + 1e-10).numpy()
    
    # Build block Laplacian
    # L_F is (n * d_small) x (n * d_small) where n = seq_len
    n = seq_len
    d_s = proj_dim
    
    L_F = np.zeros((n * d_s, n * d_s))
    
    # Diagonal blocks: sum over neighbors
    for i in range(n):
        block_ii = np.zeros((d_s, d_s))
        for j in range(n):
            if i != j:
                rho_ij = sqrt_A[i, j] * W_V_small
                block_ii += rho_ij.T @ rho_ij
        L_F[i*d_s:(i+1)*d_s, i*d_s:(i+1)*d_s] = block_ii
    
    # Off-diagonal blocks
    for i in range(n):
        for j in range(n):
            if i != j:
                rho_ij = sqrt_A[i, j] * W_V_small
                rho_ji = sqrt_A[j, i] * W_V_small
                block_ij = -rho_ij.T @ rho_ji
                L_F[i*d_s:(i+1)*d_s, j*d_s:(j+1)*d_s] = block_ij
    
    # Compute eigenvalues with error handling
    try:
        # Add small regularization for numerical stability
        L_F_reg = L_F + 1e-10 * np.eye(L_F.shape[0])
        eigenvalues = np.linalg.eigvalsh(L_F_reg)
        eigenvalues = np.sort(np.real(eigenvalues))  # Ensure real and sorted
    except np.linalg.LinAlgError:
        # Eigenvalues didn't converge - use fallback metrics
        eigenvalues = np.array([0.0, np.trace(L_F) / L_F.shape[0]])  # Approximate
    
    # Spectral information with safety checks
    spectral_info = {
        'eigenvalues': eigenvalues,
        'lambda_1': float(eigenvalues[0]) if len(eigenvalues) > 0 else 0.0,
        'lambda_2': float(eigenvalues[1]) if len(eigenvalues) > 1 else 0.0,
        'spectral_gap': float(eigenvalues[1] - eigenvalues[0]) if len(eigenvalues) > 1 else 0.0,
        'trace': float(np.trace(L_F)),
        'frobenius_norm': float(np.linalg.norm(L_F, 'fro'))
    }
    
    return L_F, spectral_info

# Build Laplacian for layer 0
L_F_0, spectral_0 = build_sheaf_laplacian(attentions[0], W_V_matrices[0])

print(f"Layer 0 Sheaf Laplacian:")
print(f"  Shape: {L_F_0.shape}")
print(f"  Lambda_1 (smallest): {spectral_0['lambda_1']:.6f}")
print(f"  Lambda_2: {spectral_0['lambda_2']:.6f}")
print(f"  Spectral Gap: {spectral_0['spectral_gap']:.6f}")

In [None]:
# Compute Sheaf Laplacian spectral properties for all layers
print("Computing Sheaf Laplacian for all layers...")
print("(With eigenvalue fallback for ill-conditioned matrices)\n")

laplacian_spectral = {}
failed_layers = []

for layer_idx in tqdm(range(n_layers), desc="Layers"):
    try:
        L_F, spectral = build_sheaf_laplacian(
            attentions[layer_idx], 
            W_V_matrices[layer_idx],
            max_tokens=8  # Keep small for efficiency
        )
        laplacian_spectral[layer_idx] = spectral
    except Exception as e:
        # Complete failure - use dummy values
        failed_layers.append((layer_idx, str(e)))
        laplacian_spectral[layer_idx] = {
            'eigenvalues': np.array([0.0, 0.0]),
            'lambda_1': 0.0,
            'lambda_2': 0.0,
            'spectral_gap': 0.0,
            'trace': 0.0,
            'frobenius_norm': 0.0
        }

print("\nDone!")
if failed_layers:
    print(f"\nWarning: {len(failed_layers)} layers had numerical issues:")
    for layer, err in failed_layers[:5]:  # Show first 5
        print(f"  Layer {layer}: {err[:50]}...")

In [None]:
# Plot Sheaf Laplacian Spectral Analysis
spectral_gaps = [laplacian_spectral[l]['spectral_gap'] for l in layers]
lambda_1s = [laplacian_spectral[l]['lambda_1'] for l in layers]
lambda_2s = [laplacian_spectral[l]['lambda_2'] for l in layers]
traces = [laplacian_spectral[l]['trace'] for l in layers]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Spectral Gap
ax1 = axes[0, 0]
ax1.plot(layers, spectral_gaps, 'b-', linewidth=2, marker='o', markersize=4)
ax1.axvline(x=L_star_contraction, color='red', linestyle='--', linewidth=2, label=f'L* = {L_star_contraction}')
ax1.set_xlabel('Layer', fontsize=12)
ax1.set_ylabel('Spectral Gap (lambda_2 - lambda_1)', fontsize=12)
ax1.set_title('Sheaf Laplacian Spectral Gap', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot 2: First two eigenvalues
ax2 = axes[0, 1]
ax2.semilogy(layers, [max(l, 1e-10) for l in lambda_1s], 'g-', linewidth=2, marker='s', markersize=4, label='lambda_1')
ax2.semilogy(layers, lambda_2s, 'r-', linewidth=2, marker='^', markersize=4, label='lambda_2')
ax2.axvline(x=L_star_contraction, color='purple', linestyle='--', linewidth=2, label=f'L* = {L_star_contraction}')
ax2.set_xlabel('Layer', fontsize=12)
ax2.set_ylabel('Eigenvalue (log scale)', fontsize=12)
ax2.set_title('First Two Eigenvalues of L_F', fontsize=14)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

# Plot 3: Trace (total "energy")
ax3 = axes[1, 0]
ax3.plot(layers, traces, 'm-', linewidth=2, marker='d', markersize=4)
ax3.axvline(x=L_star_contraction, color='red', linestyle='--', linewidth=2, label=f'L* = {L_star_contraction}')
ax3.set_xlabel('Layer', fontsize=12)
ax3.set_ylabel('Trace(L_F)', fontsize=12)
ax3.set_title('Sheaf Laplacian Trace (Total Inconsistency)', fontsize=14)
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)

# Plot 4: Eigenvalue spectrum at selected layers
ax4 = axes[1, 1]
key_layers = [0, n_layers // 4, n_layers // 2, 3 * n_layers // 4, n_layers - 1]
colors = plt.cm.viridis(np.linspace(0, 1, len(key_layers)))

for idx, layer in enumerate(key_layers):
    eigs = laplacian_spectral[layer]['eigenvalues'][:20]  # First 20 eigenvalues
    ax4.semilogy(range(len(eigs)), eigs + 1e-10, 
                 label=f'Layer {layer}', color=colors[idx], linewidth=2)

ax4.set_xlabel('Eigenvalue Index', fontsize=12)
ax4.set_ylabel('Eigenvalue (log scale)', fontsize=12)
ax4.set_title('Eigenvalue Spectrum at Key Layers', fontsize=14)
ax4.legend(fontsize=9)
ax4.grid(True, alpha=0.3)

plt.suptitle(f'{MODEL_NAME}: Sheaf Laplacian Spectral Analysis', 
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('sheaf_laplacian_spectral.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n>>> Figure saved as 'sheaf_laplacian_spectral.png'")

## 8. Validation Summary

In [None]:
import json

# Prepare summary
summary = {
    'model': MODEL_NAME,
    'n_layers': int(n_layers),
    'n_heads': int(n_heads),
    'hidden_dim': int(hidden_dim),
    'test_prompt': TEST_PROMPT,
    'seq_len': int(seq_len),
    
    'restriction_maps': {
        'L_star_min_contraction': int(L_star_contraction),
        'L_star_inflection': int(L_star_inflection),
        'contraction_ratios': [float(c) for c in contraction_ratios],
        'mean_op_norms': [float(n) for n in mean_op_norms],
        'W_V_op_norms': [float(n) for n in W_V_op_norms],
        'attention_entropies': [float(e) for e in attention_entropies]
    },
    
    'sheaf_laplacian': {
        'spectral_gaps': [float(g) for g in spectral_gaps],
        'lambda_1': [float(l) for l in lambda_1s],
        'lambda_2': [float(l) for l in lambda_2s],
        'traces': [float(t) for t in traces]
    },
    
    'validation': {
        'formula_tested': 'rho_ij = sqrt(A_ij) * W_V',
        'contraction_before_Lstar': bool(np.mean(contraction_ratios[:L_star_contraction]) < 1),
        'expansion_after_Lstar': bool(np.mean(contraction_ratios[L_star_contraction:]) > np.mean(contraction_ratios[:L_star_contraction]))
    }
}

# Save to JSON
with open('restriction_maps_results.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("="*60)
print("RESTRICTION MAPS VALIDATION SUMMARY")
print("="*60)
print(f"\nModel: {MODEL_NAME}")
print(f"Test prompt: '{TEST_PROMPT[:50]}...'")
print(f"\nFormula tested: rho_ij = sqrt(A_ij) * W_V")
print(f"\nKey Findings:")
print(f"  L* (min contraction): Layer {L_star_contraction}")
print(f"  Contraction before L*: {summary['validation']['contraction_before_Lstar']}")
print(f"  Relative expansion after L*: {summary['validation']['expansion_after_Lstar']}")
print(f"\nSheaf Laplacian:")
print(f"  Spectral gap range: {min(spectral_gaps):.4f} - {max(spectral_gaps):.4f}")
print(f"  Max gap at layer: {np.argmax(spectral_gaps)}")
print(f"\nFiles saved:")
print(f"  - restriction_maps_analysis.png")
print(f"  - sheaf_laplacian_spectral.png")
print(f"  - restriction_maps_results.json")
print("="*60)

In [None]:
# Create ZIP archive
import zipfile
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"restriction_maps_results_{timestamp}.zip"

with zipfile.ZipFile(zip_filename, 'w') as zipf:
    zipf.write('restriction_maps_analysis.png')
    zipf.write('sheaf_laplacian_spectral.png')
    zipf.write('restriction_maps_results.json')

print(f">>> Created: {zip_filename}")

## 9. Interpretation

### Theoretical Prediction

The restriction map formula $\rho_{ij} = \sqrt{A_{ij}} \cdot W_V$ has specific implications:

1. **Attention-weighted transport**: The $\sqrt{A_{ij}}$ factor scales the transport based on attention
2. **Value transformation**: $W_V$ transforms the semantic content during transport
3. **Contraction/Expansion**: The norm of $\rho_{ij}$ determines whether information is compressed or expanded

### Expected vs Observed

| Prediction | Expected | Status |
|------------|----------|--------|
| $\rho_{ij}$ follows formula | $\sqrt{A_{ij}} \cdot W_V$ | CONSTRUCTED |
| Contraction before L* | Norm ratio < 1 | CHECK RESULTS |
| Expansion after L* | Norm ratio increases | CHECK RESULTS |
| Spectral gap shift at L* | Change in $\lambda_2 - \lambda_1$ | CHECK RESULTS |

### Significance

If validated, this confirms:
1. Transformers implicitly implement sheaf diffusion
2. The attention mechanism defines restriction maps
3. Layer dynamics follow sheaf-theoretic predictions

This provides the **mechanistic basis** for the phase-structured dynamics observed in Papers #1 and #2.

## 10. Download Results

In [None]:
# Download all results
from google.colab import files

print("Downloading result files...")
print()

# Download ZIP
print(f"1. ZIP Archive: {zip_filename}")
files.download(zip_filename)

# Individual files
print("\n2. Individual files:")
print("   - restriction_maps_analysis.png")
files.download('restriction_maps_analysis.png')
print("   - sheaf_laplacian_spectral.png")
files.download('sheaf_laplacian_spectral.png')
print("   - restriction_maps_results.json")
files.download('restriction_maps_results.json')

print("\n>>> All files downloaded!")