# Notebook 2: Training Transolver on Stokes Flow

**Goal:** Train a Transolver model and visualize how it learns to partition the mesh into physics-based slices.

## Outline
1. Load Stokes Flow Dataset
2. Transolver Model Implementation
3. Train for 20 Epochs
4. Visualize Learned Slice Assignments

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os

from utils import download_stokes_dataset, load_stokes_sample, get_num_samples

torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Stokes Flow Dataset

We'll use the same dataset from Lab 2: Stokes flow around obstacles with varying geometries.

# Download and load dataset
download_stokes_dataset()

# Load multiple samples for training
num_samples = max(1, min(get_num_samples(), 50))  # Use up to 50 samples
print(f"Loading {num_samples} samples...")

train_data = []
for i in range(num_samples):
    coords, u, v, p = load_stokes_sample(sample_idx=i)
    train_data.append({
        'coords': torch.tensor(coords, dtype=torch.float32),
        'targets': torch.tensor(np.stack([u, v, p], axis=1), dtype=torch.float32)
    })

# Use first sample for visualization
sample_coords = train_data[0]['coords']
sample_targets = train_data[0]['targets']
N = len(sample_coords)
print(f"✓ Loaded {num_samples} samples, each with ~{N} mesh points")
print(f"  Input: coordinates (N, 2)")
print(f"  Output: u, v, p (N, 3)")

## 2. Transolver Model

The model learns to partition mesh points into slices based on physical behavior.

In [None]:
class PhysicsAttentionLayer(nn.Module):
    """Physics-Attention layer from Transolver."""
    
    def __init__(self, dim, num_slices=8, num_heads=4):
        super().__init__()
        self.dim = dim
        self.num_slices = num_slices
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        # Slice projection (this learns to partition the mesh!)
        self.slice_proj = nn.Linear(dim, num_slices)
        
        # Multi-head attention projections
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        
        self.norm = nn.LayerNorm(dim)
        self._last_slice_weights = None  # Store for visualization
    
    def forward(self, x, return_slices=False):
        B, N, D = x.shape
        residual = x
        x = self.norm(x)
        
        # Step 1: SLICE - compute soft assignments
        slice_logits = self.slice_proj(x)  # (B, N, M)
        slice_weights = F.softmax(slice_logits, dim=-1)  # (B, N, M)
        self._last_slice_weights = slice_weights.detach()  # Store for visualization
        
        # Step 2: AGGREGATE - compress to M tokens
        slice_weights_t = slice_weights.transpose(1, 2)  # (B, M, N)
        z = torch.bmm(slice_weights_t, x)  # (B, M, D)
        z = z / (slice_weights_t.sum(dim=-1, keepdim=True) + 1e-8)
        
        # Step 3: ATTEND - M×M attention (the cheap part!)
        M = self.num_slices
        Q = self.q_proj(z).view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(z).view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(z).view(B, M, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        z_prime = torch.matmul(attn, V)
        z_prime = z_prime.transpose(1, 2).reshape(B, M, D)
        z_prime = self.out_proj(z_prime)
        
        # Step 4: DESLICE - broadcast back to N points
        out = torch.bmm(slice_weights, z_prime)  # (B, N, D)
        
        if return_slices:
            return residual + out, slice_weights
        return residual + out

print("PhysicsAttentionLayer defined")

In [None]:
class SimpleTransolver(nn.Module):
    """Simplified Transolver for Stokes flow."""
    
    def __init__(self, in_dim=2, out_dim=3, hidden_dim=64, num_layers=3, num_slices=8):
        super().__init__()
        self.num_slices = num_slices
        
        self.embedding = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.layers = nn.ModuleList([
            PhysicsAttentionLayer(hidden_dim, num_slices=num_slices, num_heads=4)
            for _ in range(num_layers)
        ])
        
        self.decoder = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        h = self.embedding(x)
        for layer in self.layers:
            h = layer(h)
        return self.decoder(h)
    
    def get_slice_weights(self, x):
        """Get slice assignments from the first layer (for visualization)."""
        h = self.embedding(x)
        _, slice_weights = self.layers[0](h, return_slices=True)
        return slice_weights

# Create model
NUM_SLICES = 8  # We'll visualize these
model = SimpleTransolver(in_dim=2, out_dim=3, hidden_dim=64, num_layers=3, num_slices=NUM_SLICES).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model created: {n_params:,} parameters")
print(f"  - Hidden dim: 64")
print(f"  - Layers: 3")
print(f"  - Slices: {NUM_SLICES}")

## 3. Train for 20 Epochs

We'll train the model and track how the slice assignments evolve during training.

In [None]:
# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

NUM_EPOCHS = 100
losses = []
slice_history = []  # Store slice assignments at different epochs

# Get initial slice assignments (before training)
model.eval()
with torch.no_grad():
    x = sample_coords.unsqueeze(0).to(device)
    initial_slices = model.get_slice_weights(x)[0].cpu().numpy()
    slice_history.append(('Epoch 0 (untrained)', initial_slices))

print("Training Transolver...")
print("-" * 50)

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0.0
    
    # Shuffle data each epoch
    indices = np.random.permutation(len(train_data))
    
    for idx in indices:
        data = train_data[idx]
        coords = data['coords'].unsqueeze(0).to(device)  # (1, N, 2)
        targets = data['targets'].unsqueeze(0).to(device)  # (1, N, 3)
        
        optimizer.zero_grad()
        pred = model(coords)
        loss = criterion(pred, targets)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_data)
    losses.append(avg_loss)
    
    # Save slice assignments at key epochs
    if (epoch + 1) in [10, 25, 50, 100]:
        model.eval()
        with torch.no_grad():
            x = sample_coords.unsqueeze(0).to(device)
            slices = model.get_slice_weights(x)[0].cpu().numpy()
            slice_history.append((f'Epoch {epoch+1}', slices))
    
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{NUM_EPOCHS}: Loss = {avg_loss:.6f}")

