# Using Walrus for Navier-Stokes Prediction

This notebook demonstrates how to use the Walrus foundation model to make predictions on Navier-Stokes spectral data.

## Overview

**Workflow:**
1. Load Navier-Stokes data
2. Convert to Well format
3. Load Walrus model and weights
4. Prepare data in Walrus input format
5. Run autoregressive rollout predictions
6. Visualize and compare predictions vs ground truth

**What Walrus Does:**
- Takes initial timesteps as input
- Predicts future evolution of the flow field
- Uses physics-informed architecture to maintain consistency
- Handles boundary conditions automatically

## Step 1: Import Libraries

In [1]:
import numpy as np
import torch
import h5py
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

# Walrus imports
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf

# For working with Well format data
from the_well.data.datasets import WellMetadata

print("Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

Libraries imported successfully
PyTorch version: 2.5.1
CUDA available: False


## Step 2: Load and Convert Navier-Stokes Data

First, we load the Navier-Stokes spectral data and convert it to Well format.
This follows the same process as in `walrus_example_2_NavierStokesToWell.ipynb`.

In [2]:
def Navier_Stokes_Spectral(n_sims, data_dist):
    """
    Load Navier-Stokes spectral simulation data.
    
    Args:
        n_sims: Number of simulations to load
        data_dist: 'ID' for in-distribution or 'OOD' for out-of-distribution
    
    Returns:
        u: Horizontal velocity [B, Nt, Nx, Ny]
        v: Vertical velocity [B, Nt, Nx, Ny]
        p: Pressure [B, Nt, Nx, Ny]
        rho: Density [B, Nt, Nx, Ny]
        x: X-coordinates
        y: Y-coordinates  
        dt: Time step
    """
    data_loc = '/Users/Vicky/Documents/UKAEA/Code/Uncertainty_Quantification/PDE_Residuals/Neural_PDE/Data'
    
    # Load appropriate dataset based on distribution
    if data_dist == 'ID':
        data = np.load(data_loc + '/NS_Spectral_combined.npz')
    elif data_dist == 'OOD':
        data = np.load(data_loc + '/NS_Spectral_combined_pitagora_OOD_nu_1e-2.npz')

    # Extract fields and convert to float32
    u = data['u'].astype(np.float32)[:n_sims]  # Horizontal velocity
    v = data['v'].astype(np.float32)[:n_sims]  # Vertical velocity
    p = data['p'].astype(np.float32)[:n_sims]  # Pressure
    rho = np.ones_like(u)  # Density = 1 (constant)
    
    # Extract coordinates and time step
    x = data['x']
    y = x  # Assuming square domain
    dt = data['dt']

    return u, v, p, rho, x, y, dt

In [3]:
# Load a small subset for demonstration
n_sims = 3  # Number of simulations to use
data_dist = 'ID'  # Use in-distribution data

# Load the raw data
u, v, p, rho, x_coords, y_coords, dt = Navier_Stokes_Spectral(n_sims, data_dist)

print(f"Loaded Navier-Stokes data:")
print(f"  Horizontal velocity (u): {u.shape}")
print(f"  Vertical velocity (v): {v.shape}")
print(f"  Pressure (p): {p.shape}")
print(f"  Spatial grid: {len(x_coords)} × {len(y_coords)}")
print(f"  Time step: {dt}")

Loaded Navier-Stokes data:
  Horizontal velocity (u): (3, 50, 100, 100)
  Vertical velocity (v): (3, 50, 100, 100)
  Pressure (p): (3, 50, 100, 100)
  Spatial grid: 100 × 100
  Time step: 0.01


## Step 3: Download Walrus Model Weights

Download the pretrained Walrus model and configuration from HuggingFace.

In [4]:
# Create directories for model files
checkpoint_base_path = "./checkpoints/"
config_base_path = "./configs/"
Path(checkpoint_base_path).mkdir(exist_ok=True)
Path(config_base_path).mkdir(exist_ok=True)

# Download model weights and config from HuggingFace
# Only download if not already present
config_path = Path(config_base_path) / "extended_config.yaml"
checkpoint_path = Path(checkpoint_base_path) / "walrus.pt"

if not config_path.exists():
    print("Downloading model configuration...")
    !wget -q https://huggingface.co/polymathic-ai/walrus/resolve/main/extended_config.yaml -O {config_path}
    print("✓ Configuration downloaded")
else:
    print("✓ Configuration already exists")

if not checkpoint_path.exists():
    print("Downloading model weights (~5GB, this may take a few minutes)...")
    !wget https://huggingface.co/polymathic-ai/walrus/resolve/main/walrus.pt -O {checkpoint_path}
    print("✓ Model weights downloaded")
else:
    print("✓ Model weights already exist")

✓ Configuration already exists
✓ Model weights already exist


## Step 4: Load Walrus Configuration and Model

Load the Hydra configuration and instantiate the Walrus model.

In [5]:
# Load configuration using Hydra
# The config file contains model architecture, training settings, and data specs
config = OmegaConf.load(config_path)

print("Model configuration loaded:")
print(f"  Model type: {config.model._target_}")
print(f"  Number of parameters: ~1.3B")

# Note: The config structure may vary, so we'll set our own parameters
print(f"\nConfiguration note:") 
print(f"  We'll manually set input/output sequence lengths for our NS data")
print(f"  Walrus is flexible and can handle variable sequence lengths")


Model configuration loaded:
  Model type: walrus.models.IsotropicModel
  Number of parameters: ~1.3B

Configuration note:
  We'll manually set input/output sequence lengths for our NS data
  Walrus is flexible and can handle variable sequence lengths


In [6]:
# Get pretrained field mapping
# This maps physical field names to embedding indices in the model
pretrained_field_to_index = config.data.field_index_map_override

print("Pretrained field mapping (showing first 10):")
for i, (field, idx) in enumerate(list(pretrained_field_to_index.items())[:10]):
    print(f"  {field}: {idx}")
print(f"  ... ({len(pretrained_field_to_index)} total fields)")

Pretrained field mapping (showing first 10):
  closed_boundary: 0
  open_boundary: 1
  bias_correction: 2
  pressure: 3
  velocity_x: 4
  velocity_y: 5
  velocity_z: 6
  zeros_like_density: 7
  speed_of_sound: 8
  concentration: 9
  ... (67 total fields)


In [7]:
# Add our Navier-Stokes fields to the mapping
# We need to map: velocity (vector), pressure (scalar), density (scalar)

# Create new field mapping with our NS fields
field_to_index = dict(pretrained_field_to_index)

# Add fields if not already present
# Note: We're using generic field names that might already be in the pretrained model
if "velocity" not in field_to_index:
    max_idx = max(field_to_index.values())
    field_to_index["velocity_x"] = max_idx + 1
    field_to_index["velocity_y"] = max_idx + 2
    print(f"Added velocity_x at index {max_idx + 1}")
    print(f"Added velocity_y at index {max_idx + 2}")
    max_idx += 2
else:
    print("Velocity fields already in pretrained mapping")
    max_idx = max(field_to_index.values())

if "pressure" not in field_to_index:
    field_to_index["pressure"] = max_idx + 1
    print(f"Added pressure at index {max_idx + 1}")
    max_idx += 1
else:
    print("Pressure already in pretrained mapping")

if "density" not in field_to_index:
    field_to_index["density"] = max_idx + 1
    print(f"Added density at index {max_idx + 1}")
else:
    print("Density already in pretrained mapping")

n_states = max(field_to_index.values()) + 1
print(f"\nTotal number of field states: {n_states}")

Added velocity_x at index 67
Added velocity_y at index 68
Pressure already in pretrained mapping
Density already in pretrained mapping

Total number of field states: 69


In [8]:
# Instantiate the model
# The model is a large transformer-based architecture for PDE prediction
print("Initializing Walrus model...")

model = instantiate(
    config.model,
    n_states=n_states,  # Total number of field embeddings
)

print(f"✓ Model initialized")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

Initializing Walrus model...


  from .autonotebook import tqdm as notebook_tqdm
  init.kaiming_uniform(


✓ Model initialized
  Total parameters: 1,287,640,843


In [9]:
# Load pretrained weights
print("Loading pretrained weights...")

checkpoint = torch.load(checkpoint_path, map_location='cpu')

# The checkpoint contains the full training state
# We only need the model weights
if 'state_dict' in checkpoint:
    state_dict = checkpoint['state_dict']
else:
    state_dict = checkpoint

print(f"Checkpoint keys (first 10): {list(state_dict.keys())[:10]}")

# Check if we actually need to align the checkpoint
# In many cases, the pretrained model already has embeddings for many fields
# and we can just use it directly
if n_states > len(pretrained_field_to_index):
    print(f"\nNote: Model expects {n_states} field states, pretrained has {len(pretrained_field_to_index)}")
    print("We have added custom fields. The model will use random initialization for new field embeddings.")
    print("For best performance, consider fine-tuning on your dataset.")
    
    # For inference with new fields, we can skip alignment and just load with strict=False
    # This will use pretrained weights for existing parameters and random init for new ones
    print("\nLoading weights (strict=False to allow new field embeddings)...")
else:
    print("\nLoading pretrained weights (all fields match)...")

# Load weights into model
# strict=False allows loading even if there are missing or extra keys
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

if missing_keys:
    print(f"  Missing keys: {len(missing_keys)} (these will use random initialization)")
    if len(missing_keys) <= 5:
        for key in missing_keys:
            print(f"    - {key}")
if unexpected_keys:
    print(f"  Unexpected keys: {len(unexpected_keys)} (these will be ignored)")
    if len(unexpected_keys) <= 5:
        for key in unexpected_keys:
            print(f"    - {key}")

print("✓ Weights loaded successfully")

# Set model to evaluation mode (disables dropout, etc.)
model.eval()

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"✓ Model moved to {device}")

Loading pretrained weights...


  checkpoint = torch.load(checkpoint_path, map_location='cpu')


Checkpoint keys (first 10): ['app']

Note: Model expects 69 field states, pretrained has 67
We have added custom fields. The model will use random initialization for new field embeddings.
For best performance, consider fine-tuning on your dataset.

Loading weights (strict=False to allow new field embeddings)...
  Missing keys: 857 (these will use random initialization)
  Unexpected keys: 1 (these will be ignored)
    - app
✓ Weights loaded successfully
✓ Model moved to cpu


## Step 5: Prepare Data in Walrus Input Format

Walrus expects data in a specific dictionary format. We need to prepare:
- Input fields: Initial timesteps
- Boundary conditions: How the domain is bounded
- Field indices: Which fields we're using
- Padded field mask: Marks which fields are real vs padding
- Metadata: Information about the simulation

In [10]:
# Configuration for the prediction
T_in = 3      # Number of input timesteps (how much history to use)
T_out = 10    # Number of output timesteps to predict
T_total = T_in + T_out  # Total timesteps we need from data

B, Nt, Nx, Ny = u.shape

# Check we have enough timesteps
if Nt < T_total:
    print(f"Warning: Data has {Nt} timesteps but we need {T_total}")
    print(f"Reducing output timesteps to {Nt - T_in}")
    T_out = Nt - T_in
    T_total = Nt

print(f"Prediction setup:")
print(f"  Input timesteps: {T_in}")
print(f"  Output timesteps: {T_out}")
print(f"  Batch size: {B}")
print(f"  Spatial resolution: {Nx} × {Ny}")

Prediction setup:
  Input timesteps: 3
  Output timesteps: 10
  Batch size: 3
  Spatial resolution: 100 × 100


In [11]:
# Prepare input data
# Stack velocity components: [B, Nt, Nx, Ny, 2]
velocity = np.stack([u, v], axis=-1)

# For Walrus, we need to organize fields by their tensor rank
# and add a "depth" dimension (for 2D data, depth=1)

# Extract input and output timesteps
u_input = u[:, :T_in, :, :, np.newaxis]      # [B, T_in, Nx, Ny, 1] (add depth dim)
v_input = v[:, :T_in, :, :, np.newaxis]      # [B, T_in, Nx, Ny, 1]
p_input = p[:, :T_in, :, :, np.newaxis]      # [B, T_in, Nx, Ny, 1]
rho_input = rho[:, :T_in, :, :, np.newaxis]  # [B, T_in, Nx, Ny, 1]

# Ground truth for comparison
u_output = u[:, T_in:T_total, :, :, np.newaxis]      # [B, T_out, Nx, Ny, 1]
v_output = v[:, T_in:T_total, :, :, np.newaxis]      # [B, T_out, Nx, Ny, 1]
p_output = p[:, T_in:T_total, :, :, np.newaxis]      # [B, T_out, Nx, Ny, 1]
rho_output = rho[:, T_in:T_total, :, :, np.newaxis]  # [B, T_out, Nx, Ny, 1]

print(f"Input shapes:")
print(f"  u_input: {u_input.shape}")
print(f"  v_input: {v_input.shape}")
print(f"  p_input: {p_input.shape}")
print(f"\nOutput (ground truth) shapes:")
print(f"  u_output: {u_output.shape}")
print(f"  v_output: {v_output.shape}")
print(f"  p_output: {p_output.shape}")

Input shapes:
  u_input: (3, 3, 100, 100, 1)
  v_input: (3, 3, 100, 100, 1)
  p_input: (3, 3, 100, 100, 1)

Output (ground truth) shapes:
  u_output: (3, 10, 100, 100, 1)
  v_output: (3, 10, 100, 100, 1)
  p_output: (3, 10, 100, 100, 1)


In [12]:
# Stack all fields together
# Walrus expects: [B, T, H, W, D, C] where C is the number of fields
# For our NS data: [B, T_in, Nx, Ny, 1, 4] (u, v, p, rho)

input_fields = np.concatenate([
    u_input,
    v_input, 
    p_input,
    rho_input
], axis=-1)  # [B, T_in, Nx, Ny, 1, 4]

output_fields = np.concatenate([
    u_output,
    v_output,
    p_output,
    rho_output
], axis=-1)  # [B, T_out, Nx, Ny, 1, 4]

print(f"Stacked field shapes:")
print(f"  input_fields: {input_fields.shape}   # [B, T_in, H, W, D, C]")
print(f"  output_fields: {output_fields.shape}  # [B, T_out, H, W, D, C]")

Stacked field shapes:
  input_fields: (3, 3, 100, 100, 4)   # [B, T_in, H, W, D, C]
  output_fields: (3, 10, 100, 100, 4)  # [B, T_out, H, W, D, C]


In [13]:
# Convert to PyTorch tensors and move to device
input_fields_tensor = torch.from_numpy(input_fields).float().to(device)
output_fields_tensor = torch.from_numpy(output_fields).float().to(device)

print(f"Tensors created and moved to {device}")

Tensors created and moved to cpu


In [14]:
# Define boundary conditions
# For spectral NS, we typically have periodic BCs in all directions
# BC format: [B, n_dims, 2] where n_dims=2 for 2D, and 2 for [lower, upper] bounds
# BC codes: WALL=0, OPEN=1, PERIODIC=2

boundary_conditions = torch.tensor(
    [[[2, 2], [2, 2]]]  # [[x_lower, x_upper], [y_lower, y_upper]] all PERIODIC
).repeat(B, 1, 1).to(device)  # Repeat for each trajectory in batch

print(f"Boundary conditions: {boundary_conditions.shape}  # [B, 2, 2]")
print(f"  BC type: PERIODIC (code=2) in all directions")

Boundary conditions: torch.Size([3, 2, 2])  # [B, 2, 2]
  BC type: PERIODIC (code=2) in all directions


In [15]:
# Create field indices
# Maps each field in our data to its embedding index in the model

# We're using 4 fields: u, v, p, rho
# Let's use velocity indices if available, otherwise use custom indices
if "velocity_x" in field_to_index:
    u_idx = field_to_index["velocity_x"]
    v_idx = field_to_index["velocity_y"]
else:
    # Use generic indices
    u_idx = 0
    v_idx = 1

p_idx = field_to_index.get("pressure", 2)
rho_idx = field_to_index.get("density", 3)

field_indices = torch.tensor([u_idx, v_idx, p_idx, rho_idx]).to(device)

print(f"Field indices: {field_indices}")
print(f"  u (horizontal velocity): {u_idx}")
print(f"  v (vertical velocity): {v_idx}")
print(f"  p (pressure): {p_idx}")
print(f"  rho (density): {rho_idx}")

Field indices: tensor([67, 68,  3, 28])
  u (horizontal velocity): 67
  v (vertical velocity): 68
  p (pressure): 3
  rho (density): 28


In [16]:
# Create padded field mask
# For 2D data, we have real fields (u, v, p, rho) and NO padding
# Mask is True for real fields, False for padded/fake fields

# In our case, all 4 fields are real (no padding needed for 2D)
padded_field_mask = torch.tensor([True, True, True, True]).to(device)

print(f"Padded field mask: {padded_field_mask}")
print(f"  All fields are real (no padding)")

Padded field mask: tensor([True, True, True, True])
  All fields are real (no padding)


In [18]:
# Create metadata
# This provides information about the simulation for the model

metadata = WellMetadata(
    dataset_name="navier_stokes_spectral",
    n_spatial_dims=2,  # 2D simulation
    grid_type="cartesian",
    spatial_resolution=(Nx, Ny),  # Grid dimensions
    scalar_names=[],  # No standalone scalar parameters
    constant_scalar_names=[],  # No constant scalars
    constant_field_names={},  # No constant fields
    field_names={
        0: ["pressure", "density"],  # t0_fields (scalars)
        1: ["velocity"],              # t1_fields (vectors)
        2: []                         # t2_fields (tensors)
    },
    boundary_condition_types=["PERIODIC", "PERIODIC"],  # BCs for x, y
    n_files=1,  # Single file
    n_trajectories_per_file=[B],  # B trajectories
    n_steps_per_trajectory=T_total,  # Total timesteps available
)

print("Metadata created:")
print(f"  Dataset: {metadata.dataset_name}")
print(f"  Spatial dims: {metadata.n_spatial_dims}")
print(f"  Grid type: {metadata.grid_type}")
print(f"  Spatial resolution: {metadata.spatial_resolution}")
print(f"  Field names: {metadata.field_names}")


Metadata created:
  Dataset: navier_stokes_spectral
  Spatial dims: 2
  Grid type: cartesian
  Spatial resolution: (100, 100)
  Field names: {0: ['pressure', 'density'], 1: ['velocity'], 2: []}


In [19]:
# Create the final data dictionary for Walrus
# This is the format the model expects

walrus_data = {
    "input_fields": input_fields_tensor,        # [B, T_in, H, W, D, C]
    "output_fields": output_fields_tensor,      # [B, T_out, H, W, D, C] (for evaluation)
    "boundary_conditions": boundary_conditions, # [B, n_dims, 2]
    "field_indices": field_indices,            # [C]
    "padded_field_mask": padded_field_mask,    # [C]
    "metadata": metadata,                      # WellMetadata object
}

print("✓ Walrus data dictionary created")
print(f"\nData dictionary keys: {list(walrus_data.keys())}")

✓ Walrus data dictionary created

Data dictionary keys: ['input_fields', 'output_fields', 'boundary_conditions', 'field_indices', 'padded_field_mask', 'metadata']


## Step 6: Run Autoregressive Rollout

Now we'll use Walrus to predict the future evolution of the flow field.

**Autoregressive Rollout:**
1. Use initial timesteps (0, 1, 2) to predict timestep 3
2. Use timesteps (1, 2, 3_predicted) to predict timestep 4
3. Use timesteps (2, 3_predicted, 4_predicted) to predict timestep 5
4. Continue until we've predicted all T_out timesteps

This is called "autoregressive" because predictions feed back as inputs.

In [22]:
# Run inference (no gradient computation needed)
print("Running autoregressive rollout...")
print(f"  Predicting {T_out} timesteps")
print(f"  Using {T_in} input timesteps")

with torch.no_grad():  # Disable gradient tracking for inference
    # Initialize: current input is the initial timesteps
    current_input = input_fields_tensor  # Should be [B, T_in, H, W, C] for 2D data
    
    print(f"Input shape: {current_input.shape}")
    
    # Check the actual number of dimensions
    if current_input.dim() == 5:  # [B, T, H, W, C] - 2D data without explicit depth dim
        # The model expects input in [T, B, H, W, C] format (time-first) for 2D
        current_input = current_input.permute(1, 0, 2, 3, 4)  # [T_in, B, H, W, C]
    elif current_input.dim() == 6:  # [B, T, H, W, D, C] - 3D data
        # The model expects input in [T, B, H, W, D, C] format (time-first)
        current_input = current_input.permute(1, 0, 2, 3, 4, 5)  # [T_in, B, H, W, D, C]
    
    print(f"Permuted input shape: {current_input.shape}")
    
    # Store predictions
    predictions = []
    
    # Rollout loop: predict one timestep at a time
    for step in tqdm(range(T_out), desc="Predicting"):
        # Predict next timestep
        # Model takes the current T_in timesteps and predicts the next one
        output = model(
            x=current_input,
            state_labels=field_indices,
            bcs=boundary_conditions,
            metadata=metadata,
            train=False,  # Evaluation mode
        )  # Returns [1, B, H, W, (D,) C] (single timestep prediction)
        
        # Extract the prediction
        next_timestep = output  # [1, B, H, W, (D,) C]
        
        # Store prediction (permute back to [B, 1, H, W, (D,) C] for consistency)
        if next_timestep.dim() == 5:  # 2D case
            predictions.append(next_timestep.permute(1, 0, 2, 3, 4))
        elif next_timestep.dim() == 6:  # 3D case
            predictions.append(next_timestep.permute(1, 0, 2, 3, 4, 5))
        
        # Update input for next iteration (autoregressive)
        # Concatenate current input (drop oldest timestep) with new prediction
        current_input = torch.cat([
            current_input[1:, ...],  # Drop first timestep: [T_in-1, B, H, W, (D,) C]
            next_timestep            # Add prediction: [1, B, H, W, (D,) C]
        ], dim=0)  # Result: [T_in, B, H, W, (D,) C]
    
    # Concatenate all predictions along time dimension
    # predictions is a list of [B, 1, H, W, (D,) C] tensors
    predictions = torch.cat(predictions, dim=1)  # [B, T_out, H, W, (D,) C]

print(f"\n✓ Rollout complete!")
print(f"  Predictions shape: {predictions.shape}")


Running autoregressive rollout...
  Predicting 10 timesteps
  Using 3 input timesteps
Input shape: torch.Size([3, 3, 100, 100, 4])
Permuted input shape: torch.Size([3, 3, 100, 100, 4])


Predicting:   0%|          | 0/10 [00:00<?, ?it/s]


AssertionError: 

## Step 7: Evaluate Predictions

Compare predictions against ground truth using standard metrics.

In [None]:
# Calculate error metrics
# We'll compute MSE and relative L2 error for each field

def compute_metrics(pred, target, field_name):
    """
    Compute MSE and relative L2 error.
    
    Args:
        pred: Predictions [B, T, H, W, D]
        target: Ground truth [B, T, H, W, D]
        field_name: Name of the field for printing
    """
    # Mean Squared Error
    mse = torch.mean((pred - target) ** 2).item()
    
    # Relative L2 error: ||pred - target||_2 / ||target||_2
    rel_l2 = (torch.norm(pred - target) / torch.norm(target)).item()
    
    print(f"{field_name}:")
    print(f"  MSE: {mse:.6e}")
    print(f"  Relative L2: {rel_l2:.6f}")
    
    return mse, rel_l2

# Extract individual fields from predictions and ground truth
# Remember: fields are stacked as [u, v, p, rho]
pred_u = predictions[:, :, :, :, :, 0]  # [B, T_out, H, W, D]
pred_v = predictions[:, :, :, :, :, 1]
pred_p = predictions[:, :, :, :, :, 2]
pred_rho = predictions[:, :, :, :, :, 3]

target_u = output_fields_tensor[:, :, :, :, :, 0]
target_v = output_fields_tensor[:, :, :, :, :, 1]
target_p = output_fields_tensor[:, :, :, :, :, 2]
target_rho = output_fields_tensor[:, :, :, :, :, 3]

print("Prediction Metrics:")
print("=" * 50)
compute_metrics(pred_u, target_u, "Horizontal velocity (u)")
compute_metrics(pred_v, target_v, "Vertical velocity (v)")
compute_metrics(pred_p, target_p, "Pressure (p)")
compute_metrics(pred_rho, target_rho, "Density (rho)")

## Step 8: Visualize Results

Let's visualize the predictions compared to ground truth for one trajectory.

In [None]:
# Select trajectory and timestep to visualize
traj_idx = 0  # First trajectory
time_idx = T_out - 1  # Last predicted timestep

# Move tensors to CPU and convert to numpy for plotting
pred_u_np = pred_u[traj_idx, time_idx, :, :, 0].cpu().numpy()
pred_v_np = pred_v[traj_idx, time_idx, :, :, 0].cpu().numpy()
pred_p_np = pred_p[traj_idx, time_idx, :, :, 0].cpu().numpy()

target_u_np = target_u[traj_idx, time_idx, :, :, 0].cpu().numpy()
target_v_np = target_v[traj_idx, time_idx, :, :, 0].cpu().numpy()
target_p_np = target_p[traj_idx, time_idx, :, :, 0].cpu().numpy()

# Compute velocity magnitude
pred_vel_mag = np.sqrt(pred_u_np**2 + pred_v_np**2)
target_vel_mag = np.sqrt(target_u_np**2 + target_v_np**2)

print(f"Visualizing trajectory {traj_idx}, timestep {time_idx + T_in} (prediction {time_idx + 1}/{T_out})")

In [None]:
# Create comparison plot
fig, axes = plt.subplots(3, 3, figsize=(15, 14))

# Row 1: Horizontal velocity (u)
vmin_u = min(pred_u_np.min(), target_u_np.min())
vmax_u = max(pred_u_np.max(), target_u_np.max())

im0 = axes[0, 0].imshow(target_u_np.T, origin='lower', cmap='RdBu_r', 
                         vmin=vmin_u, vmax=vmax_u, aspect='auto')
axes[0, 0].set_title('Ground Truth: u (horizontal velocity)')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('y')
plt.colorbar(im0, ax=axes[0, 0])

im1 = axes[0, 1].imshow(pred_u_np.T, origin='lower', cmap='RdBu_r',
                         vmin=vmin_u, vmax=vmax_u, aspect='auto')
axes[0, 1].set_title('Prediction: u (horizontal velocity)')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('y')
plt.colorbar(im1, ax=axes[0, 1])

error_u = np.abs(pred_u_np - target_u_np)
im2 = axes[0, 2].imshow(error_u.T, origin='lower', cmap='hot', aspect='auto')
axes[0, 2].set_title('Absolute Error: u')
axes[0, 2].set_xlabel('x')
axes[0, 2].set_ylabel('y')
plt.colorbar(im2, ax=axes[0, 2])

# Row 2: Vertical velocity (v)
vmin_v = min(pred_v_np.min(), target_v_np.min())
vmax_v = max(pred_v_np.max(), target_v_np.max())

im3 = axes[1, 0].imshow(target_v_np.T, origin='lower', cmap='RdBu_r',
                         vmin=vmin_v, vmax=vmax_v, aspect='auto')
axes[1, 0].set_title('Ground Truth: v (vertical velocity)')
axes[1, 0].set_xlabel('x')
axes[1, 0].set_ylabel('y')
plt.colorbar(im3, ax=axes[1, 0])

im4 = axes[1, 1].imshow(pred_v_np.T, origin='lower', cmap='RdBu_r',
                         vmin=vmin_v, vmax=vmax_v, aspect='auto')
axes[1, 1].set_title('Prediction: v (vertical velocity)')
axes[1, 1].set_xlabel('x')
axes[1, 1].set_ylabel('y')
plt.colorbar(im4, ax=axes[1, 1])

error_v = np.abs(pred_v_np - target_v_np)
im5 = axes[1, 2].imshow(error_v.T, origin='lower', cmap='hot', aspect='auto')
axes[1, 2].set_title('Absolute Error: v')
axes[1, 2].set_xlabel('x')
axes[1, 2].set_ylabel('y')
plt.colorbar(im5, ax=axes[1, 2])

# Row 3: Pressure (p)
vmin_p = min(pred_p_np.min(), target_p_np.min())
vmax_p = max(pred_p_np.max(), target_p_np.max())

im6 = axes[2, 0].imshow(target_p_np.T, origin='lower', cmap='plasma',
                         vmin=vmin_p, vmax=vmax_p, aspect='auto')
axes[2, 0].set_title('Ground Truth: p (pressure)')
axes[2, 0].set_xlabel('x')
axes[2, 0].set_ylabel('y')
plt.colorbar(im6, ax=axes[2, 0])

im7 = axes[2, 1].imshow(pred_p_np.T, origin='lower', cmap='plasma',
                         vmin=vmin_p, vmax=vmax_p, aspect='auto')
axes[2, 1].set_title('Prediction: p (pressure)')
axes[2, 1].set_xlabel('x')
axes[2, 1].set_ylabel('y')
plt.colorbar(im7, ax=axes[2, 1])

error_p = np.abs(pred_p_np - target_p_np)
im8 = axes[2, 2].imshow(error_p.T, origin='lower', cmap='hot', aspect='auto')
axes[2, 2].set_title('Absolute Error: p')
axes[2, 2].set_xlabel('x')
axes[2, 2].set_ylabel('y')
plt.colorbar(im8, ax=axes[2, 2])

plt.tight_layout()
plt.show()

print(f"Timestep {time_idx + T_in} (t = {(time_idx + T_in) * dt:.4f})")

## Step 9: Visualize Time Evolution

Let's see how the predictions evolve over time compared to ground truth.

In [None]:
# Plot time series of velocity magnitude at a single point
# Choose a point in the middle of the domain
x_pt = Nx // 2
y_pt = Ny // 2

# Extract time series for all predicted timesteps
pred_vel_mag_series = np.sqrt(
    pred_u[traj_idx, :, x_pt, y_pt, 0].cpu().numpy()**2 + 
    pred_v[traj_idx, :, x_pt, y_pt, 0].cpu().numpy()**2
)

target_vel_mag_series = np.sqrt(
    target_u[traj_idx, :, x_pt, y_pt, 0].cpu().numpy()**2 + 
    target_v[traj_idx, :, x_pt, y_pt, 0].cpu().numpy()**2
)

# Time values for prediction window
pred_times = np.arange(T_in, T_total) * dt

# Create plot
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

ax.plot(pred_times, target_vel_mag_series, 'o-', label='Ground Truth', linewidth=2)
ax.plot(pred_times, pred_vel_mag_series, 's--', label='Prediction', linewidth=2, alpha=0.7)

ax.set_xlabel('Time', fontsize=12)
ax.set_ylabel('Velocity Magnitude', fontsize=12)
ax.set_title(f'Time Evolution at Point ({x_pt}, {y_pt})', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Time series at spatial location ({x_pt}, {y_pt})")

## Summary

In this notebook, we demonstrated how to use the Walrus foundation model for Navier-Stokes prediction:

### What We Did:
✓ Loaded Navier-Stokes spectral simulation data  
✓ Downloaded and initialized the Walrus model (~1.3B parameters)  
✓ Prepared data in the Walrus input format  
✓ Ran autoregressive rollout for multi-step prediction  
✓ Evaluated predictions using MSE and relative L2 error  
✓ Visualized predictions vs ground truth  

### Key Concepts:
- **Autoregressive Rollout**: Predictions feed back as inputs for next timestep
- **Field Indices**: Map physical fields to model embeddings
- **Boundary Conditions**: Tell the model how the domain is bounded
- **Padded Field Mask**: Distinguish real fields from padding
- **Well Format**: Standardized structure for PDE data

### Next Steps:
1. **Fine-tuning**: Train the model on your specific dataset for better performance
2. **Longer Rollouts**: Test prediction quality over longer time horizons
3. **Different BCs**: Experiment with wall or open boundary conditions
4. **Ensemble Predictions**: Run multiple predictions with different random seeds
5. **Data Assimilation**: Incorporate new observations during rollout

### Performance Notes:
- Walrus is pretrained on diverse PDE datasets, so it may need fine-tuning for optimal performance on your specific Navier-Stokes data
- The model uses reversible instance normalization (RevIN) to handle different scales
- Predictions typically degrade over longer rollouts (error accumulation)
- GPU acceleration is highly recommended for larger grids or longer rollouts