In [1]:
!pip install torchdiffeq



In [2]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from torchdiffeq import odeint_adjoint
import pickle
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [3]:
# Load data
data_dir = Path("./data")
test_data = np.load(data_dir / "test_arr.npz")
data_array = test_data['data']  # shape (time, features, ensemble, ics)
mask_array = test_data['mask']  # same shape as data_array

# Apply masking
data_masked = np.ma.MaskedArray(data_array, mask=mask_array)
print(f"\nMasked data - Min: {data_masked.min():.2e}, Max: {data_masked.max():.2e}")


Masked data - Min: 0.00e+00, Max: 6.54e+08


In [4]:
# Use 80% of the data for training
n_ics_total = data_array.shape[2]
n_ics_train = int(0.8 * n_ics_total)
print(f"Total ICs: {n_ics_total}")

# Split: first 80% for train, rest for validation
train_indices = np.arange(n_ics_train)
val_indices = np.arange(n_ics_train, n_ics_total)
print(f"Train ICs: {len(train_indices)}, Val ICs: {len(val_indices)}")

Total ICs: 100
Train ICs: 80, Val ICs: 20


In [5]:
# Extract features: moments [1-4] and environmental params [14-16]
moment_indices = [1, 2, 3, 4]  # qc, nc, qr, nr
env_indices = [14, 15, 16]  # q_w0, r_0, ν

print(f"Moment indices: {moment_indices}")
print(f"Environmental parameter indices: {env_indices}")

# Compute ensemble mean (average over the 100 instances dimension)
data_mean = data_masked.mean(axis=3)  # Shape: (time, features, ics)
print(f"\nEnsemble mean shape: {data_mean.shape}")

Moment indices: [1, 2, 3, 4]
Environmental parameter indices: [14, 15, 16]



Ensemble mean shape: (3599, 18, 100)


In [6]:
# Create trajectories with proper masking
def extract_trajectories(data_mean, ic_indices, moment_indices, env_indices):
    """
    Extract valid trajectories for given IC indices.
    Returns list of dicts with 'moments', 'env_params', 'length'
    """
    trajectories = []
    
    for ic_idx in ic_indices:
        # Extract moments and env params for this IC
        moments = data_mean[:, moment_indices, ic_idx]  # (time, 4)
        env_params = data_mean[0, env_indices, ic_idx]  # (3,) - constant across time
        
        # Find valid timesteps (check first moment)
        valid_mask = ~moments[:, 0].mask
        n_valid = valid_mask.sum()
        
        if n_valid > 1:  # Need at least 2 timesteps for derivatives
            # Extract only valid data
            moments_valid = moments[valid_mask].data  # (n_valid, 4)
            env_params_valid = env_params.data  # (3,)
            
            trajectories.append({
                'moments': moments_valid,
                'env_params': env_params_valid,
                'length': n_valid,
                'ic_idx': ic_idx
            })
    
    return trajectories

train_trajectories = extract_trajectories(data_mean, train_indices, moment_indices, env_indices)
val_trajectories = extract_trajectories(data_mean, val_indices, moment_indices, env_indices)

print(f"\nNumber of training trajectories: {len(train_trajectories)}")
print(f"Number of validation trajectories: {len(val_trajectories)}")
print(f"\nExample trajectory lengths (train):")
for i in range(min(5, len(train_trajectories))):
    print(f"  Trajectory {i}: {train_trajectories[i]['length']} timesteps")


Number of training trajectories: 80
Number of validation trajectories: 20

Example trajectory lengths (train):
  Trajectory 0: 130 timesteps
  Trajectory 1: 779 timesteps
  Trajectory 2: 324 timesteps
  Trajectory 3: 149 timesteps
  Trajectory 4: 59 timesteps


In [7]:
# Fit StandardScaler on environmental parameters (training data only)
env_params_train = np.array([traj['env_params'] for traj in train_trajectories])
env_scaler = StandardScaler()
env_scaler.fit(env_params_train)

print(f"\nEnvironmental parameter statistics (before scaling):")
print(f"  Mean: {env_params_train.mean(axis=0)}")
print(f"  Std: {env_params_train.std(axis=0)}")

