# Notebook 2: Transolver for Stokes Flow

**Goal:** Apply Transolver to the same Stokes flow problem from Lab 2 and compare architectures.

## Outline
1. Quick Review: The Stokes Flow Problem
2. Transolver Architecture Overview  
3. PyTorch Implementation
4. Comparison: MeshGraphNet vs Transolver

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from utils import softmax, download_stokes_dataset, load_stokes_sample

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. The Stokes Flow Problem (Recap)

From Lab 2, we're solving:

$$-\nu \Delta \mathbf{u} + \nabla p = 0, \quad \nabla \cdot \mathbf{u} = 0$$

**Input:** Mesh with obstacle geometry  
**Output:** Velocity field $(u, v)$ and pressure $p$

**Challenge:** Variable mesh sizes, irregular geometries

## 2. Transolver Architecture

```
Input (N, d_in) → Embedding → [Physics-Attention Block] × L → Decoder → Output (N, d_out)
                                     ↓
                              Slice → Aggregate → Attend → Deslice
```

## 3. PyTorch Implementation

In [None]:
class PhysicsAttentionLayer(nn.Module):
    """Physics-Attention layer from Transolver."""
    
    def __init__(self, dim, num_slices=64, num_heads=8):
        super().__init__()
        self.dim = dim
        self.num_slices = num_slices
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        # Slice projection
        self.slice_proj = nn.Linear(dim, num_slices)
        
        # Multi-head attention projections
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        
        # Layer norm
        self.norm = nn.LayerNorm(dim)
    
    def forward(self, x):
        """
        Args:
            x: (batch, N, dim) - mesh point features
        Returns:
            out: (batch, N, dim)
        """
        B, N, D = x.shape
        residual = x
        x = self.norm(x)
        
        # Step 1: SLICE - compute soft assignments
        slice_logits = self.slice_proj(x)  # (B, N, M)
        slice_weights = F.softmax(slice_logits, dim=-1)  # (B, N, M)
        
        # Step 2: AGGREGATE - compress to M tokens
        # Weighted sum: (B, M, D)
        slice_weights_t = slice_weights.transpose(1, 2)  # (B, M, N)
        z = torch.bmm(slice_weights_t, x)  # (B, M, D)
        # Normalize by sum of weights
        z = z / (slice_weights_t.sum(dim=-1, keepdim=True) + 1e-8)
        
        # Step 3: ATTEND - M×M multi-head attention
        M = self.num_slices
        Q = self.q_proj(z).view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(z).view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(z).view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        z_prime = torch.matmul(attn, V)  # (B, heads, M, head_dim)
        z_prime = z_prime.transpose(1, 2).reshape(B, M, D)  # (B, M, D)
        z_prime = self.out_proj(z_prime)
        
        # Step 4: DESLICE - broadcast back to N points
        out = torch.bmm(slice_weights, z_prime)  # (B, N, D)
        
        return residual + out

# Test the layer
B, N, D, M = 2, 500, 128, 32
layer = PhysicsAttentionLayer(dim=D, num_slices=M).to(device)
x_test = torch.randn(B, N, D).to(device)
out = layer(x_test)
print(f"Input:  {x_test.shape}")
print(f"Output: {out.shape}")
print(f"✓ Physics-Attention layer works!")

In [None]:
class SimpleTransolver(nn.Module):
    """Simplified Transolver for Stokes flow."""
    
    def __init__(self, in_dim=2, out_dim=3, hidden_dim=128, num_layers=4, num_slices=64):
        super().__init__()
        self.embedding = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.layers = nn.ModuleList([
            PhysicsAttentionLayer(hidden_dim, num_slices=num_slices)
            for _ in range(num_layers)
        ])
        
        self.decoder = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        """
        Args:
            x: (batch, N, in_dim) - mesh coordinates
        Returns:
            out: (batch, N, out_dim) - predicted fields (u, v, p)
        """
        h = self.embedding(x)
        for layer in self.layers:
            h = layer(h)
        return self.decoder(h)

# Test full model
model = SimpleTransolver(in_dim=2, out_dim=3).to(device)
coords = torch.randn(2, 500, 2).to(device)  # (batch, N points, 2D coords)
pred = model(coords)
print(f"Coordinates: {coords.shape}")
print(f"Prediction:  {pred.shape} (u, v, p for each point)")

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters:  {n_params:,}")

## 4. Comparison: MeshGraphNet vs Transolver

| Aspect | MeshGraphNet (Lab 2) | Transolver |
|--------|---------------------|------------|
| **Architecture** | Graph Neural Network | Transformer + Physics-Attention |
| **Message Passing** | Local (neighbors only) | Global (via slices) |
| **Complexity** | O(E) edges | O(N·M) |
| **Inductive Bias** | Mesh connectivity | Learned physics slices |
| **Scalability** | Good | Excellent |

In [None]:
# Timing comparison (forward pass only)
import time

mesh_sizes = [500, 1000, 2000, 5000, 10000]
transolver_times = []

for N in mesh_sizes:
    model = SimpleTransolver(in_dim=2, out_dim=3, num_slices=64).to(device)
    x = torch.randn(1, N, 2).to(device)
    
    # Warmup
    with torch.no_grad():
        _ = model(x)
    
    # Time forward pass
    if device.type == 'cuda':
        torch.cuda.synchronize()
    start = time.time()
    with torch.no_grad():
        for _ in range(10):
            _ = model(x)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    elapsed = (time.time() - start) / 10
    transolver_times.append(elapsed * 1000)  # ms
    print(f"N={N:5d}: {elapsed*1000:.2f} ms")

# Plot scaling
plt.figure(figsize=(8, 5))
plt.plot(mesh_sizes, transolver_times, 'g-o', label='Transolver', linewidth=2)
# Theoretical quadratic for comparison
theoretical_quadratic = [t * (n/mesh_sizes[0])**2 for t, n in zip([transolver_times[0]], mesh_sizes)]
plt.plot(mesh_sizes, [transolver_times[0] * (n/mesh_sizes[0])**2 for n in mesh_sizes], 
         'r--', label='O(N²) reference', alpha=0.5)
plt.xlabel('Mesh Size (N)')
plt.ylabel('Forward Pass Time (ms)')
plt.title('Transolver Scaling: Near-Linear!')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Summary & Key Takeaways

**What we learned:**
1. **Physics-Attention** reduces complexity from O(N²) to O(N·M) by grouping points into learned "slices"
2. **Transolver** stacks Physics-Attention layers to create a scalable transformer for PDEs
3. Unlike GNNs, Transolver can capture **global** interactions efficiently

**When to use what:**
- **MeshGraphNet**: When mesh connectivity is important, moderate mesh sizes
- **Transolver**: Large meshes, need for global interactions, variable mesh topologies

**Next steps:**
- Train Transolver on the Stokes dataset from Lab 2
- Compare accuracy with MeshGraphNet
- Experiment with number of slices (M) and layers

## References
- [Transolver Paper](https://arxiv.org/abs/2402.02366)
- [MeshGraphNet Paper](https://arxiv.org/abs/2010.03409)
- [PhysicsNeMo Documentation](https://github.com/NVIDIA/physicsnemo)