# Notebook 2: Training Transolver on Stokes Flow

**Goal:** Train a Transolver model and visualize how it learns to partition the mesh into physics-based slices.

## Outline
1. Load Stokes Flow Dataset
2. Transolver Model (PhysicsNeMo)
3. Training Loop (200 Epochs)
4. Visualize Learned Slice Assignments

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os

from utils import download_stokes_dataset, load_stokes_sample, get_num_samples

torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Stokes Flow Dataset

We'll use the same dataset from Lab 4: Stokes flow around obstacles with varying geometries.

In [None]:
# Download and load dataset
download_stokes_dataset()

# Load multiple samples for training
num_samples = max(1, min(get_num_samples(), 50))  # Use up to 50 samples
print(f"Loading {num_samples} samples...")

train_data = []
for i in range(num_samples):
    coords, u, v, p = load_stokes_sample(sample_idx=i)
    train_data.append({
        'coords': torch.tensor(coords, dtype=torch.float32),
        'targets': torch.tensor(np.stack([u, v, p], axis=1), dtype=torch.float32)
    })

# Use first sample for visualization
sample_coords = train_data[0]['coords']
sample_targets = train_data[0]['targets']
N = len(sample_coords)
print(f"✓ Loaded {num_samples} samples, each with ~{N} mesh points")
print(f"  Input: coordinates (N, 2)")
print(f"  Output: u, v, p (N, 3)")

## 2. Transolver Model

We'll try to use PhysicsNeMo's Transolver if available, otherwise fall back to a simple implementation.

In [None]:
# Import PhysicsNeMo Transolver
from physicsnemo.models.transolver import Transolver

# Model configuration
NUM_SLICES = 32  # Number of learned slices
HIDDEN_DIM = 128
NUM_LAYERS = 4
NUM_HEADS = 8

print("✓ PhysicsNeMo Transolver imported")
print(f"  Config: hidden_dim={HIDDEN_DIM}, layers={NUM_LAYERS}, slices={NUM_SLICES}, heads={NUM_HEADS}")

In [None]:
# Create PhysicsNeMo Transolver model
# For unstructured meshes: embedding_dim=2 (x,y coordinates), functional_dim=0 (no additional input features)
model = Transolver(
    functional_dim=0,           # No functional input, just coordinates
    out_dim=3,                  # Output: u, v, p
    embedding_dim=2,            # 2D coordinates as embeddings
    n_layers=NUM_LAYERS,
    n_hidden=HIDDEN_DIM,
    n_head=NUM_HEADS,
    slice_num=NUM_SLICES,
    unified_pos=False,          # We provide our own embeddings (coordinates)
    structured_shape=None,      # Irregular/unstructured mesh
    use_te=False,               # Don't require transformer engine
).to(device)

# Storage for captured slice weights
captured_slice_weights = {}

def capture_slice_weights_hook(module, input, output):
    """Hook to capture slice weights from PhysicsAttention forward pass."""
    # The slice weights are computed in compute_slices_from_projections
    # We need to access them through the module's internal state
    # After forward, we can recompute them from the stored projections
    pass

# Helper function to get slice weights by running a modified forward pass
def get_slice_weights(model, coords):
    """
    Extract slice weights from the first PhysicsAttention layer.
    
    PhysicsNeMo's Transolver computes slice_weights internally in each block's Attn module.
    We access the first block and manually compute slice weights from the projections.
    """
    model.eval()
    with torch.no_grad():
        # Prepare input (PhysicsNeMo expects: fx=functional_input, embedding=spatial_coords)
        B, N, _ = coords.shape
        fx = torch.zeros(B, N, 0, device=coords.device)  # Empty functional input
        embedding = coords  # Coordinates as embeddings
        
        # Run through preprocessing
        fx_combined = torch.cat((embedding, fx), -1)  # (B, N, 2)
        h = model.preprocess(fx_combined)  # (B, N, hidden_dim)
        
        # Get first block's attention module
        first_block = model.blocks[0]
        attn_module = first_block.Attn
        
        # Run through LayerNorm
        h_normed = first_block.ln_1(h)
        
        # Project to head dimension (PhysicsNeMo internal)
        x_mid = attn_module.in_project_x(h_normed)
        x_mid = x_mid.view(B, N, attn_module.heads, attn_module.dim_head)
        
        # Compute slice projections
        slice_projections = attn_module.in_project_slice(x_mid)  # (B, N, heads, slices)
        
        # Compute slice weights (temperature-scaled softmax)
        temp = torch.clamp(attn_module.temperature, min=0.5, max=5)
        slice_weights = F.softmax(slice_projections / temp, dim=-1)  # (B, N, heads, slices)
        
        # Average across heads for visualization
        slice_weights_avg = slice_weights.mean(dim=2)  # (B, N, slices)
        
        return slice_weights_avg

n_params = sum(p.numel() for p in model.parameters())
print(f"✓ PhysicsNeMo Transolver created: {n_params:,} parameters")
print(f"  Architecture: {NUM_LAYERS} layers × {NUM_HEADS} heads × {NUM_SLICES} slices")

