# Notebook 1: Physics-Attention vs Standard Attention

**Goal:** Understand why standard transformers are expensive for physics simulations and how Physics-Attention solves this.

## Outline
1. Load Sample Stokes Flow Data
2. Standard Attention vs Physics-Attention
3. Physics-Attention: The 4-Step Algorithm
4. Visualizing Slice Assignments

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from utils import softmax, download_stokes_dataset, load_stokes_sample

np.random.seed(42)

## 1. Download & Load Stokes Flow Data

We'll use the Stokes flow dataset (same as Lab 2). Run the cell below to download if needed.

In [None]:
# Download dataset (if not present) and load a sample
download_stokes_dataset()
coords, u, v, p = load_stokes_sample()
N_mesh = len(coords)

# Visualize the Stokes flow data (uniform sizing)
fig, axes = plt.subplots(1, 3, figsize=(15, 3.5))

fields = [('Velocity u', u, 'RdBu_r'), ('Velocity v', v, 'RdBu_r'), ('Pressure p', p, 'viridis')]

for ax, (title, field, cmap) in zip(axes, fields):
    sc = ax.scatter(coords[:, 0], coords[:, 1], c=field, cmap=cmap, s=5)
    ax.set_title(title)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(coords[:, 0].min() - 0.05, coords[:, 0].max() + 0.05)
    ax.set_ylim(coords[:, 1].min() - 0.05, coords[:, 1].max() + 0.05)
    ax.set_aspect('equal')
    plt.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)

