# Notebook 1: Physics-Attention vs Standard Attention

**Goal:** Understand why standard transformers are expensive for physics simulations and how Physics-Attention solves this.

## Outline
1. Load Sample Stokes Flow Data
2. The Quadratic Cost Problem
3. Physics-Attention: The 4-Step Solution
4. Visualize with Real Mesh Data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from utils import softmax, download_stokes_dataset, load_stokes_sample

np.random.seed(42)

## 1. Download & Load Stokes Flow Data

We'll use the Stokes flow dataset (same as Lab 2). Run the cell below to download if needed.

In [None]:
# Download dataset (if not present) and load a sample
download_stokes_dataset()
coords, u, v, p = load_stokes_sample()
N_mesh = len(coords)

# Interactive visualization with Plotly
fig = make_subplots(rows=1, cols=3, subplot_titles=['Velocity u', 'Velocity v', 'Pressure p'])

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=u, colorscale='RdBu_r', showscale=True, 
                colorbar=dict(x=0.28, len=0.8, title='u')),
    name='u', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>u=%{marker.color:.3f}'), row=1, col=1)

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=v, colorscale='RdBu_r', showscale=True,
                colorbar=dict(x=0.62, len=0.8, title='v')),
    name='v', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>v=%{marker.color:.3f}'), row=1, col=2)

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=p, colorscale='Viridis', showscale=True,
                colorbar=dict(x=0.97, len=0.8, title='p')),
    name='p', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>p=%{marker.color:.3f}'), row=1, col=3)

fig.update_layout(title=f'Stokes Flow Data (N={N_mesh} mesh points)', height=400, width=1100, showlegend=False)
fig.update_xaxes(title_text='x')
fig.update_yaxes(title_text='y', scaleanchor='x', scaleratio=1)
fig.show()

print(f"✓ Loaded mesh with {N_mesh} points")

## 2. Standard Attention vs Physics-Attention

### Standard Self-Attention (Transformers)

Standard self-attention computes: $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$

**What are Q, K, V?**
- **Q (Query)**: "What information am I looking for?" — Each point asks a question
- **K (Key)**: "What information do I have?" — Each point advertises its content  
- **V (Value)**: "What information to send?" — The actual content to aggregate

In standard attention, **every point queries every other point**:
```
For N mesh points:
  Q = X @ W_q  →  (N, d)   # N queries
  K = X @ W_k  →  (N, d)   # N keys
  V = X @ W_v  →  (N, d)   # N values
  
  Attention = softmax(Q @ K^T / √d) @ V
                       ↑
              This is N×N = O(N²) operations!
```

### Why This is Problematic for Simulations

| Mesh Points (N) | Attention Matrix | Memory (float32) | Feasible? |
|-----------------|------------------|------------------|-----------|
| 1,000 | 1M ops | 4 MB | ✓ OK |
| 10,000 | 100M ops | 400 MB | ⚠️ Slow |
| 100,000 | 10B ops | 40 GB | ❌ Too expensive |

Real CFD meshes often have 100K+ points — standard attention is infeasible!

### Physics-Attention: The Key Difference

Instead of every point attending to every other point, **group similar physics together**:

```
Standard:  Point → Point attention    (N×N)
Physics:   Point → Slice → Point      (N×M + M×M + M×N)
```

**In Physics-Attention:**
- **Q, K, V are computed on M slice tokens** (not N points!)
- Slices are learned groupings of mesh points with similar physical behavior
- Typical M = 8-64, while N = 10,000+

This reduces complexity from **O(N²)** to **O(N·M + M²) ≈ O(N·M)**

In [None]:
# Define both attention functions to compare
N = N_mesh  # Use actual mesh size
C = 8       # Feature dimension (e.g., coordinates + physics quantities)

def standard_attention(X, d_k=8):
    """
    Standard self-attention: O(N²) complexity
    
    X: (N, C) - features for N mesh points
    Returns: output (N, C), attention matrix (N, N)
    """
    # Q, K, V projections - each mesh point gets its own query/key/value
    Q = X  # Query: (N, C) - "What am I looking for?" for each of N points
    K = X  # Key:   (N, C) - "What do I contain?" for each of N points
    V = X  # Value: (N, C) - "What to send?" for each of N points
    
    # THE EXPENSIVE PART: Every point attends to every other point
    scores = Q @ K.T / np.sqrt(d_k)  # (N, N) - N² dot products!
    attn = softmax(scores, axis=1)   # (N, N) - attention weights
    
    output = attn @ V  # (N, C) - weighted combination of all values
    return output, attn

