# Notebook 1: Physics-Attention vs Standard Attention

**Goal:** Understand why standard transformers are expensive for physics simulations and how Physics-Attention solves this.

## Outline
1. Load Sample Stokes Flow Data
2. The Quadratic Cost Problem
3. Physics-Attention: The 4-Step Solution
4. Visualize with Real Mesh Data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from utils import softmax, download_stokes_dataset, load_stokes_sample

np.random.seed(42)

## 1. Download & Load Stokes Flow Data

We'll use the Stokes flow dataset (same as Lab 2). Run the cell below to download if needed.

In [None]:
# Download dataset (if not present) and load a sample
download_stokes_dataset()
coords, u, v, p = load_stokes_sample()
N_mesh = len(coords)

# Interactive visualization with Plotly
fig = make_subplots(rows=1, cols=3, subplot_titles=['Velocity u', 'Velocity v', 'Pressure p'])

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=u, colorscale='RdBu_r', showscale=True, 
                colorbar=dict(x=0.28, len=0.8, title='u')),
    name='u', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>u=%{marker.color:.3f}'), row=1, col=1)

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=v, colorscale='RdBu_r', showscale=True,
                colorbar=dict(x=0.62, len=0.8, title='v')),
    name='v', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>v=%{marker.color:.3f}'), row=1, col=2)

fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], mode='markers',
    marker=dict(size=5, color=p, colorscale='Viridis', showscale=True,
                colorbar=dict(x=0.97, len=0.8, title='p')),
    name='p', hovertemplate='x=%{x:.2f}<br>y=%{y:.2f}<br>p=%{marker.color:.3f}'), row=1, col=3)

fig.update_layout(title=f'Stokes Flow Data (N={N_mesh} mesh points)', height=400, width=1100, showlegend=False)
fig.update_xaxes(title_text='x')
fig.update_yaxes(title_text='y', scaleanchor='x', scaleratio=1)
fig.show()

print(f"✓ Loaded mesh with {N_mesh} points")

## 2. The Quadratic Cost Problem

Standard self-attention computes: $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$

The $QK^T$ creates an **N × N** attention matrix — this is expensive for large meshes!

In [None]:
# Visualize the cost scaling
N = N_mesh  # Use actual mesh size
M = 8  # Number of slices

