### Neural ODE for Collision-Coalescence Parameterization

Plans:
- ODEFunc: 4 → 50 (Tanh) → 50 (Tanh) → 50 (Tanh) → 4 (linear)
- Solver: **Configurable** - RK4 (fixed-step, faster) or Dopri5 (adaptive, more accurate)
- Training: Variable-length trajectories with masking also can continue training from saved checkpoints

**Solver Comparison:**
- **RK4**: Fixed-step (dt=20s), faster training
- **Dopri5**: Adaptive step size

In [12]:
!pip install torchdiffeq

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint

import numpy as np
import pickle
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [13]:
# Hyperparameters
config = {
    'dt': 20.0,                  # Time step in seconds
    'hidden_size': 50,           # Neurons per hidden layer
    'n_layers': 3,               # Number of hidden layers
    'batch_size': 4,             # Trajectories per batch
    'learning_rate': 1e-4,       # Adam learning rate
    'n_epochs': 100,             # Training epochs
    'val_every': 3,              # Validation frequency
    'max_grad_norm': 1.0,        # Gradient clipping threshold
    'ode_solver': 'rk4',         # Choose: 'rk4' or 'dopri5'
    'rtol': 1e-4,                # Relative tolerance (only for adaptive solvers like dopri5)
    'atol': 1e-6,                # Absolute tolerance (only for adaptive solvers like dopri5)
}

# Data paths
data_dir = Path('/home/jovyan/cloud_microphysics/data/')
train_path = data_dir / 'train_trajectories.pkl'
val_path = data_dir / 'val_trajectories.pkl'
stats_path = data_dir / 'moment_normalization_stats.pkl'

In [14]:
# Load preprocessed trajectories
with open(train_path, 'rb') as f:
    train_trajectories = pickle.load(f)

with open(val_path, 'rb') as f:
    val_trajectories = pickle.load(f)

with open(stats_path, 'rb') as f:
    norm_stats = pickle.load(f)

print(f"\nLoaded {len(train_trajectories)} training trajectories")
print(f"Loaded {len(val_trajectories)} validation trajectories")
print(f"\nNormalization stats loaded: {list(norm_stats.keys())}")

# Inspect sample trajectory
sample_traj = train_trajectories[0]
print(f"\nSample trajectory keys: {list(sample_traj.keys())}")
print(f"Moments shape: {sample_traj['moments_scaled'].shape}")
print(f"Trajectory length: {sample_traj['length']}")


Loaded 575 training trajectories
Loaded 144 validation trajectories

Normalization stats loaded: ['asinh_scales', 'moment_scaler_mean', 'moment_scaler_std', 'moment_scaler']

Sample trajectory keys: ['moments', 'length', 'ic_idx', 'moments_scaled']
Moments shape: (174, 4)
Trajectory length: 174


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [15]:
class TrajectoryDataset(Dataset):
    """Dataset for trajectory prediction with variable-length sequences.

    Returns:
        Dictionary containing:
            - initial_state: (4,) tensor - Initial moment values
            - trajectory: (length, 4) tensor - Full trajectory (variable length)
            - length: int - Number of timesteps in this trajectory
    """

    def __init__(self, trajectories, max_timesteps=None):
        """
        Args:
            trajectories: List of trajectory dictionaries
            max_timesteps: Optional maximum timesteps (for capping very long trajectories)
        """
        self.trajectories = trajectories
        self.max_timesteps = max_timesteps

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]

        # Get scaled moments and actual length
        moments_scaled = traj['moments_scaled']
        length = traj['length']

        # Optionally cap at max_timesteps
        if self.max_timesteps:
            length = min(length, self.max_timesteps)
            moments_scaled = moments_scaled[:length]

        # Initial state
        initial_state = moments_scaled[0]  # (4,)

        # Full trajectory
        trajectory = moments_scaled  # (length, 4)

        return {
            'initial_state': torch.tensor(initial_state, dtype=torch.float32),
            'trajectory': torch.tensor(trajectory, dtype=torch.float32),
            'length': length
        }