plt.suptitle(f'Stokes Flow Data (N={N_mesh} mesh points)', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

print(f"✓ Loaded mesh with {N_mesh} points")

## 2. Standard Attention vs Physics-Attention

### Standard Self-Attention

The transformer's attention mechanism computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- **Q** (Query), **K** (Key), **V** (Value) are linear projections of input X ∈ ℝ^(N×C)
- The **QKᵀ** term creates an **N × N** attention matrix
- Every point attends to every other point → **O(N²) complexity**

**Problem:** For a mesh with N=10,000 points, we need 100,000,000 operations per layer!

### Physics-Attention (Transolver)

Instead of N×N attention, Physics-Attention:
1. Groups N mesh points into **M slices** (M << N, typically 8-64)
2. Computes **M × M** attention between slice representations
3. Broadcasts results back to N points

**Result:** O(N·M + M²) ≈ **O(N·M)** complexity — orders of magnitude faster!

| | Standard Attention | Physics-Attention |
|---|---|---|
| **Attention matrix** | N × N | M × M |
| **Complexity** | O(N²) | O(N·M) |
| **N=10,000, M=8** | 100,000,000 ops | 80,000 ops |

In [None]:
# Define both attention functions to compare
N = N_mesh  # Use actual mesh size
C = 8       # Feature dimension

def standard_attention(X, d_k=8):
    """
    Standard self-attention: O(N²) complexity
    X: (N, C) input features
    Returns: (N, C) output, attention matrix (N, N)
    """
    # Linear projections (simplified - same weights for demo)
    Q = X  # Query: (N, C)
    K = X  # Key:   (N, C)  
    V = X  # Value: (N, C)
    
    # THE EXPENSIVE PART: N×N attention matrix
    scores = Q @ K.T / np.sqrt(d_k)  # (N, N) - O(N²) operations!
    attn = softmax(scores, axis=1)   # (N, N)
    
    output = attn @ V  # (N, C)
    return output, attn

def physics_attention(X, M, d_k=8):
    """
    Physics-Attention: O(N·M) complexity where M << N
    X: (N, C) input features
    M: number of slices
    Returns: (N, C) output, slice weights (N, M), slice attention (M, M)
    """
    N, C = X.shape
    
    # Step 1: SLICE - assign N points to M slices (soft assignment)
    W_slice = np.random.randn(C, M) * 0.5  # Learnable in real model
    slice_weights = softmax(X @ W_slice, axis=1)  # (N, M) - O(N·M)
    
    # Step 2: AGGREGATE - compress each slice into single token
    slice_tokens = slice_weights.T @ X  # (M, C) - weighted sum per slice
    
    # Step 3: ATTEND - M×M attention (THE CHEAP PART!)
    Q = slice_tokens  # (M, C)
    K = slice_tokens  # (M, C)
    V = slice_tokens  # (M, C)
    scores = Q @ K.T / np.sqrt(d_k)  # (M, M) - only O(M²) operations!
    slice_attn = softmax(scores, axis=1)  # (M, M)
    attended = slice_attn @ V  # (M, C)
    
    # Step 4: DESLICE - broadcast back to N points
    output = slice_weights @ attended  # (N, C) - O(N·M)
    
    return output, slice_weights, slice_attn

# Create sample input
X = np.random.randn(N, C)

# Compare costs for different M values
print("="*70)
print(f"COST COMPARISON: Standard vs Physics-Attention (N = {N:,} mesh points)")
print("="*70)
print(f"\n{'Method':<25} {'Attention Size':<18} {'Operations':<15} {'Speedup':<10}")
print("-"*70)

# Standard attention
std_out, std_attn = standard_attention(X)
std_ops = N * N
print(f"{'Standard Attention':<25} {f'{N}×{N}':<18} {std_ops:,} ops{'':<5} {'1x (baseline)':<10}")

# Physics attention with different M
for M in [4, 8, 16, 32, 64]:
    phys_out, slice_w, slice_attn = physics_attention(X, M)
    # Total ops: N·M (slice) + M² (attend) + N·M (deslice) ≈ 2·N·M + M²
    phys_ops = 2 * N * M + M * M
    speedup = std_ops / phys_ops
    print(f"{'Physics-Attention M='+str(M):<25} {f'{M}×{M}':<18} {phys_ops:,} ops{'':<5} {speedup:.0f}x faster")

print("-"*70)
print(f"\n✓ With M=8 slices, we get ~{std_ops // (2*N*8 + 64):.0f}x speedup!")
print("✓ The key: M×M attention instead of N×N")

## 3. Physics-Attention: The 4-Step Solution

### What are "Slices"?

In physics simulations, different regions of a mesh often have similar physical behavior:
- **Inlet region**: Smooth laminar flow
- **Obstacle wake**: Turbulent/recirculating flow  
- **Boundary layer**: High gradients near walls
- **Far-field**: Nearly uniform flow

**Slices** are learned groupings that cluster mesh points with similar physics. Instead of every point attending to every other point (N×N), we:
1. Group points into M slices (soft assignment via learned weights)
2. Compute attention only between slice representations (M×M)

### The 4-Step Algorithm

```
Input: X ∈ R^(N×C)  (N mesh points, C features)
       W ∈ R^(C×M)  (learnable slice weights)

Step 1 - SLICE:     S = softmax(X @ W)           → S ∈ R^(N×M)  (assignment weights)
Step 2 - AGGREGATE: Z = S^T @ X                  → Z ∈ R^(M×C)  (slice tokens)
Step 3 - ATTEND:    Z' = Attention(Z, Z, Z)      → Z' ∈ R^(M×C) (M×M attention!)
Step 4 - DESLICE:   Y = S @ Z'                   → Y ∈ R^(N×C)  (broadcast back)

Output: Y ∈ R^(N×C)
```

**Key insight:** Step 3 is O(M²) instead of O(N²), and M is typically 8-64 while N can be 10,000+!

In [None]:
# Create physics-based features from our mesh (coords + physics values)
C = 8  # Feature dimension
features = np.column_stack([
    coords,  # x, y coordinates
    u.reshape(-1, 1),  # velocity u
    v.reshape(-1, 1),  # velocity v  
    p.reshape(-1, 1),  # pressure
    np.random.randn(N, C-5)  # padding to get C features
])

print(f"✓ Created feature matrix: {features.shape} (N={N} points, C={C} features)")

## 4. How Mesh Points Get Distributed to Slices

Let's visualize exactly how mesh points get assigned to slices. Each point gets a **soft assignment weight** to each slice (values sum to 1). The point "belongs" most strongly to the slice with highest weight.

### Step-by-Step with M=3 Slices

In [None]:
# Visualize slice assignments with M=3 slices
M_demo = 3
np.random.seed(123)

# Compute slice assignments
W_demo = np.random.randn(C, M_demo) * 0.8
slice_weights = softmax(features @ W_demo, axis=1)  # (N, M) soft assignment
dominant_slice = np.argmax(slice_weights, axis=1)   # Hard assignment for visualization

slice_colors = ['#e41a1c', '#377eb8', '#4daf4a']  # Red, Blue, Green
counts = [np.sum(dominant_slice == i) for i in range(M_demo)]

# Simple visualization: mesh + attention matrix
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Left: Mesh colored by slice assignment
ax1 = axes[0]
for i in range(M_demo):
    mask = dominant_slice == i
    ax1.scatter(coords[mask, 0], coords[mask, 1], c=slice_colors[i], 
               s=8, alpha=0.7, label=f'Slice {i} ({counts[i]} pts)')
ax1.set_title(f'Mesh Points Assigned to {M_demo} Slices', fontsize=12)
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_aspect('equal')
ax1.legend(loc='upper right')

# Right: The M×M attention matrix (the cheap part!)
ax2 = axes[1]
phys_attn = softmax(np.random.randn(M_demo, M_demo), axis=1)
im = ax2.imshow(phys_attn, cmap='Greens', vmin=0, vmax=1)
ax2.set_title(f'Physics-Attention: {M_demo}×{M_demo} = {M_demo**2} ops\n(instead of {N}×{N} = {N*N:,} ops)', fontsize=11)
ax2.set_xlabel('Key Slice')
ax2.set_ylabel('Query Slice')
ax2.set_xticks(range(M_demo))
ax2.set_yticks(range(M_demo))
plt.colorbar(im, ax=ax2, shrink=0.8)

plt.tight_layout()
plt.show()

print(f"✓ Cost reduction: {N*N:,} → {M_demo**2} = {N*N // M_demo**2:,}x fewer operations!")

### Effect of Different Slice Counts (M)

More slices = finer grouping but higher cost. Typical values: M=8 to M=64.

In [None]:
# Compare M=4, M=8, M=16 slice partitioning
slice_configs = [4, 8, 16]
fig_compare, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, num_slices in enumerate(slice_configs):
    ax = axes[idx]
    np.random.seed(42 + idx)  # Different seed for variety
    
    # Compute slice assignments
    W = np.random.randn(C, num_slices) * 0.6
    logits = features @ W
    weights = softmax(logits, axis=1)
    dominant_slice = np.argmax(weights, axis=1)
    
    # Plot mesh colored by slice
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=dominant_slice, 
                        cmap='tab10' if num_slices <= 10 else 'tab20',
                        s=6, alpha=0.7)
    
    # Calculate cost reduction
    cost_reduction = N*N // (num_slices**2)
    ax.set_title(f'M = {num_slices} slices\nAttention: {num_slices}×{num_slices}={num_slices**2} ops\n({cost_reduction:,}x cheaper)', 
                fontsize=10)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')
    plt.colorbar(scatter, ax=ax, label='Slice ID', shrink=0.8)

plt.suptitle('Trade-off: More Slices = Finer Resolution but Higher Cost', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

print("\nTypical M values in Transolver: 8-64 (paper uses M=32 or M=64)")

## Summary

| Aspect | Standard Attention | Physics-Attention |
|--------|-------------------|-------------------|
| **Complexity** | O(N²) — expensive! | O(N·M) — efficient! |
| **Attention matrix** | N×N | M×M (M≈64) |
| **Grouping** | All-to-all | Learned slices |

**Key Takeaway:** Physics-Attention reduces cost by grouping mesh points into M learned "slices" and performing attention between these compressed representations.

**Next:** Notebook 2 shows the full Transolver architecture in PyTorch.