def physics_attention(X, M, d_k=8):
    """
    Physics-Attention: O(N·M) complexity where M << N
    
    Key insight: Instead of N points having Q/K/V, we have M SLICES with Q/K/V!
    
    X: (N, C) - features for N mesh points
    M: number of physics-based slices (typically 8-64)
    Returns: output (N, C), slice weights (N, M), slice attention (M, M)
    """
    N, C = X.shape
    
    # Step 1: SLICE - soft-assign N points to M slices
    W_slice = np.random.randn(C, M) * 0.5  # Learnable projection
    slice_weights = softmax(X @ W_slice, axis=1)  # (N, M) - which slice each point belongs to
    
    # Step 2: AGGREGATE - compress N points into M slice tokens
    slice_tokens = slice_weights.T @ X  # (M, C) - weighted average per slice
    
    # Step 3: ATTEND - Q/K/V on SLICES, not points!
    Q = slice_tokens  # (M, C) - only M queries (not N!)
    K = slice_tokens  # (M, C) - only M keys
    V = slice_tokens  # (M, C) - only M values
    scores = Q @ K.T / np.sqrt(d_k)  # (M, M) - only M² ops, not N²!
    slice_attn = softmax(scores, axis=1)  # (M, M)
    attended = slice_attn @ V  # (M, C)
    
    # Step 4: DESLICE - broadcast M tokens back to N points
    output = slice_weights @ attended  # (N, C)
    
    return output, slice_weights, slice_attn

# Create sample input
X = np.random.randn(N, C)

# Compare costs for different M values
print("="*70)
print(f"COST COMPARISON: Standard vs Physics-Attention (N = {N:,} mesh points)")
print("="*70)
print(f"\n{'Method':<25} {'Attention Size':<18} {'Operations':<15} {'Speedup':<10}")
print("-"*70)

# Standard attention
std_out, std_attn = standard_attention(X)
std_ops = N * N
print(f"{'Standard Attention':<25} {f'{N}×{N}':<18} {std_ops:,} ops{'':<5} {'1x (baseline)':<10}")

# Physics attention with different M
for M in [4, 8, 16, 32, 64]:
    phys_out, slice_w, slice_attn = physics_attention(X, M)
    # Total ops: N·M (slice) + M² (attend) + N·M (deslice) ≈ 2·N·M + M²
    phys_ops = 2 * N * M + M * M
    speedup = std_ops / phys_ops
    print(f"{'Physics-Attention M='+str(M):<25} {f'{M}×{M}':<18} {phys_ops:,} ops{'':<5} {speedup:.0f}x faster")

print("-"*70)
print(f"\n✓ With M=8 slices, we get ~{std_ops // (2*N*8 + 64):.0f}x speedup!")
print("✓ The key: M×M attention instead of N×N")

## 3. Physics-Attention: The 4-Step Solution

### What are "Slices"?

In physics simulations, different regions of a mesh often have similar physical behavior:
- **Inlet region**: Smooth laminar flow
- **Obstacle wake**: Turbulent/recirculating flow  
- **Boundary layer**: High gradients near walls
- **Far-field**: Nearly uniform flow

**Slices** are learned groupings that cluster mesh points with similar physics. Instead of every point attending to every other point (N×N), we:
1. Group points into M slices (soft assignment via learned weights)
2. Compute attention only between slice representations (M×M)

### The 4-Step Algorithm

```
Input: X ∈ R^(N×C)  (N mesh points, C features)
       W ∈ R^(C×M)  (learnable slice weights)

Step 1 - SLICE:     S = softmax(X @ W)           → S ∈ R^(N×M)  (assignment weights)
Step 2 - AGGREGATE: Z = S^T @ X                  → Z ∈ R^(M×C)  (slice tokens)
Step 3 - ATTEND:    Z' = Attention(Z, Z, Z)      → Z' ∈ R^(M×C) (M×M attention!)
Step 4 - DESLICE:   Y = S @ Z'                   → Y ∈ R^(N×C)  (broadcast back)

Output: Y ∈ R^(N×C)
```

**Key insight:** Step 3 is O(M²) instead of O(N²), and M is typically 8-64 while N can be 10,000+!

In [None]:
# Create physics-based features from our mesh (coords + physics values)
C = 8  # Feature dimension
features = np.column_stack([
    coords,  # x, y coordinates
    u.reshape(-1, 1),  # velocity u
    v.reshape(-1, 1),  # velocity v  
    p.reshape(-1, 1),  # pressure
    np.random.randn(N, C-5)  # padding to get C features
])

print(f"✓ Created feature matrix: {features.shape} (N={N} points, C={C} features)")

## 4. How Mesh Points Get Distributed to Slices

Let's visualize exactly how mesh points get assigned to slices. Each point gets a **soft assignment weight** to each slice (values sum to 1). The point "belongs" most strongly to the slice with highest weight.

### Step-by-Step with M=3 Slices

In [None]:
# Detailed visualization with M=3 slices
M_demo = 3
np.random.seed(123)  # For reproducibility

# Step 1: Compute slice assignment weights
W_demo = np.random.randn(C, M_demo) * 0.8
slice_logits = features @ W_demo
slice_weights = softmax(slice_logits, axis=1)  # Shape: (N, 3)
dominant_slice = np.argmax(slice_weights, axis=1)

# Count points per slice
slice_colors = ['#e41a1c', '#377eb8', '#4daf4a']  # Red, Blue, Green
slice_names = ['Slice 0', 'Slice 1', 'Slice 2']
counts = [np.sum(dominant_slice == i) for i in range(M_demo)]