def collate_variable_length(batch):
    """Collate function to pad variable-length sequences and create masks.

    Args:
        batch: List of dictionaries from TrajectoryDataset

    Returns:
        Dictionary containing:
            - initial_states: (batch_size, 4) tensor
            - trajectories: (batch_size, max_len, 4) tensor (padded)
            - masks: (batch_size, max_len) bool tensor (True = valid, False = padded)
            - lengths: (batch_size,) tensor of actual lengths
    """
    lengths = [item['length'] for item in batch]
    max_len = max(lengths)
    batch_size = len(batch)

    # Initialize padded tensors
    padded_trajectories = torch.zeros(batch_size, max_len, 4)
    initial_states = torch.stack([item['initial_state'] for item in batch])

    # Create mask (True = valid, False = padded)
    mask = torch.zeros(batch_size, max_len, dtype=torch.bool)

    # Fill in actual data and mask
    for i, item in enumerate(batch):
        length = item['length']
        padded_trajectories[i, :length] = item['trajectory']
        mask[i, :length] = True

    return {
        'initial_states': initial_states,
        'trajectories': padded_trajectories,
        'masks': mask,
        'lengths': torch.tensor(lengths, dtype=torch.long)
    }


def masked_mse_loss(predictions, targets, mask):
    """Compute MSE loss only on valid (non-padded) timesteps.

    Args:
        predictions: (batch_size, n_timesteps, 4) tensor
        targets: (batch_size, n_timesteps, 4) tensor
        mask: (batch_size, n_timesteps) bool tensor (True = valid)

    Returns:
        Scalar loss value
    """
    # Expand mask to match feature dimension
    mask = mask.unsqueeze(-1)  # (batch_size, n_timesteps, 1)

    # Compute squared errors
    squared_errors = (predictions - targets) ** 2  # (batch_size, n_timesteps, 4)

    # Apply mask and compute mean over valid elements only
    masked_errors = squared_errors * mask
    loss = masked_errors.sum() / mask.sum()

    return loss


# Create datasets (no truncation - use all available timesteps)
train_dataset = TrajectoryDataset(train_trajectories, max_timesteps=None)
val_dataset = TrajectoryDataset(val_trajectories, max_timesteps=None)

# Create dataloaders with custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    drop_last=True,
    collate_fn=collate_variable_length
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_variable_length
)

print(f"\nDataset statistics:")
print(f"Training trajectories: {len(train_dataset)}")
print(f"Validation trajectories: {len(val_dataset)}")

# Show trajectory length distribution
train_lengths = [traj['length'] for traj in train_trajectories]
val_lengths = [traj['length'] for traj in val_trajectories]
print(f"\nTraining trajectory lengths:")
print(f"  Min: {min(train_lengths)}, Max: {max(train_lengths)}, Mean: {np.mean(train_lengths):.1f}")
print(f"Validation trajectory lengths:")
print(f"  Min: {min(val_lengths)}, Max: {max(val_lengths)}, Mean: {np.mean(val_lengths):.1f}")


Dataset statistics:
Training trajectories: 575
Validation trajectories: 144

Training trajectory lengths:
  Min: 59, Max: 3599, Mean: 459.2
Validation trajectory lengths:
  Min: 59, Max: 2399, Mean: 448.5


In [16]:
class ODEFunc(nn.Module):
    """Neural network that defines the derivative function dM/dt = f(M).
    
    Architecture: 4 → 50(Tanh) → 50(Tanh) → 50(Tanh) → 4(linear)
    """
    
    def __init__(self, hidden_size=50, n_layers=3):
        super(ODEFunc, self).__init__()
        
        layers = []
        
        # Input layer: 4 moments → hidden_size
        layers.append(nn.Linear(4, hidden_size))
        layers.append(nn.Tanh())
        
        # Hidden layers: hidden_size → hidden_size
        for _ in range(n_layers - 1):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.Tanh())
        
        # Output layer: hidden_size → 4 derivatives
        layers.append(nn.Linear(hidden_size, 4))
        
        self.net = nn.Sequential(*layers)
    
    def forward(self, t, y):
        """
        Args:
            t: Scalar time (required by odeint, but not used for time-invariant dynamics)
            y: State tensor of shape (batch_size, 4)
        
        Returns:
            dy/dt: Derivative tensor of shape (batch_size, 4)
        """
        return self.net(y)