print("-" * 50)
print(f"✓ Training complete! Final loss: {losses[-1]:.6f}")

## 4. Visualize Learned Slice Assignments

Let's see how the model learned to partition the mesh during training.

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, NUM_EPOCHS+1), losses, 'b-o', linewidth=2, markersize=4)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.semilogy(range(1, NUM_EPOCHS+1), losses, 'b-o', linewidth=2, markersize=4)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss (log scale)')
plt.title('Training Loss (Log Scale)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Visualize how slice assignments evolved during training
coords_np = sample_coords.numpy()
num_snapshots = len(slice_history)

fig, axes = plt.subplots(1, num_snapshots, figsize=(4*num_snapshots, 4))

for idx, (title, slice_weights) in enumerate(slice_history):
    ax = axes[idx] if num_snapshots > 1 else axes
    
    # Get dominant slice for each point
    dominant_slice = np.argmax(slice_weights, axis=1)
    
    # Plot mesh colored by slice
    scatter = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=dominant_slice, 
                        cmap='tab10', s=6, alpha=0.7)
    ax.set_title(title, fontsize=11)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')

plt.suptitle('Evolution of Learned Slice Assignments During Training', fontsize=12, y=1.02)
plt.tight_layout()
plt.show()

print("Notice how slices become more spatially coherent as training progresses!")

In [None]:
# Final comparison: Learned slices vs actual physics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Get final slice assignments
final_slices = slice_history[-1][1]
dominant_slice = np.argmax(final_slices, axis=1)
targets_np = sample_targets.numpy()

# Top left: Velocity u
ax = axes[0, 0]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=targets_np[:, 0], cmap='RdBu_r', s=6)
ax.set_title('Ground Truth: Velocity u', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
plt.colorbar(sc, ax=ax, shrink=0.8)

# Top right: Velocity v
ax = axes[0, 1]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=targets_np[:, 1], cmap='RdBu_r', s=6)
ax.set_title('Ground Truth: Velocity v', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
plt.colorbar(sc, ax=ax, shrink=0.8)

# Bottom left: Pressure
ax = axes[1, 0]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=targets_np[:, 2], cmap='viridis', s=6)
ax.set_title('Ground Truth: Pressure p', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
plt.colorbar(sc, ax=ax, shrink=0.8)

# Bottom right: Learned slices
ax = axes[1, 1]
sc = ax.scatter(coords_np[:, 0], coords_np[:, 1], c=dominant_slice, cmap='tab10', s=6, alpha=0.8)
ax.set_title(f'Learned Slice Assignments (M={NUM_SLICES})', fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')
cbar = plt.colorbar(sc, ax=ax, shrink=0.8)
cbar.set_label('Slice ID')

plt.suptitle('Physics Fields vs Learned Slice Partitioning', fontsize=13, y=1.01)
plt.tight_layout()
plt.show()

print("The model learns to group points with similar physical behavior into slices!")

# Visualize each of the 8 slices separately
print(f"\n--- Visualizing all {NUM_SLICES} slices separately ---")

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

slice_colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', 
                '#ff7f00', '#ffff33', '#a65628', '#f781bf']

for slice_id in range(NUM_SLICES):
    ax = axes[slice_id]
    mask = dominant_slice == slice_id
    count = np.sum(mask)
    
    # Plot all points faded
    ax.scatter(coords_np[:, 0], coords_np[:, 1], c='lightgray', s=3, alpha=0.3)
    # Highlight this slice's points
    if count > 0:
        ax.scatter(coords_np[mask, 0], coords_np[mask, 1], 
                  c=slice_colors[slice_id], s=8, alpha=0.8)
    
    ax.set_title(f'Slice {slice_id}: {count} points ({100*count/len(coords_np):.1f}%)', 
                fontsize=10, color=slice_colors[slice_id] if count > 0 else 'gray')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_aspect('equal')

plt.suptitle(f'All {NUM_SLICES} Learned Slices (After {NUM_EPOCHS} Epochs)', fontsize=13, y=1.01)
plt.tight_layout()
plt.show()

# Print slice statistics
print("\nSlice Statistics:")
for i in range(NUM_SLICES):
    count = np.sum(dominant_slice == i)
    print(f"  Slice {i}: {count:4d} points ({100*count/len(coords_np):5.1f}%)")

## Summary

**What we demonstrated:**
1. Trained a Transolver model on Stokes flow data for 20 epochs
2. Visualized how **slice assignments evolve** during training
3. The model learns to partition the mesh based on **physical behavior**, not just geometry

**Key observations:**
- Initially (untrained): slices are essentially random
- After training: slices align with physics (e.g., inlet, wake, boundaries)
- This learned partitioning enables efficient O(N·M) attention instead of O(N²)

**References:**
- [Transolver Paper](https://arxiv.org/abs/2402.02366)