# Apply scaling to all trajectories
for traj in train_trajectories:
    traj['env_params_scaled'] = env_scaler.transform(traj['env_params'].reshape(1, -1)).flatten()

for traj in val_trajectories:
    traj['env_params_scaled'] = env_scaler.transform(traj['env_params'].reshape(1, -1)).flatten()

print(f"\nEnvironmental parameter statistics (after scaling):")
env_params_scaled_train = np.array([traj['env_params_scaled'] for traj in train_trajectories])
print(f"  Mean: {env_params_scaled_train.mean(axis=0)}")
print(f"  Std: {env_params_scaled_train.std(axis=0)}")

# Save scaler for later use
with open(data_dir / 'env_scaler.pkl', 'wb') as f:
    pickle.dump(env_scaler, f)
print("\nScaler saved to data/env_scaler.pkl")


Environmental parameter statistics (before scaling):
  Mean: [8.84999941e-04 1.19000000e-05 2.10625000e+00]
  Std: [4.63707871e-04 2.02854569e-06 1.19030918e+00]

Environmental parameter statistics (after scaling):
  Mean: [ 8.32667268e-16  4.87387908e-15 -1.74686654e-16]
  Std: [1. 1. 1.]

Scaler saved to data/env_scaler.pkl


Moments span many orders of magnitude and have heavy-tailed distributions, so we use asinh transformation:
1. Compute optimal scales for each moment variable
2. Apply asinh(x / scale) transformation
3. Apply standard scaling to normalized asinh values

In [8]:
all_moments_train = []
for traj in train_trajectories:
    all_moments_train.append(traj['moments'])

all_moments_train = np.vstack(all_moments_train)  # Shape: (total_timesteps, 4)
print(f"Total training timesteps: {all_moments_train.shape[0]}")
print(f"\nRaw moment statistics:")
print(f"  qc - Min: {all_moments_train[:, 0].min():.2e}, Max: {all_moments_train[:, 0].max():.2e}, Median: {np.median(all_moments_train[:, 0]):.2e}")
print(f"  nc - Min: {all_moments_train[:, 1].min():.2e}, Max: {all_moments_train[:, 1].max():.2e}, Median: {np.median(all_moments_train[:, 1]):.2e}")
print(f"  qr - Min: {all_moments_train[:, 2].min():.2e}, Max: {all_moments_train[:, 2].max():.2e}, Median: {np.median(all_moments_train[:, 2]):.2e}")
print(f"  nr - Min: {all_moments_train[:, 3].min():.2e}, Max: {all_moments_train[:, 3].max():.2e}, Median: {np.median(all_moments_train[:, 3]):.2e}")

# Compute optimal asinh scales (use median as characteristic scale)
# asinh(median) ≈ 1
asinh_scales = np.array([
    np.median(all_moments_train[:, 0]),  # qc
    np.median(all_moments_train[:, 1]),  # nc
    np.median(all_moments_train[:, 2]),  # qr
    np.median(all_moments_train[:, 3])   # nr
])

print(f"\nAsinh scales (using median):")
print(f"  qc: {asinh_scales[0]:.2e}")
print(f"  nc: {asinh_scales[1]:.2e}")
print(f"  qr: {asinh_scales[2]:.2e}")
print(f"  nr: {asinh_scales[3]:.2e}")

Total training timesteps: 39371

Raw moment statistics:
  qc - Min: 3.50e-11, Max: 2.00e-03, Median: 7.82e-05
  nc - Min: 6.52e+01, Max: 6.54e+08, Median: 1.27e+07
  qr - Min: 1.74e-23, Max: 5.59e-14, Median: 8.67e-16
  nr - Min: 0.00e+00, Max: 2.00e-03, Median: 2.00e-04

Asinh scales (using median):
  qc: 7.82e-05
  nc: 1.27e+07
  qr: 8.67e-16
  nr: 2.00e-04


In [9]:
# Apply asinh transformation to training data
moments_asinh_train = np.arcsinh(all_moments_train / asinh_scales)

# Fit StandardScaler on asinh-transformed moments
moment_scaler = StandardScaler()
moment_scaler.fit(moments_asinh_train)