class NeuralODE(nn.Module):
    """Wrapper class that integrates the ODE using torchdiffeq."""

    def __init__(self, ode_func, method='dopri5', rtol=1e-4, atol=1e-6):
        """
        Args:
            ode_func: ODEFunc instance
            method: ODE solver method ('rk4' or 'dopri5')
            rtol: Relative tolerance for adaptive solvers (ignored for rk4)
            atol: Absolute tolerance for adaptive solvers (ignored for rk4)
        """
        super(NeuralODE, self).__init__()
        self.ode_func = ode_func
        self.method = method
        self.rtol = rtol
        self.atol = atol

    def forward(self, initial_state, t_span):
        """
        Args:
            initial_state: Tensor of shape (batch_size, 4)
            t_span: Tensor of time points to evaluate, shape (n_timesteps,)

        Returns:
            trajectory: Tensor of shape (n_timesteps, batch_size, 4)
        """
        # For fixed-step solvers like rk4, rtol/atol are ignored
        if self.method == 'rk4':
            trajectory = odeint(
                self.ode_func,
                initial_state,
                t_span,
                method=self.method
            )
        else:
            # For adaptive solvers like dopri5, use rtol/atol
            trajectory = odeint(
                self.ode_func,
                initial_state,
                t_span,
                method=self.method,
                rtol=self.rtol,
                atol=self.atol
            )
        return trajectory

In [17]:
# Initialize model
ode_func = ODEFunc(hidden_size=config['hidden_size'], n_layers=config['n_layers']).to(device)
model = NeuralODE(
    ode_func,
    method=config['ode_solver'],
    rtol=config['rtol'],
    atol=config['atol']
).to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {n_params:,} trainable parameters")

# Print solver info
if config['ode_solver'] == 'rk4':
    print(f"ODE Solver: {config['ode_solver']} (fixed-step, dt={config['dt']}s)")
else:
    print(f"ODE Solver: {config['ode_solver']} (adaptive, rtol={config['rtol']}, atol={config['atol']})")

print(model.ode_func.net)

Model initialized with 5,554 trainable parameters
ODE Solver: rk4 (fixed-step, dt=20.0s)
Sequential(
  (0): Linear(in_features=4, out_features=50, bias=True)
  (1): Tanh()
  (2): Linear(in_features=50, out_features=50, bias=True)
  (3): Tanh()
  (4): Linear(in_features=50, out_features=50, bias=True)
  (5): Tanh()
  (6): Linear(in_features=50, out_features=4, bias=True)
)


In [18]:
# Loss function and optimizer
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'epochs': []
}

In [19]:
RESUME_TRAINING = True  # Set to True to continue from checkpoint
CHECKPOINT_PATH = 'best_model.pt'  # Path to checkpoint file
# The checkpoint contains model weights, optimizer state,
# epoch number, and training history

# Initialize training variables
start_epoch = 0
best_val_loss = float('inf')

# Load checkpoint if resuming
if RESUME_TRAINING and Path(CHECKPOINT_PATH).exists():
    print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
    
    # Load model and optimizer states
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Load training progress
    start_epoch = checkpoint['epoch']
    best_val_loss = checkpoint.get('val_loss', float('inf'))
    
    # Load history if available (for continuous loss curves)
    if 'history' in checkpoint:
        history = checkpoint['history']
        print(f"Loaded training history with {len(history['train_loss'])} epochs")
    
    print(f"Resuming from Checkpoint")
    print(f"  Epoch: {checkpoint['epoch']}")
    if 'train_loss' in checkpoint:
        print(f"  Train Loss: {checkpoint['train_loss']:.6f}")
    print(f"  Val Loss: {checkpoint['val_loss']:.6f}")
    print(f"  Config: {checkpoint['config']}")
    print(f"\nWill continue training from epoch {start_epoch + 1} to {config['n_epochs']}")
        
elif RESUME_TRAINING and not Path(CHECKPOINT_PATH).exists():
    print(f"WARNING: RESUME_TRAINING=True but checkpoint not found at {CHECKPOINT_PATH}")
    
else:
    print(f"Starting fresh training for {config['n_epochs']} epochs")

Loading checkpoint from best_model.pt...
Loaded training history with 3 epochs
Resuming from Checkpoint
  Epoch: 3
  Train Loss: 4.511445
  Val Loss: 4.648163
  Config: {'dt': 20.0, 'hidden_size': 50, 'n_layers': 3, 'batch_size': 4, 'learning_rate': 0.0001, 'n_epochs': 100, 'val_every': 3, 'max_grad_norm': 1.0, 'ode_solver': 'rk4', 'rtol': 0.0001, 'atol': 1e-06}

