# Notebook 1: Physics-Attention vs Standard Attention

**Goal:** Understand why standard transformers are expensive for physics simulations and how Physics-Attention solves this.

## Outline
1. Load Sample Stokes Flow Data
2. The Quadratic Cost Problem
3. Physics-Attention: The 4-Step Solution
4. Visualize with Real Mesh Data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from utils import softmax, download_stokes_dataset, load_stokes_sample

np.random.seed(42)

## 1. Download & Load Stokes Flow Data

We'll use the Stokes flow dataset (same as Lab 2). Run the cell below to download if needed.

In [None]:
# Download dataset (if not present) and load a sample
download_stokes_dataset()
coords, u, v, p = load_stokes_sample()
N_mesh = len(coords)

# Interactive visualization with Plotly
fig = make_subplots(rows=1, cols=3, subplot_titles=['Velocity u', 'Velocity v', 'Pressure p'])

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=u, colorscale='RdBu_r', showscale=True, 
                colorbar=dict(x=0.28, len=0.8, title='u')),
    name='u', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>u=%{marker.color:.3f}'), row=1, col=1)

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=v, colorscale='RdBu_r', showscale=True,
                colorbar=dict(x=0.62, len=0.8, title='v')),
    name='v', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>v=%{marker.color:.3f}'), row=1, col=2)

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=p, colorscale='Viridis', showscale=True,
                colorbar=dict(x=0.97, len=0.8, title='p')),
    name='p', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>p=%{marker.color:.3f}'), row=1, col=3)

fig.update_layout(title=f'Stokes Flow Data (N={N_mesh} mesh points)', height=400, width=1100, showlegend=False)
fig.update_xaxes(title_text='x')
fig.update_yaxes(title_text='y', scaleanchor='x', scaleratio=1)
fig.show()

print(f"✓ Loaded mesh with {N_mesh} points")

## 2. The Quadratic Cost Problem

Standard self-attention computes: $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$

The $QK^T$ creates an **N × N** attention matrix — this is expensive for large meshes!

In [None]:
# Visualize the cost scaling
N = N_mesh  # Use actual mesh size
M = 8  # Number of slices