print("="*60)
print(f"SLICE ASSIGNMENT SUMMARY (M={M_demo} slices)")
print("="*60)
for i, (name, count) in enumerate(zip(slice_names, counts)):
    pct = 100 * count / N
    print(f"  {name}: {count:,} points ({pct:.1f}%)")
print("="*60)

# ============================================================
# FIGURE 1: Each Slice Shown Separately (Clear Visualization)
# ============================================================
fig1, axes = plt.subplots(1, 4, figsize=(16, 4))

# Show each slice in its own subplot
for i in range(M_demo):
    ax = axes[i]
    mask = dominant_slice == i
    
    # Plot all points faded
    ax.scatter(coords[:, 0], coords[:, 1], c='lightgray', s=4, alpha=0.3)
    # Highlight this slice's points
    ax.scatter(coords[mask, 0], coords[mask, 1], c=slice_colors[i], s=10, alpha=0.9)
    
    ax.set_title(f'Slice {i}: {counts[i]:,} points ({100*counts[i]/N:.1f}%)', 
                fontsize=12, fontweight='bold', color=slice_colors[i])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')

# Final subplot: All slices combined
ax = axes[3]
for i in range(M_demo):
    mask = dominant_slice == i
    ax.scatter(coords[mask, 0], coords[mask, 1], 
              c=slice_colors[i], s=8, alpha=0.7, label=f'Slice {i} ({counts[i]:,} pts)')
ax.set_title('All Slices Combined', fontsize=12, fontweight='bold')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
ax.legend(loc='upper right', fontsize=9)

plt.suptitle(f'Mesh Points Distributed Across {M_demo} Slices', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# ============================================================
# FIGURE 2: Physics-Attention Matrix (M×M instead of N×N)
# ============================================================
fig2, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

# Left: Points per slice bar chart
bars = ax1.bar(range(M_demo), counts, color=slice_colors, edgecolor='black', linewidth=1.5)
ax1.set_xlabel('Slice ID', fontsize=11)
ax1.set_ylabel('Number of Points', fontsize=11)
ax1.set_title('Points per Slice', fontsize=12, fontweight='bold')
ax1.set_xticks(range(M_demo))
for bar, count in zip(bars, counts):
    ax1.annotate(f'{count:,}', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                ha='center', va='bottom', fontsize=11, fontweight='bold')

# Right: The resulting M×M attention matrix
phys_attn = softmax(np.random.randn(M_demo, M_demo), axis=1)
im_attn = ax2.imshow(phys_attn, cmap='Greens', vmin=0, vmax=1)
ax2.set_title(f'Physics-Attention: {M_demo}×{M_demo} = {M_demo**2} ops\n(instead of {N}×{N} = {N*N:,} ops)', 
             fontsize=11, fontweight='bold')
ax2.set_xlabel('Key Slice')
ax2.set_ylabel('Query Slice')
ax2.set_xticks(range(M_demo))
ax2.set_yticks(range(M_demo))
plt.colorbar(im_attn, ax=ax2, shrink=0.8)

plt.tight_layout()
plt.show()

print(f"\n✓ Cost reduction: {N*N:,} → {M_demo**2} = {N*N // M_demo**2:,}x fewer operations!")

### Effect of Different Slice Counts (M)

More slices = finer grouping but higher cost. Typical values: M=8 to M=64.

In [None]:
# Compare M=4, M=8, M=16 slice partitioning
slice_configs = [4, 8, 16]
fig_compare, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, num_slices in enumerate(slice_configs):
    ax = axes[idx]
    np.random.seed(42 + idx)  # Different seed for variety
    
    # Compute slice assignments
    W = np.random.randn(C, num_slices) * 0.6
    logits = features @ W
    weights = softmax(logits, axis=1)
    dominant_slice = np.argmax(weights, axis=1)
    
    # Plot mesh colored by slice
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=dominant_slice, 
                        cmap='tab10' if num_slices <= 10 else 'tab20',
                        s=6, alpha=0.7)
    
    # Calculate cost reduction
    cost_reduction = N*N // (num_slices**2)
    ax.set_title(f'M = {num_slices} slices\nAttention: {num_slices}×{num_slices}={num_slices**2} ops\n({cost_reduction:,}x cheaper)', 
                fontsize=10)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')
    plt.colorbar(scatter, ax=ax, label='Slice ID', shrink=0.8)

plt.suptitle('Trade-off: More Slices = Finer Resolution but Higher Cost', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

print("\nTypical M values in Transolver: 8-64 (paper uses M=32 or M=64)")

## Summary

| Aspect | Standard Attention | Physics-Attention |
|--------|-------------------|-------------------|
| **Complexity** | O(N²) — expensive! | O(N·M) — efficient! |
| **Attention matrix** | N×N | M×M (M≈64) |
| **Grouping** | All-to-all | Learned slices |

**Key Takeaway:** Physics-Attention reduces cost by grouping mesh points into M learned "slices" and performing attention between these compressed representations.

**Next:** Notebook 2 shows the full Transolver architecture in PyTorch.