Will continue training from epoch 4 to 100


In [20]:
def train_epoch(model, train_loader, optimizer, device, dt=20.0):
    """Train for one epoch with variable-length sequences."""
    model.train()
    total_loss = 0.0
    n_nan_batches = 0
    
    for batch in tqdm(train_loader, desc='Training', leave=False):
        initial_states = batch['initial_states'].to(device)  # (batch_size, 4)
        target_trajectories = batch['trajectories'].to(device)  # (batch_size, max_len, 4)
        masks = batch['masks'].to(device)  # (batch_size, max_len)
        
        # Create dynamic t_span for this batch based on max length
        max_len = target_trajectories.shape[1]
        t_span = torch.arange(0, max_len, dtype=torch.float32, device=device) * dt
        
        # Forward pass: integrate ODE
        pred_trajectory = model(initial_states, t_span)  # (max_len, batch_size, 4)
        
        # Reshape: (max_len, batch_size, 4) -> (batch_size, max_len, 4)
        pred_trajectory = pred_trajectory.permute(1, 0, 2)
        
        # Check for NaN/Inf in predictions
        if torch.isnan(pred_trajectory).any() or torch.isinf(pred_trajectory).any():
            print(f"\nWARNING: NaN/Inf detected in predictions! Skipping batch.")
            n_nan_batches += 1
            continue
        
        # Compute masked loss (only on valid timesteps)
        loss = masked_mse_loss(pred_trajectory, target_trajectories, masks)
        
        # Check for NaN/Inf in loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"\nWARNING: NaN/Inf loss detected! Skipping batch.")
            n_nan_batches += 1
            continue
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / max(len(train_loader) - n_nan_batches, 1)
    if n_nan_batches > 0:
        print(f"\nSkipped {n_nan_batches}/{len(train_loader)} batches due to NaN/Inf")
    return avg_loss


def validate(model, val_loader, device, dt=20.0):
    """Validate the model with variable-length sequences."""
    model.eval()
    total_loss = 0.0
    n_nan_batches = 0
    
    with torch.no_grad():
        for batch in val_loader:
            initial_states = batch['initial_states'].to(device)
            target_trajectories = batch['trajectories'].to(device)
            masks = batch['masks'].to(device)
            
            # Create dynamic t_span for this batch
            max_len = target_trajectories.shape[1]
            t_span = torch.arange(0, max_len, dtype=torch.float32, device=device) * dt
            
            # Forward pass
            pred_trajectory = model(initial_states, t_span)
            pred_trajectory = pred_trajectory.permute(1, 0, 2)
            
            # Check for NaN/Inf
            if torch.isnan(pred_trajectory).any() or torch.isinf(pred_trajectory).any():
                n_nan_batches += 1
                continue
            
            # Compute masked loss
            loss = masked_mse_loss(pred_trajectory, target_trajectories, masks)
            if not (torch.isnan(loss) or torch.isinf(loss)):
                total_loss += loss.item()
            else:
                n_nan_batches += 1
    
    avg_loss = total_loss / max(len(val_loader) - n_nan_batches, 1)
    if n_nan_batches > 0:
        print(f"  Validation: Skipped {n_nan_batches}/{len(val_loader)} batches due to NaN/Inf")
    return avg_loss

In [None]:
# Training loop
print(f"\nStarting training from epoch {start_epoch + 1}...\n")

for epoch in range(start_epoch, config['n_epochs']):
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, device, dt=config['dt'])
    history['train_loss'].append(train_loss)
    
    # Validate
    if (epoch + 1) % config['val_every'] == 0:
        val_loss = validate(model, val_loader, device, dt=config['dt'])
        history['val_loss'].append(val_loss)
        history['epochs'].append(epoch + 1)
        
        print(f"Epoch {epoch+1}/{config['n_epochs']} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'train_loss': train_loss,
                'config': config,
                'history': history  # Save history for continuous loss curves
            }, 'best_model.pt')
            print(f"  → Saved best model (val_loss: {val_loss:.6f})")
    else:
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config['n_epochs']} - Train Loss: {train_loss:.6f}")