print(f"\nMoment scaler parameters:")
print(f"  Means: {moment_scaler.mean_}")
print(f"  Stds: {moment_scaler.scale_}")


Moment scaler parameters:
  Means: [1.24339712 1.30073491 1.14457747 0.96651209]
  Stds: [1.26184572 1.34025571 1.20207609 0.88758972]


In [10]:
# Now apply full transformation (asinh + z-score) to all trajectories
def transform_moments(moments, asinh_scales, moment_scaler):
    """Apply asinh transformation followed by z-score normalization."""
    moments_asinh = np.arcsinh(moments / asinh_scales)
    moments_normalized = moment_scaler.transform(moments_asinh)
    return moments_normalized

# Apply to training trajectories
for traj in train_trajectories:
    traj['moments_scaled'] = transform_moments(traj['moments'], asinh_scales, moment_scaler)

# Apply to validation trajectories
for traj in val_trajectories:
    traj['moments_scaled'] = transform_moments(traj['moments'], asinh_scales, moment_scaler)

We now incorporate Transformed moments to all trajectories as 'moments_scaled' field

In [11]:
# Save asinh normalization parameters for inference
asinh_normalization_stats = {
    'asinh_scales': asinh_scales,
    'moment_scaler_mean': moment_scaler.mean_,
    'moment_scaler_std': moment_scaler.scale_,
    'moment_scaler': moment_scaler  # Save the full scaler object for convenience
}

with open(data_dir / 'asinh_normalization_stats.pkl', 'wb') as f:
    pickle.dump(asinh_normalization_stats, f)

print("Asinh normalization parameters saved to data/asinh_normalization_stats.pkl")
print(f"Asinh scales: {asinh_scales}")
print(f"Scaler means: {moment_scaler.mean_}")
print(f"Scaler stds: {moment_scaler.scale_}")

Asinh normalization parameters saved to data/asinh_normalization_stats.pkl
Asinh scales: [7.81641163e-05 1.27452755e+07 8.67027511e-16 1.99642239e-04]
Scaler means: [1.24339712 1.30073491 1.14457747 0.96651209]
Scaler stds: [1.26184572 1.34025571 1.20207609 0.88758972]


Each trajectory now contains:
- 'moments': raw moment data
- 'moments_scaled': asinh-transformed and z-score normalized moments
- 'env_params': raw environmental parameters
- 'env_params_scaled': z-score normalized environmental parameters
- 'length': number of timesteps
- 'ic_idx': initial condition index

In [14]:
# Collect all scaled moments from training data
all_moments_scaled_train = np.vstack([traj['moments_scaled'] for traj in train_trajectories])

print(f"\nScaled moment statistics:")
for i, name in enumerate(['qc', 'nc', 'qr', 'nr']):
    mean = all_moments_scaled_train[:, i].mean()
    std = all_moments_scaled_train[:, i].std()
    min_val = all_moments_scaled_train[:, i].min()
    max_val = all_moments_scaled_train[:, i].max()
    print(f"  {name} - Mean: {mean:+.6f}, Std: {std:.6f}, Range: [{min_val:+.2f}, {max_val:+.2f}]")

# Check for NaN or Inf values
if np.any(np.isnan(all_moments_scaled_train)) or np.any(np.isinf(all_moments_scaled_train)):
    print("WARNING: NaN or Inf values detected in scaled moments!")
else:
    print("No NaN or Inf values detected")


Scaled moment statistics:
  qc - Mean: +0.000000, Std: 1.000000, Range: [-0.99, +2.13]
  nc - Mean: +0.000000, Std: 1.000000, Range: [-0.97, +2.49]
  qr - Mean: -0.000000, Std: 1.000000, Range: [-0.95, +3.09]
  nr - Mean: +0.000000, Std: 1.000000, Range: [-1.09, +2.29]
No NaN or Inf values detected


In [15]:
with open(data_dir / "train_trajectories.pkl", "wb") as f:
    pickle.dump(train_trajectories, f)

with open(data_dir / "val_trajectories.pkl", "wb") as f:
    pickle.dump(val_trajectories, f)