## 3. Train for 20 Epochs

We'll train the model and track how the slice assignments evolve during training.

In [None]:
# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

NUM_EPOCHS = 100
losses = []
slice_history = []  # Store slice assignments at different epochs

# Get initial slice assignments (before training)
x = sample_coords.unsqueeze(0).to(device)
initial_slices = get_slice_weights(model, x)[0].cpu().numpy()
slice_history.append(('Epoch 0 (untrained)', initial_slices))

print("Training PhysicsNeMo Transolver...")
print("-" * 50)

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0.0
    
    # Shuffle data each epoch
    indices = np.random.permutation(len(train_data))
    
    for idx in indices:
        data = train_data[idx]
        coords = data['coords'].unsqueeze(0).to(device)  # (1, N, 2) - used as embedding
        targets = data['targets'].unsqueeze(0).to(device)  # (1, N, 3)
        
        # PhysicsNeMo Transolver forward: fx=None (no functional input), embedding=coords
        optimizer.zero_grad()
        B, N, _ = coords.shape
        fx = torch.zeros(B, N, 0, device=device)  # Empty functional input
        pred = model(fx, embedding=coords)
        loss = criterion(pred, targets)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_data)
    losses.append(avg_loss)
    
    # Save slice assignments at key epochs
    if (epoch + 1) in [10, 25, 50, 100]:
        x = sample_coords.unsqueeze(0).to(device)
        slices = get_slice_weights(model, x)[0].cpu().numpy()
        slice_history.append((f'Epoch {epoch+1}', slices))
    
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{NUM_EPOCHS}: Loss = {avg_loss:.6f}")

print("-" * 50)
print(f"✓ Training complete! Final loss: {losses[-1]:.6f}")

## Summary & Key Takeaways

Let's see how the model learned to partition the mesh during training.

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, NUM_EPOCHS+1), losses, 'b-o', linewidth=2, markersize=4)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.semilogy(range(1, NUM_EPOCHS+1), losses, 'b-o', linewidth=2, markersize=4)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss (log scale)')
plt.title('Training Loss (Log Scale)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Visualize how slice assignments evolved during training
coords_np = sample_coords.numpy()
num_snapshots = len(slice_history)

fig, axes = plt.subplots(1, num_snapshots, figsize=(4*num_snapshots, 4))

for idx, (title, slice_weights) in enumerate(slice_history):
    ax = axes[idx] if num_snapshots > 1 else axes
    
    # Get dominant slice for each point
    dominant_slice = np.argmax(slice_weights, axis=1)
    
    # Plot mesh colored by slice
    scatter = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=dominant_slice, 
                        cmap='tab10', s=6, alpha=0.7)
    ax.set_title(title, fontsize=11)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')

plt.suptitle('Evolution of Learned Slice Assignments During Training', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

print("Notice how slices become more spatially coherent as training progresses!")

In [None]:
# Final comparison: Learned slices vs actual physics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Get final slice assignments
final_slices = slice_history[-1][1]
dominant_slice = np.argmax(final_slices, axis=1)
targets_np = sample_targets.numpy()

# Top left: Velocity u
ax = axes[0, 0]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=targets_np[:, 0], cmap='RdBu_r', s=6)
ax.set_title('Ground Truth: Velocity u', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
plt.colorbar(sc, ax=ax, shrink=0.8)

# Top right: Velocity v
ax = axes[0, 1]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=targets_np[:, 1], cmap='RdBu_r', s=6)
ax.set_title('Ground Truth: Velocity v', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
plt.colorbar(sc, ax=ax, shrink=0.8)

# Bottom left: Pressure
ax = axes[1, 0]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=targets_np[:, 2], cmap='viridis', s=6)
ax.set_title('Ground Truth: Pressure p', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
plt.colorbar(sc, ax=ax, shrink=0.8)

# Bottom right: Learned slices
ax = axes[1, 1]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=dominant_slice, cmap='tab10', s=6, alpha=0.8)
ax.set_title(f'Learned Slice Assignments (M={NUM_SLICES})', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
cbar = plt.colorbar(sc, ax=ax, shrink=0.8)
cbar.set_label('Slice ID')

plt.suptitle('Physics Fields vs Learned Slice Partitioning', fontsize=13, y=1.01)
plt.tight_layout()
plt.show()

print("The model learns to group points with similar physical behavior into slices!")

## Summary

**What we demonstrated:**
1. Trained a Transolver model on Stokes flow data for 20 epochs
2. Visualized how **slice assignments evolve** during training
3. The model learns to partition the mesh based on **physical behavior**, not just geometry

**Key observations:**
- Initially (untrained): slices are essentially random
- After training: slices align with physics (e.g., inlet, wake, boundaries)
- This learned partitioning enables efficient O(N·M) attention instead of O(N²)
**References:**
- [Transolver Paper](https://arxiv.org/abs/2402.02366)