print(f"\nTraining complete! Best validation loss: {best_val_loss:.6f}")
print(f"Trained for {len(history['train_loss'])} total epochs")


Starting training from epoch 4...



                                                           

Epoch 6/100 - Train Loss: 4.920209, Val Loss: 4.251313
  → Saved best model (val_loss: 4.251313)


                                                           

Epoch 9/100 - Train Loss: 4.541486, Val Loss: 3.672335
  → Saved best model (val_loss: 3.672335)


                                                           

Epoch 10/100 - Train Loss: 4.600614


                                                           

Epoch 12/100 - Train Loss: 4.794000, Val Loss: 7.030597


Training:   6%|▋         | 9/143 [00:10<02:20,  1.05s/it]

In [None]:
# Plot training curves
fig, ax = plt.subplots(figsize=(10, 5))

# Train loss
ax.plot(range(1, len(history['train_loss']) + 1), history['train_loss'], 
        label='Train Loss', alpha=0.7)

# Val loss
ax.plot(history['epochs'], history['val_loss'], 
        label='Val Loss', marker='o', linewidth=2)

ax.set_xlabel('Epoch')
ax.set_ylabel('MSE Loss')
ax.set_title('Training Progress')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

plt.tight_layout()
plt.show()

print(f"Final train loss: {history['train_loss'][-1]:.6f}")
print(f"Final val loss: {history['val_loss'][-1]:.6f}")
print(f"Best val loss: {best_val_loss:.6f}")

### Load Best Model for Evaluation