mesh_sizes = np.array([N, N*2, N*5, N*10, N*50, N*100])
standard_cost = mesh_sizes ** 2  # O(N²)
physics_cost = mesh_sizes * M    # O(N·M)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Cost comparison
ax1 = axes[0]
ax1.loglog(mesh_sizes, standard_cost, 'r-o', label='Standard Attention O(N²)', linewidth=2)
ax1.loglog(mesh_sizes, physics_cost, 'g-o', label='Physics-Attention O(N·M)', linewidth=2)
ax1.axvline(x=N, color='blue', linestyle='--', alpha=0.5, label=f'Our mesh (N={N})')
ax1.set_xlabel('Mesh Points (N)')
ax1.set_ylabel('Operations')
ax1.set_title('Computational Cost Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Show what N×N attention matrix looks like (subsampled)
ax2 = axes[1]
subsample = min(100, N)
std_attn = softmax(np.random.randn(subsample, subsample), axis=1)
im = ax2.imshow(std_attn, cmap='Reds', aspect='auto')
ax2.set_title(f'Standard Attention Matrix\n(showing {subsample}×{subsample} of {N}×{N})')
ax2.set_xlabel('Key points')
ax2.set_ylabel('Query points')
plt.colorbar(im, ax=ax2)

plt.tight_layout()
plt.show()

print(f"\nFor our mesh (N={N}):")
print(f"  Standard Attention: {N}×{N} = {N*N:,} operations")
print(f"  Physics-Attention:  {N}×{M} = {N*M:,} operations")
print(f"  Cost reduction: {N*N // (N*M)}x cheaper!")

## 3. Physics-Attention: The 4-Step Solution

Transolver's key insight: **Physics is mostly local**. We can group points into M "slices" and attend between slices instead.

1. **Slice**: Assign N points to M groups (soft assignment)
2. **Aggregate**: Compress each slice into a single token  
3. **Attend**: M×M attention between slice tokens (cheap!)
4. **Deslice**: Broadcast back to N points

In [None]:
# Create physics-based features from our mesh (coords + physics values)
C = 8  # Feature dimension
features = np.column_stack([
    coords,  # x, y coordinates
    u.reshape(-1, 1),  # velocity u
    v.reshape(-1, 1),  # velocity v  
    p.reshape(-1, 1),  # pressure
    np.random.randn(N, C-5)  # padding to get C features
])

# Simulate slice assignment (learned weights in real model)
W_slice = np.random.randn(C, M) * 0.5
slice_logits = features @ W_slice
slice_weights = softmax(slice_logits, axis=1)  # (N, M) - soft assignment

# Simulate M×M attention
attn_matrix = softmax(np.random.randn(M, M), axis=1)

# Visualize the comparison
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Standard attention matrix (subsampled)
ax1 = axes[0]
subsample = min(80, N)
std_attn = softmax(np.random.randn(subsample, subsample), axis=1)
im1 = ax1.imshow(std_attn, cmap='Reds')
ax1.set_title(f'Standard Attention\n({N}×{N} = {N*N:,} ops)', fontsize=11)
ax1.set_xlabel('Key points')
ax1.set_ylabel('Query points')

# Slice assignment matrix
ax2 = axes[1]
# Sort by x-coordinate for visualization
sort_idx = np.argsort(coords[:, 0])
im2 = ax2.imshow(slice_weights[sort_idx[:200]].T, aspect='auto', cmap='Greens')
ax2.set_title('Slice Assignments\n(N points → M slices)', fontsize=11)
ax2.set_xlabel('Mesh points (sorted by x)')
ax2.set_ylabel(f'Slices (M={M})')
plt.colorbar(im2, ax=ax2)

# Physics-Attention: M×M
ax3 = axes[2]
im3 = ax3.imshow(attn_matrix, cmap='Greens')
ax3.set_title(f'Physics-Attention\n({M}×{M} = {M*M} ops)', fontsize=11)
ax3.set_xlabel('Key slices')
ax3.set_ylabel('Query slices')
plt.colorbar(im3, ax=ax3)

plt.tight_layout()
plt.show()

print(f"Cost reduction: {N*N:,} → {M*M} = {N*N // (M*M):,}x fewer attention operations!")

## 4. Visualizing Slices on Our Mesh

Let's see how different numbers of slices partition our Stokes flow mesh. The slices learn to group points with similar physical behavior.

In [None]:
# Interactive visualization of slice assignments with Plotly
slice_counts = [4, 8, 16]
fig = make_subplots(rows=1, cols=3, subplot_titles=[f'M = {m} slices' for m in slice_counts])

colors_tab10 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', 
                '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
                '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5', '#c49c94']

for idx, num_slices in enumerate(slice_counts):
    # Create slice assignments
    W = np.random.randn(C, num_slices) * 0.5
    logits = features @ W
    weights = softmax(logits, axis=1)
    dominant_slice = np.argmax(weights, axis=1)
    
    # Add traces for each slice
    for slice_id in range(num_slices):
        mask = dominant_slice == slice_id
        if mask.sum() > 0:
            fig.add_trace(go.Scatter(
                x=coords[mask, 0], y=coords[mask, 1], mode='markers',
                marker=dict(size=6, color=colors_tab10[slice_id % len(colors_tab10)]),
                name=f'Slice {slice_id}', showlegend=(idx == 0),
                hovertemplate=f'Slice {slice_id}<br>x=%{{x:.2f}}<br>y=%{{y:.2f}}'
            ), row=1, col=idx+1)

fig.update_layout(
    title='How Different Slice Counts (M) Partition the Mesh',
    height=400, width=1200,
    legend=dict(x=1.02, y=0.5)
)
fig.update_xaxes(title_text='x')
fig.update_yaxes(title_text='y', scaleanchor='x', scaleratio=1)
fig.show()

print("Note: In a trained model, slices would group physically similar regions")
print("      (e.g., inlet, wake, boundary layers) rather than random partitions.")

## Summary

| Aspect | Standard Attention | Physics-Attention |
|--------|-------------------|-------------------|
| **Complexity** | O(N²) — expensive! | O(N·M) — efficient! |
| **Attention matrix** | N×N | M×M (M≈64) |
| **Grouping** | All-to-all | Learned slices |

**Key Takeaway:** Physics-Attention reduces cost by grouping mesh points into M learned "slices" and performing attention between these compressed representations.

**Next:** Notebook 2 shows the full Transolver architecture in PyTorch.