In [1]:
## Step 1: Download Model Weights and Config

import os

# Create directories to store checkpoints and configs
checkpoint_base_path = "./checkpoints/"
config_base_path = "./configs/"
os.makedirs(checkpoint_base_path, exist_ok=True)
os.makedirs(config_base_path, exist_ok=True)

# Download the model configuration file (YAML format)
# This contains model architecture details, training hyperparameters, and field mappings
!wget  https://huggingface.co/polymathic-ai/walrus/resolve/main/extended_config.yaml \
    -O {config_base_path}/extended_config.yaml

# Download the pretrained model weights (~4.8GB)
# This is the actual neural network parameters trained on The Well datasets
!wget  https://huggingface.co/polymathic-ai/walrus/resolve/main/walrus.pt \
    -O {checkpoint_base_path}/walrus.pt

--2026-01-20 14:05:27--  https://huggingface.co/polymathic-ai/walrus/resolve/main/extended_config.yaml
Resolving huggingface.co (huggingface.co)... 2600:9000:28fd:de00:17:b174:6d00:93a1, 2600:9000:28fd:800:17:b174:6d00:93a1, 2600:9000:28fd:1200:17:b174:6d00:93a1, ...
Connecting to huggingface.co (huggingface.co)|2600:9000:28fd:de00:17:b174:6d00:93a1|:443... connected.
HTTP request sent, awaiting response... 307 Temporary Redirect
Location: /api/resolve-cache/models/polymathic-ai/walrus/f8fd578e7d8a15d2e510d32d5952b9eddc37548c/extended_config.yaml?%2Fpolymathic-ai%2Fwalrus%2Fresolve%2Fmain%2Fextended_config.yaml=&etag=%223eb6c57e518c935eba9ade2e0b7a3b3381f491b6%22 [following]
--2026-01-20 14:05:27--  https://huggingface.co/api/resolve-cache/models/polymathic-ai/walrus/f8fd578e7d8a15d2e510d32d5952b9eddc37548c/extended_config.yaml?%2Fpolymathic-ai%2Fwalrus%2Fresolve%2Fmain%2Fextended_config.yaml=&etag=%223eb6c57e518c935eba9ade2e0b7a3b3381f491b6%22
Reusing existing connection to [huggingfa

OSError: [Errno 5] Input/output error

In [2]:
## Load configuration and checkpoint 
import numpy as np 
import torch
import copy
from walrus.models import IsotropicModel
from walrus.data.well_to_multi_transformer import ChannelsFirstWithTimeFormatter
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf, open_dict
from walrus.utils.experiment_utils import align_checkpoint_with_field_to_index_map
from the_well.data.datasets import WellMetadata
from the_well.data.utils import flatten_field_names

# Set paths to downloaded files
checkpoint_path = f"{checkpoint_base_path}/walrus.pt"
checkpoint_config_path = f"{config_base_path}/extended_config.yaml"

# Load the checkpoint (model weights)
# weights_only=True for security - only loads tensor weights, not arbitrary Python objects
# ["app"]["model"] extracts just the model state dict from the full checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["app"]["model"]

# Load the configuration using OmegaConf (hierarchical config management)
config = OmegaConf.load(checkpoint_config_path)

print("Configuration loaded successfully!")
print(f"Model has {config.model.processor_blocks} processor blocks")
print(f"Hidden dimension: {config.model.hidden_dim}")

  from .autonotebook import tqdm as notebook_tqdm


Configuration loaded successfully!
Model has 40 processor blocks
Hidden dimension: 1408


In [3]:
## Load and Convert Dataset 
def Navier_Stokes_Spectral(n_sims, data_dist):
    """
    Load Navier-Stokes spectral simulation data.
    
    Args:
        n_sims: Number of simulations to load
        data_dist: 'ID' for in-distribution or 'OOD' for out-of-distribution
    
    Returns:
        u: Horizontal velocity [B, Nt, Nx, Ny]
        v: Vertical velocity [B, Nt, Nx, Ny]
        p: Pressure [B, Nt, Nx, Ny]
        rho: Density [B, Nt, Nx, Ny]
        x: X-coordinates
        y: Y-coordinates  
        dt: Time step
    """
    data_loc = '/Users/Vicky/Documents/UKAEA/Code/Uncertainty_Quantification/PDE_Residuals/Neural_PDE/Data'
    
    # Load appropriate dataset based on distribution
    if data_dist == 'ID':
        data = np.load(data_loc + '/NS_Spectral_combined.npz')
    elif data_dist == 'OOD':
        data = np.load(data_loc + '/NS_Spectral_combined_pitagora_OOD_nu_1e-2.npz')

    # Extract fields and convert to float32
    u = data['u'].astype(np.float32)[:n_sims]  # Horizontal velocity
    v = data['v'].astype(np.float32)[:n_sims]  # Vertical velocity
    p = data['p'].astype(np.float32)[:n_sims]  # Pressure
    rho = np.ones_like(u)  # Density = 1 (constant)
    
    # Extract coordinates and time step
    x = data['x']
    y = x  # Assuming square domain
    dt = data['dt']

    return u, v, p, rho, x, y, dt

# Load a small subset for demonstration
n_sims = 3  # Number of simulations to use
data_dist = 'ID'  # Use in-distribution data

# Load the raw data
u, v, p, rho, x_coords, y_coords, dt = Navier_Stokes_Spectral(n_sims, data_dist)

print(f"Loaded Navier-Stokes data:")
print(f"  Horizontal velocity (u): {u.shape}")
print(f"  Vertical velocity (v): {v.shape}")
print(f"  Pressure (p): {p.shape}")
print(f"  Spatial grid: {len(x_coords)} × {len(y_coords)}")
print(f"  Time step: {dt}")

Loaded Navier-Stokes data:
  Horizontal velocity (u): (3, 50, 100, 100)
  Vertical velocity (v): (3, 50, 100, 100)
  Pressure (p): (3, 50, 100, 100)
  Spatial grid: 100 × 100
  Time step: 0.01


In [5]:
# Get the field_to_index_map from the config
# This dictionary maps field names (strings) to integer indices used by the model's embedding layer
field_to_index_map = config.data.field_index_map_override

print("Available fields in pretrained model:")
print(f"Total fields: {len(field_to_index_map)}")
print(f"\nExample fields: {list(field_to_index_map.keys())[:10]}")

# For our synthetic example, we'll use:
# - velocity_x (index 4) - already in pretrained model
# - velocity_y (index 5) - already in pretrained model
# - density (index 28) - already in pretrained model
# - "blubber" - a NEW field we're adding as an example

# Create a copy and add our new field
new_field_to_index_map = dict(field_to_index_map)
new_field_to_index_map["blubber"] = max(field_to_index_map.values()) + 1  # Assign next available index

print(f"\nAdded new field 'blubber' with index: {new_field_to_index_map['blubber']}")

Available fields in pretrained model:
Total fields: 67

Example fields: ['closed_boundary', 'open_boundary', 'bias_correction', 'pressure', 'velocity_x', 'velocity_y', 'velocity_z', 'zeros_like_density', 'speed_of_sound', 'concentration']

Added new field 'blubber' with index: 67


In [None]:
# Initialize model architecture using Hydra's instantiate
# n_states = total number of field types the model needs to handle
model = instantiate(
    config.model,
    n_states=max(new_field_to_index_map.values()) + 1,  # +1 because indices are 0-based
)

# Align the pretrained checkpoint with our new field mapping
# This function:
#   - Copies weights for fields that exist in both old and new mappings
#   - Initializes random weights for new fields (like "blubber")
#   - Ensures the embedding layer has the right size
revised_model_checkpoint = align_checkpoint_with_field_to_index_map(
    checkpoint_state_dict=checkpoint,           # Pretrained weights
    model_state_dict=model.state_dict(),        # Current model structure
    checkpoint_field_to_index_map=field_to_index_map,      # Original mapping
    model_field_to_index_map=new_field_to_index_map,       # New mapping with "blubber"
)

# Load the aligned weights into the model
model.load_state_dict(revised_model_checkpoint)

# Move model to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

# Set to evaluation mode (disables dropout, batch norm training behavior, etc.)
model.eval()

print("Model loaded successfully!")

In [None]:
## Formatter and RevIn 

# Formatter handles tensor shape conversions between Well format and model format
formatter = ChannelsFirstWithTimeFormatter()

# RevIN (Reversible Instance Normalization) handles data normalization
# instantiate() is used because config.trainer.revin is a partial function
revin = instantiate(config.trainer.revin)()  

print("Helper objects initialized successfully!")

In [None]:
## Define Rollout Function 

from walrus.trainer.training import expand_mask_to_match

def rollout_model(model, revin, batch, formatter, max_rollout_steps=200, model_epsilon=1e-5, device=torch.device("cpu")):
    """Rollout the model autoregressively for multiple timesteps.
    
    Args:
        model: The Walrus model
        revin: Reversible normalization object
        batch: Dictionary containing input data
        formatter: Converts between data formats
        max_rollout_steps: Maximum timesteps to predict
        model_epsilon: Small value for numerical stability in normalization
        device: torch device (cuda/cpu)
    
    Returns:
        y_pred_out: Model predictions [B, T, H, W, D, C]
        y_ref: Ground truth reference [B, T, H, W, D, C]
    """
    # Extract metadata (contains dataset info like spatial dims, field names, etc.)
    metadata = batch["metadata"]
    
    # Move all tensors to the target device (GPU/CPU)
    # Skip metadata and boundary_conditions as they're not tensors
    batch = {
        k: v.to(device)
        if k not in {"metadata", "boundary_conditions"}
        else v
        for k, v in batch.items()
    }
    
    # Check if there's a mask field (for masked regions like obstacles)
    # If mask exists, extract it and move to device
    if (
        "mask" in batch["metadata"].constant_field_names[0]
    ):
        mask_index = batch["metadata"].constant_field_names[0].index("mask")
        mask = batch["constant_fields"][..., mask_index : mask_index + 1]
        mask = mask.to(device, dtype=torch.bool)
    else:
        mask = None

    # Format the input data for the model
    # Returns:
    #   inputs: Formatted input fields
    #   y_ref: Reference output (ground truth) for comparison
    inputs, y_ref = formatter.process_input(
        batch,
        causal_in_time=model.causal_in_time,  # Whether model uses causal masking
        predict_delta=True,                    # Predict change (delta) rather than absolute values
        train=False,                           # Inference mode (no training-specific augmentations)
    )

    # Calculate how many timesteps we'll predict
    # Format: inputs is [T, B, C, H, W, D], y_ref is [B, T, H, W, D, C]
    T_in = batch["input_fields"].shape[1]  # Number of input timesteps (e.g., 6)
    max_rollout_steps = max_rollout_steps + (T_in - 1)  # Adjust for input length
    rollout_steps = min(y_ref.shape[1], max_rollout_steps)  # Don't exceed available reference data
    train_rollout_limit = 1  # Only predict 1 step ahead per iteration (autoregressive)

    # Trim reference to match rollout length (saves memory)
    y_ref = y_ref[:, :rollout_steps]
    
    # Create a copy of batch that we'll update with predictions
    moving_batch = copy.deepcopy(batch)
    y_preds = []  # Store all predictions
    
    # Main autoregressive loop
    for i in range(train_rollout_limit - 1, rollout_steps):
        # Format current input window for model
        inputs, _ = formatter.process_input(moving_batch)
        inputs = list(inputs)  # Convert to list for easier manipulation
        
        # Compute normalization statistics from current input window
        # This is done inside torch.no_grad() for efficiency
        with torch.no_grad():
            normalization_stats = revin.compute_stats(
                inputs[0],      # Field data
                metadata,       # Dataset metadata
                epsilon=model_epsilon  # For numerical stability
            )
        
        # Normalize inputs using computed statistics
        # inputs[0] = field data, inputs[1] = constant fields, inputs[2] = boundary conditions
        normalized_inputs = inputs[:]  # Shallow copy
        normalized_inputs[0] = revin.normalize_stdmean(
            normalized_inputs[0], 
            normalization_stats
        )
        
        # Run model forward pass
        # Inputs:
        #   - normalized_inputs[0]: Normalized field data [T, B, C, H, W, D]
        #   - normalized_inputs[1]: Constant fields
        #   - normalized_inputs[2]: Boundary conditions
        #   - metadata: Dataset information
        # Output:
        #   - y_pred: Predicted delta (change) in fields
        y_pred = model(
            normalized_inputs[0],
            normalized_inputs[1],
            normalized_inputs[2].tolist(),
            metadata=metadata,
        )
        
        # For causal models, only keep the last prediction
        if model.causal_in_time:
            y_pred = y_pred[-1:]  # y_pred is [T, B, C, H, W, D], take last T
        
        # Denormalize prediction and add to last input to get absolute value
        # Model predicts delta (change), so: next_state = current_state + delta
        y_pred = (inputs[0][-y_pred.shape[0]:].float()  # Get corresponding input timestep
                  + revin.denormalize_delta(y_pred, normalization_stats))  # Add denormalized delta
        
        # Format output back to Well convention and trim to match reference shape
        y_pred = formatter.process_output(y_pred, metadata)[..., : y_ref.shape[-1]]

        # Apply mask if present (set masked regions to zero)
        if mask is not None:
            mask_pred = expand_mask_to_match(mask, y_pred)
            y_pred.masked_fill_(mask_pred, 0)

        # Zero out padded fields (fields that were added for dimensional consistency)
        y_pred = y_pred.masked_fill(~batch["padded_field_mask"], 0.0)

        # Update moving window for next iteration (autoregressive step)
        # Drop oldest timestep, append new prediction
        if i != rollout_steps - 1:
            moving_batch["input_fields"] = torch.cat(
                [moving_batch["input_fields"][:, 1:],  # Drop first timestep
                 y_pred[:, -1:]],                       # Append prediction
                dim=1
            )
        
        # Store prediction
        # For causal models on first iteration, store all predictions
        # Otherwise, store only the newest prediction
        if model.causal_in_time and i == train_rollout_limit - 1:
            y_preds.append(y_pred)
        else:
            y_preds.append(y_pred[:, -1:])
    
    # Concatenate all predictions along time dimension
    y_pred_out = torch.cat(y_preds, dim=1)
    
    # Apply mask to reference if present
    if mask is not None:
        mask_ref = expand_mask_to_match(mask, y_ref)
        y_ref.masked_fill_(mask_ref, 0)
    
    return y_pred_out, y_ref

print("Rollout function defined!")

In [None]:
## Getting the data ready 

# Define data dimensions
B = 1       # Batch size (number of trajectories)
T_in = 4    # Input timesteps (how many past states to condition on)
T_out = 10  # Output timesteps (how many future states to predict)
H = 100     # Height (spatial resolution in y direction)
W = 100     # Width (spatial resolution in x direction)
D = 1       # Depth (set to 1 for 2D data, would be >1 for 3D)
C_var = 4   # Number of variable fields:
            #   1. velocity_x (real)
            #   2. velocity_y (real)
            #   3. pressure  (real)
            #   4. velocity_z (padding, needed for dimensional consistency)
C_con = 1   # No constant (time-invariant) fields in this example

# Create synthetic trajectory data
# In practice, you would load your own data here
synthetic_trajectory_example = {
    # Input fields: past states that condition the prediction
    # Shape: [batch, time_in, height, width, depth, channels]
    "input_fields": torch.randn(B, T_in, H, W, D, C_var, device=device),
    
    # Output fields: ground truth future states (for evaluation)
    # Shape: [batch, time_out, height, width, depth, channels]
    "output_fields": torch.randn(B, T_out, H, W, D, C_var, device=device),
    
    # Constant fields: time-invariant quantities (e.g., geometry, material properties)
    # Shape: [batch, height, width, depth, const_channels]
    # Empty in this example (C_con = 0)
    "constant_fields": torch.randn(B, H, W, D, C_con, device=device),
    
    # Boundary conditions: encodes domain boundaries
    # Shape: [batch, 3 dimensions, 2 sides (lower/upper)]
    # Values: 0=WALL, 1=OPEN, 2=PERIODIC
    # [[2,2], [2,2], [2,2]] means all periodic (torus topology)
    "boundary_conditions": torch.tensor([[[2, 2], [2, 2], [2, 2]] for _ in range(B)], device=device),
    
    # Padded field mask: indicates which fields are real vs padding
    # True = real field, False = padding
    # [True, True, True, True, False] means first 4 fields are real, 5th is padding
    "padded_field_mask": torch.tensor([True, True, True, True, False], device=device),
    
    # Field indices: maps each field to its embedding index
    # [4, 5, 28, 67, 6] means:
    #   Field 0 → embedding 4 (velocity_x)
    #   Field 1 → embedding 5 (velocity_y)
    #   Field 2 → embedding 28 (density)
    #   Field 3 → embedding 67 (blubber - our new field)
    #   Field 4 → embedding 6 (velocity_z - padding)
    "field_indices": torch.tensor([4, 5, 28, 67, 6], device=device),
    
    # Metadata: describes the dataset properties
    "metadata": WellMetadata(
        dataset_name="synthetic_dataset",       # Name for logging/identification
        n_spatial_dims=3,                       # 3D (even though D=1, we pad to 3D)
        
        # Field organization by rank:
        # 0 = scalars (density, blubber)
        # 1 = vectors (velocity_x, velocity_y, velocity_z)
        # 2 = rank-2 tensors (none in this example)
        field_names={0: ['density', "blubber"], 1: ['velocity_x', 'velocity_y', 'velocity_z'], 2: []},
        
        spatial_resolution=(128, 128, 1),       # Grid size (H, W, D)
        scalar_names=[],                        # Global scalars (e.g., time, parameters)
        constant_field_names={0: [], 1: [], 2: []},  # Constant fields by rank (none here)
        constant_scalar_names=[],               # Constant global scalars
        boundary_condition_types=[],            # Not used in this simplified example
        n_files=[],                             # Not used
        n_trajectories_per_file=[],             # Not used
        n_steps_per_trajectory=[],              # Not used
        grid_type='cartesian'                   # Coordinate system (cartesian, cylindrical, etc.)
    ),
}

print("Synthetic data created!")
print(f"Input shape: {synthetic_trajectory_example['input_fields'].shape}")
print(f"Output shape: {synthetic_trajectory_example['output_fields'].shape}")
print(f"Field indices: {synthetic_trajectory_example['field_indices']}")
print(f"\nField mapping:")
field_list = ['velocity_x', 'velocity_y', 'density', 'blubber', 'velocity_z (padding)']
for i, (idx, name) in enumerate(zip(synthetic_trajectory_example['field_indices'], field_list)):
    is_real = "real" if synthetic_trajectory_example['padded_field_mask'][i] else "padding"
    print(f"  Channel {i}: {name} → embedding {idx} ({is_real})")

In [None]:
## Inference 

# Run inference without computing gradients (faster, less memory)
with torch.no_grad():
    # Ensure mask is on correct device
    synthetic_trajectory_example["padded_field_mask"] = synthetic_trajectory_example["padded_field_mask"].to(device)
    
    # Get the metadata for logging
    fake_metadata = synthetic_trajectory_example["metadata"]
    
    # Run the autoregressive rollout
    # This will predict T_out (10) timesteps, one at a time
    print("Running model rollout...")
    print("This performs autoregressive prediction:")
    print("  1. Use timesteps [0-5] to predict timestep [6]")
    print("  2. Use timesteps [1-6] to predict timestep [7]")
    print("  3. Continue until all 10 timesteps are predicted")
    print()
    
    y_pred, y_ref = rollout_model(
        model,                      # The Walrus model
        revin,                      # Normalization helper
        synthetic_trajectory_example,  # Input data
        formatter,                  # Data format converter
        max_rollout_steps=200,      # Maximum steps to predict
        device=device,              # GPU/CPU
    )
    
    print(f"Prediction complete!")
    print(f"Prediction shape (with padding): {y_pred.shape}")
    print(f"Reference shape (with padding): {y_ref.shape}")
    
    # Remove padded fields (velocity_z) from predictions and reference
    # Only keep the real fields (velocity_x, velocity_y, density, blubber)
    y_pred, y_ref = (
        y_pred[..., synthetic_trajectory_example["padded_field_mask"]],
        y_ref[..., synthetic_trajectory_example["padded_field_mask"]],
    )
    
    # Get human-readable field names for output
    field_names = flatten_field_names(fake_metadata, include_constants=False)
    used_field_names = [
        f
        for i, f in enumerate(field_names)
        if synthetic_trajectory_example["padded_field_mask"][i]
    ]
    
    print(f"\nUsed fields (after removing padding): {used_field_names}")
    print(f"Final prediction shape: {y_pred.shape}")
    print(f"  [batch=1, time=10, height=128, width=128, depth=1, channels=4]")
    print(f"\nYou can now use y_pred for downstream analysis!")
    print(f"For example:")
    print(f"  - Visualize predictions: y_pred[0, :, :, :, 0, i] for field i")
    print(f"  - Compute errors: (y_pred - y_ref).abs().mean()")
    print(f"  - Extract single field: velocity_x = y_pred[..., 0]")