In [None]:
# Load best model checkpoint
checkpoint = torch.load('best_model.pt', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']} (val_loss: {checkpoint['val_loss']:.6f})")

In [None]:
# Select validation samples to visualize
n_samples = 10
sample_indices = np.random.choice(len(val_dataset), n_samples, replace=True)

# Moment names
moment_names = ['qc (cloud water)', 'nc (cloud droplets)', 'qr (rain water)', 'nr (rain drops)']

# Create subplots
fig, axes = plt.subplots(n_samples, 4, figsize=(16, 3*n_samples))
if n_samples == 1:
    axes = axes[np.newaxis, :]

with torch.no_grad():
    for i, sample_idx in enumerate(sample_indices):
        # Get data
        sample = val_dataset[sample_idx]
        initial_state = sample['initial_state'].unsqueeze(0).to(device)  # (1, 4)
        target_trajectory = sample['trajectory'].cpu().numpy()  # (length, 4)
        length = sample['length']
        
        # Create t_span for this specific trajectory
        t_span_sample = torch.arange(0, length, dtype=torch.float32, device=device) * config['dt']
        
        # Predict
        pred_trajectory = model(initial_state, t_span_sample)  # (length, 1, 4)
        pred_trajectory = pred_trajectory.squeeze(1).cpu().numpy()  # (length, 4)
        
        # Time axis
        time_axis = t_span_sample.cpu().numpy()
        
        # Plot each moment
        for j in range(4):
            ax = axes[i, j]
            ax.plot(time_axis, target_trajectory[:, j], 'k-', label='True', linewidth=2)
            ax.plot(time_axis, pred_trajectory[:, j], 'r--', label='Predicted', linewidth=2)
            
            if i == 0:
                ax.set_title(moment_names[j], fontsize=12)
            if i == n_samples - 1:
                ax.set_xlabel('Time (s)', fontsize=10)
            if j == 0:
                ax.set_ylabel(f'Sample {i+1}\n(len={length})\nNormalized Value', fontsize=10)
            
            ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Compute metrics on full validation set with variable-length sequences
all_predictions = []
all_targets = []
all_masks = []

model.eval()
with torch.no_grad():
    for batch in val_loader:
        initial_states = batch['initial_states'].to(device)
        target_trajectories = batch['trajectories']  # Keep on CPU for now
        masks = batch['masks']  # (batch_size, max_len)
        
        # Create dynamic t_span for this batch
        max_len = target_trajectories.shape[1]
        t_span = torch.arange(0, max_len, dtype=torch.float32, device=device) * config['dt']
        
        # Predict
        pred_trajectory = model(initial_states, t_span)  # (max_len, batch_size, 4)
        pred_trajectory = pred_trajectory.permute(1, 0, 2).cpu()  # (batch_size, max_len, 4)
        
        all_predictions.append(pred_trajectory.numpy())
        all_targets.append(target_trajectories.numpy())
        all_masks.append(masks.numpy())

# Concatenate all batches
all_predictions = np.concatenate(all_predictions, axis=0)  # (n_samples, max_len, 4)
all_targets = np.concatenate(all_targets, axis=0)  # (n_samples, max_len, 4)
all_masks = np.concatenate(all_masks, axis=0)  # (n_samples, max_len)

# Compute metrics per moment (only on valid timesteps)
print("\nValidation Metrics (normalized space, masked):")

for i, name in enumerate(moment_names):
    # Extract valid predictions and targets for this moment
    valid_mask = all_masks  # (n_samples, max_len)
    
    # MSE (only on valid timesteps)
    squared_errors = (all_predictions[:, :, i] - all_targets[:, :, i]) ** 2
    mse = (squared_errors * valid_mask).sum() / valid_mask.sum()
    
    # MAE (only on valid timesteps)
    abs_errors = np.abs(all_predictions[:, :, i] - all_targets[:, :, i])
    mae = (abs_errors * valid_mask).sum() / valid_mask.sum()
    
    # R² score (only on valid timesteps)
    valid_targets = all_targets[:, :, i][valid_mask]
    valid_predictions = all_predictions[:, :, i][valid_mask]
    ss_res = np.sum((valid_targets - valid_predictions) ** 2)
    ss_tot = np.sum((valid_targets - np.mean(valid_targets)) ** 2)
    r2 = 1 - (ss_res / ss_tot)
    
    print(f"{name:25s} - MSE: {mse:.6f}, MAE: {mae:.6f}, R²: {r2:.6f}")

# Overall metrics
overall_squared_errors = (all_predictions - all_targets) ** 2
all_masks_expanded = all_masks[:, :, np.newaxis]  # (n_samples, max_len, 1)
overall_mse = (overall_squared_errors * all_masks_expanded).sum() / all_masks_expanded.sum()

overall_abs_errors = np.abs(all_predictions - all_targets)
overall_mae = (overall_abs_errors * all_masks_expanded).sum() / all_masks_expanded.sum()

print(f"{'Overall':25s} - MSE: {overall_mse:.6f}, MAE: {overall_mae:.6f}")

# Report total number of timesteps evaluated
total_valid_timesteps = all_masks.sum()
total_possible_timesteps = all_masks.size
print(f"\nEvaluated {int(total_valid_timesteps):,} valid timesteps out of {total_possible_timesteps:,} total ({100*total_valid_timesteps/total_possible_timesteps:.1f}% non-padded)")

In [None]:
# Compute MAE as a function of time (accounting for variable-length sequences)
# We'll compute MAE at each time index, but only average over trajectories that have data at that index

max_len = all_predictions.shape[1]
mae_over_time = np.zeros((max_len, 4))
count_over_time = np.zeros(max_len)

for t in range(max_len):
    # Find which samples have valid data at time t
    valid_at_t = all_masks[:, t]  # (n_samples,)
    count_over_time[t] = valid_at_t.sum()
    
    if count_over_time[t] > 0:
        # Compute MAE for each moment at time t
        for i in range(4):
            errors = np.abs(all_predictions[:, t, i] - all_targets[:, t, i])
            mae_over_time[t, i] = (errors * valid_at_t).sum() / count_over_time[t]

# Create time axis in seconds
time_axis = np.arange(max_len) * config['dt']

# Plot
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

# Plot MAE over time
for i, name in enumerate(moment_names):
    ax1.plot(time_axis, mae_over_time[:, i], label=name, linewidth=2)

ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('Mean Absolute Error', fontsize=12)
ax1.set_title('Prediction Error Over Time', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot number of trajectories available at each time
ax2.plot(time_axis, count_over_time, 'k-', linewidth=2)
ax2.fill_between(time_axis, 0, count_over_time, alpha=0.3)
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylabel('Number of Trajectories', fontsize=12)
ax2.set_title('Data Availability Over Time', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTime range: 0 to {time_axis[count_over_time > 0][-1]:.0f} seconds")
print(f"At least one trajectory extends to: {time_axis[-1]:.0f} seconds")