mesh_sizes = np.array([N, N*2, N*5, N*10, N*50, N*100])
standard_cost = mesh_sizes ** 2  # O(N²)
physics_cost = mesh_sizes * M    # O(N·M)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Cost comparison
ax1 = axes[0]
ax1.loglog(mesh_sizes, standard_cost, 'r-o', label='Standard Attention O(N²)', linewidth=2)
ax1.loglog(mesh_sizes, physics_cost, 'g-o', label='Physics-Attention O(N·M)', linewidth=2)
ax1.axvline(x=N, color='blue', linestyle='--', alpha=0.5, label=f'Our mesh (N={N})')
ax1.set_xlabel('Mesh Points (N)')
ax1.set_ylabel('Operations')
ax1.set_title('Computational Cost Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Show what N×N attention matrix looks like (subsampled)
ax2 = axes[1]
subsample = min(100, N)
std_attn = softmax(np.random.randn(subsample, subsample), axis=1)
im = ax2.imshow(std_attn, cmap='Reds', aspect='auto')
ax2.set_title(f'Standard Attention Matrix\n(showing {subsample}×{subsample} of {N}×{N})')
ax2.set_xlabel('Key points')
ax2.set_ylabel('Query points')
plt.colorbar(im, ax=ax2)

plt.tight_layout()
plt.show()

print(f"\nFor our mesh (N={N}):")
print(f"  Standard Attention: {N}×{N} = {N*N:,} operations")
print(f"  Physics-Attention:  {N}×{M} = {N*M:,} operations")
print(f"  Cost reduction: {N*N // (N*M)}x cheaper!")

## 3. Physics-Attention: The 4-Step Solution

### What are "Slices"?

In physics simulations, different regions of a mesh often have similar physical behavior:
- **Inlet region**: Smooth laminar flow
- **Obstacle wake**: Turbulent/recirculating flow  
- **Boundary layer**: High gradients near walls
- **Far-field**: Nearly uniform flow

**Slices** are learned groupings that cluster mesh points with similar physics. Instead of every point attending to every other point (N×N), we:
1. Group points into M slices (soft assignment via learned weights)
2. Compute attention only between slice representations (M×M)

### The 4-Step Algorithm

```
Input: X ∈ R^(N×C)  (N mesh points, C features)
       W ∈ R^(C×M)  (learnable slice weights)

Step 1 - SLICE:     S = softmax(X @ W)           → S ∈ R^(N×M)  (assignment weights)
Step 2 - AGGREGATE: Z = S^T @ X                  → Z ∈ R^(M×C)  (slice tokens)
Step 3 - ATTEND:    Z' = Attention(Z, Z, Z)      → Z' ∈ R^(M×C) (M×M attention!)
Step 4 - DESLICE:   Y = S @ Z'                   → Y ∈ R^(N×C)  (broadcast back)

Output: Y ∈ R^(N×C)
```

**Key insight:** Step 3 is O(M²) instead of O(N²), and M is typically 8-64 while N can be 10,000+!

In [None]:
# Create physics-based features from our mesh (coords + physics values)
C = 8  # Feature dimension
features = np.column_stack([
    coords,  # x, y coordinates
    u.reshape(-1, 1),  # velocity u
    v.reshape(-1, 1),  # velocity v  
    p.reshape(-1, 1),  # pressure
    np.random.randn(N, C-5)  # padding to get C features
])

# Simulate slice assignment (learned weights in real model)
W_slice = np.random.randn(C, M) * 0.5
slice_logits = features @ W_slice
slice_weights = softmax(slice_logits, axis=1)  # (N, M) - soft assignment

# Simulate M×M attention
attn_matrix = softmax(np.random.randn(M, M), axis=1)

# Visualize the comparison
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Standard attention matrix (subsampled)
ax1 = axes[0]
subsample = min(80, N)
std_attn = softmax(np.random.randn(subsample, subsample), axis=1)
im1 = ax1.imshow(std_attn, cmap='Reds')
ax1.set_title(f'Standard Attention\n({N}×{N} = {N*N:,} ops)', fontsize=11)
ax1.set_xlabel('Key points')
ax1.set_ylabel('Query points')

# Slice assignment matrix
ax2 = axes[1]
# Sort by x-coordinate for visualization
sort_idx = np.argsort(coords[:, 0])
im2 = ax2.imshow(slice_weights[sort_idx[:200]].T, aspect='auto', cmap='Greens')
ax2.set_title('Slice Assignments\n(N points → M slices)', fontsize=11)
ax2.set_xlabel('Mesh points (sorted by x)')
ax2.set_ylabel(f'Slices (M={M})')
plt.colorbar(im2, ax=ax2)

# Physics-Attention: M×M
ax3 = axes[2]
im3 = ax3.imshow(attn_matrix, cmap='Greens')
ax3.set_title(f'Physics-Attention\n({M}×{M} = {M*M} ops)', fontsize=11)
ax3.set_xlabel('Key slices')
ax3.set_ylabel('Query slices')
plt.colorbar(im3, ax=ax3)

plt.tight_layout()
plt.show()

print(f"Cost reduction: {N*N:,} → {M*M} = {N*N // (M*M):,}x fewer attention operations!")

## 4. How Mesh Points Get Distributed to Slices

Let's visualize exactly how mesh points get assigned to slices. Each point gets a **soft assignment weight** to each slice (values sum to 1). The point "belongs" most strongly to the slice with highest weight.

### Step-by-Step with M=3 Slices

In [None]:
# Detailed visualization with M=3 slices
M_demo = 3
np.random.seed(123)  # For reproducibility

# Step 1: Compute slice assignment weights
W_demo = np.random.randn(C, M_demo) * 0.8
slice_logits = features @ W_demo
slice_weights = softmax(slice_logits, axis=1)  # Shape: (N, 3)
dominant_slice = np.argmax(slice_weights, axis=1)

# Count points per slice
slice_colors = ['#e41a1c', '#377eb8', '#4daf4a']  # Red, Blue, Green
slice_names = ['Slice 0 (Red)', 'Slice 1 (Blue)', 'Slice 2 (Green)']
counts = [np.sum(dominant_slice == i) for i in range(M_demo)]

print("="*60)
print(f"SLICE ASSIGNMENT SUMMARY (M={M_demo} slices)")
print("="*60)
for i, (name, count) in enumerate(zip(slice_names, counts)):
    pct = 100 * count / N
    print(f"  {name}: {count:,} points ({pct:.1f}%)")
print("="*60)

# Create comprehensive visualization
fig = plt.figure(figsize=(16, 10))

# Row 1: Show individual slices
for i in range(M_demo):
    ax = fig.add_subplot(2, 4, i+1)
    mask = dominant_slice == i
    
    # Plot all points faded
    ax.scatter(coords[:, 0], coords[:, 1], c='lightgray', s=3, alpha=0.3)
    # Highlight this slice's points
    ax.scatter(coords[mask, 0], coords[mask, 1], c=slice_colors[i], s=8, alpha=0.8)
    
    ax.set_title(f'Slice {i}\n({counts[i]:,} points)', fontsize=11, color=slice_colors[i])
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')

# Row 1, Col 4: All slices combined
ax_combined = fig.add_subplot(2, 4, 4)
for i in range(M_demo):
    mask = dominant_slice == i
    ax_combined.scatter(coords[mask, 0], coords[mask, 1], 
                       c=slice_colors[i], s=6, alpha=0.7, label=f'Slice {i}')
ax_combined.set_title('All 3 Slices Combined', fontsize=11)
ax_combined.set_xlabel('x')
ax_combined.set_ylabel('y')
ax_combined.set_aspect('equal')
ax_combined.legend(loc='upper right', fontsize=9)

# Row 2, Col 1: Soft assignment weights heatmap
ax_weights = fig.add_subplot(2, 4, 5)
# Sort points by dominant slice for visualization
sort_idx = np.lexsort((coords[np.arange(N), 0], dominant_slice))
im = ax_weights.imshow(slice_weights[sort_idx].T, aspect='auto', cmap='YlOrRd',
                       extent=[0, N, -0.5, M_demo-0.5])
ax_weights.set_title('Soft Assignment Weights\n(each column sums to 1)', fontsize=10)
ax_weights.set_xlabel(f'Mesh Points (sorted, N={N})')
ax_weights.set_ylabel('Slice ID')
ax_weights.set_yticks([0, 1, 2])
plt.colorbar(im, ax=ax_weights, label='Weight')

# Row 2, Col 2: Bar chart of distribution
ax_bar = fig.add_subplot(2, 4, 6)
bars = ax_bar.bar(range(M_demo), counts, color=slice_colors, edgecolor='black')
ax_bar.set_xlabel('Slice ID')
ax_bar.set_ylabel('Number of Points')
ax_bar.set_title('Points per Slice', fontsize=11)
ax_bar.set_xticks(range(M_demo))
for bar, count in zip(bars, counts):
    ax_bar.annotate(f'{count}', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                   ha='center', va='bottom', fontsize=10)

# Row 2, Col 3-4: Example point assignment
ax_example = fig.add_subplot(2, 4, 7)
# Pick 5 example points (one from each region)
example_indices = np.random.choice(N, 5, replace=False)
example_weights = slice_weights[example_indices]

x_pos = np.arange(5)
width = 0.25
for i in range(M_demo):
    ax_example.bar(x_pos + i*width, example_weights[:, i], width, 
                  color=slice_colors[i], label=f'Slice {i}', alpha=0.8)
ax_example.set_xlabel('Example Points')
ax_example.set_ylabel('Assignment Weight')
ax_example.set_title('Soft Weights for 5 Sample Points\n(each point sums to 1)', fontsize=10)
ax_example.set_xticks(x_pos + width)
ax_example.set_xticklabels([f'P{i}' for i in range(5)])
ax_example.legend(fontsize=8)
ax_example.set_ylim(0, 1)

# Row 2, Col 4: The resulting M×M attention
ax_attn = fig.add_subplot(2, 4, 8)
phys_attn = softmax(np.random.randn(M_demo, M_demo), axis=1)
im_attn = ax_attn.imshow(phys_attn, cmap='Greens', vmin=0, vmax=1)
ax_attn.set_title(f'Physics-Attention Matrix\n(only {M_demo}×{M_demo}={M_demo**2} ops!)', fontsize=10)
ax_attn.set_xlabel('Key Slice')
ax_attn.set_ylabel('Query Slice')
ax_attn.set_xticks(range(M_demo))
ax_attn.set_yticks(range(M_demo))
plt.colorbar(im_attn, ax=ax_attn)

plt.tight_layout()
plt.show()

print(f"\n✓ Instead of {N}×{N}={N*N:,} attention ops, we only need {M_demo}×{M_demo}={M_demo**2}!")
print(f"✓ Cost reduction: {N*N // M_demo**2:,}x fewer operations")

### Effect of Different Slice Counts (M)

More slices = finer grouping but higher cost. Typical values: M=8 to M=64.

In [None]:
# Compare M=4, M=8, M=16 slice partitioning
slice_configs = [4, 8, 16]
fig_compare, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, num_slices in enumerate(slice_configs):
    ax = axes[idx]
    np.random.seed(42 + idx)  # Different seed for variety
    
    # Compute slice assignments
    W = np.random.randn(C, num_slices) * 0.6
    logits = features @ W
    weights = softmax(logits, axis=1)
    dominant_slice = np.argmax(weights, axis=1)
    
    # Plot mesh colored by slice
    scatter = ax.scatter(coords[:, 0], coords[:, 1], c=dominant_slice, 
                        cmap='tab10' if num_slices <= 10 else 'tab20',
                        s=6, alpha=0.7)
    
    # Calculate cost reduction
    cost_reduction = N*N // (num_slices**2)
    ax.set_title(f'M = {num_slices} slices\nAttention: {num_slices}×{num_slices}={num_slices**2} ops\n({cost_reduction:,}x cheaper)', 
                fontsize=10)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')
    plt.colorbar(scatter, ax=ax, label='Slice ID', shrink=0.8)

plt.suptitle('Trade-off: More Slices = Finer Resolution but Higher Cost', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

print("\nTypical M values in Transolver: 8-64 (paper uses M=32 or M=64)")

## Summary

| Aspect | Standard Attention | Physics-Attention |
|--------|-------------------|-------------------|
| **Complexity** | O(N²) — expensive! | O(N·M) — efficient! |
| **Attention matrix** | N×N | M×M (M≈64) |
| **Grouping** | All-to-all | Learned slices |

**Key Takeaway:** Physics-Attention reduces cost by grouping mesh points into M learned "slices" and performing attention between these compressed representations.

**Next:** Notebook 2 shows the full Transolver architecture in PyTorch.