In [1]:
!pip install torchdiffeq



In [2]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from torchdiffeq import odeint_adjoint
import pickle
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [3]:
# Load data
data_dir = Path("./data")
test_data = np.load(data_dir / "test_arr.npz")
data_array = test_data['data']  # shape (time, features, ensemble, ics)
mask_array = test_data['mask']  # same shape as data_array

# Apply masking
data_masked = np.ma.MaskedArray(data_array, mask=mask_array)
print(f"\nMasked data - Min: {data_masked.min():.2e}, Max: {data_masked.max():.2e}")


Masked data - Min: 0.00e+00, Max: 6.54e+08


In [4]:
# Use 80% of the data for training
n_ics_total = data_array.shape[2]
n_ics_train = int(0.8 * n_ics_total)
print(f"Total ICs: {n_ics_total}")

# Split: first 80% for train, rest for validation
train_indices = np.arange(n_ics_train)
val_indices = np.arange(n_ics_train, n_ics_total)
print(f"Train ICs: {len(train_indices)}, Val ICs: {len(val_indices)}")

Total ICs: 100
Train ICs: 80, Val ICs: 20


In [5]:
# Extract features: moments [1-4] and environmental params [14-16]
moment_indices = [1, 2, 3, 4]  # qc, nc, qr, nr
env_indices = [14, 15, 16]  # q_w0, r_0, Î½

print(f"Moment indices: {moment_indices}")
print(f"Environmental parameter indices: {env_indices}")

# Compute ensemble mean (average over the 100 instances dimension)
data_mean = data_masked.mean(axis=3)  # Shape: (time, features, ics)
print(f"\nEnsemble mean shape: {data_mean.shape}")

Moment indices: [1, 2, 3, 4]
Environmental parameter indices: [14, 15, 16]



Ensemble mean shape: (3599, 18, 100)


In [6]:
# Create trajectories with proper masking
def extract_trajectories(data_mean, ic_indices, moment_indices, env_indices):
    """
    Extract valid trajectories for given IC indices.
    Returns list of dicts with 'moments', 'env_params', 'length'
    """
    trajectories = []
    
    for ic_idx in ic_indices:
        # Extract moments and env params for this IC
        moments = data_mean[:, moment_indices, ic_idx]  # (time, 4)
        env_params = data_mean[0, env_indices, ic_idx]  # (3,) - constant across time
        
        # Find valid timesteps (check first moment)
        valid_mask = ~moments[:, 0].mask
        n_valid = valid_mask.sum()
        
        if n_valid > 1:  # Need at least 2 timesteps for derivatives
            # Extract only valid data
            moments_valid = moments[valid_mask].data  # (n_valid, 4)
            env_params_valid = env_params.data  # (3,)
            
            trajectories.append({
                'moments': moments_valid,
                'env_params': env_params_valid,
                'length': n_valid,
                'ic_idx': ic_idx
            })
    
    return trajectories

train_trajectories = extract_trajectories(data_mean, train_indices, moment_indices, env_indices)
val_trajectories = extract_trajectories(data_mean, val_indices, moment_indices, env_indices)

print(f"\nNumber of training trajectories: {len(train_trajectories)}")
print(f"Number of validation trajectories: {len(val_trajectories)}")
print(f"\nExample trajectory lengths (train):")
for i in range(min(5, len(train_trajectories))):
    print(f"  Trajectory {i}: {train_trajectories[i]['length']} timesteps")


Number of training trajectories: 80
Number of validation trajectories: 20

Example trajectory lengths (train):
  Trajectory 0: 130 timesteps
  Trajectory 1: 779 timesteps
  Trajectory 2: 324 timesteps
  Trajectory 3: 149 timesteps
  Trajectory 4: 59 timesteps


In [7]:
# Apply log-transform to moments
epsilon = 1e-10

def log_transform_moments(moments):
    """Apply log transform to moments."""
    return np.log(moments + epsilon)

def inverse_log_transform_moments(log_moments):
    """Inverse log transform."""
    return np.exp(log_moments) - epsilon

# Transform all trajectories
for traj in train_trajectories:
    traj['log_moments'] = log_transform_moments(traj['moments'])

for traj in val_trajectories:
    traj['log_moments'] = log_transform_moments(traj['moments'])

print(f"Example log-moment range: [{train_trajectories[0]['log_moments'].min():.2f}, {train_trajectories[0]['log_moments'].max():.2f}]")

Example log-moment range: [-23.03, 19.21]


In [8]:
# Fit StandardScaler on environmental parameters (training data only)
env_params_train = np.array([traj['env_params'] for traj in train_trajectories])
env_scaler = StandardScaler()
env_scaler.fit(env_params_train)

print(f"\nEnvironmental parameter statistics (before scaling):")
print(f"  Mean: {env_params_train.mean(axis=0)}")
print(f"  Std: {env_params_train.std(axis=0)}")

# Apply scaling to all trajectories
for traj in train_trajectories:
    traj['env_params_scaled'] = env_scaler.transform(traj['env_params'].reshape(1, -1)).flatten()

for traj in val_trajectories:
    traj['env_params_scaled'] = env_scaler.transform(traj['env_params'].reshape(1, -1)).flatten()

print(f"\nEnvironmental parameter statistics (after scaling):")
env_params_scaled_train = np.array([traj['env_params_scaled'] for traj in train_trajectories])
print(f"  Mean: {env_params_scaled_train.mean(axis=0)}")
print(f"  Std: {env_params_scaled_train.std(axis=0)}")

# Save scaler for later use
with open(data_dir / 'env_scaler.pkl', 'wb') as f:
    pickle.dump(env_scaler, f)
print("\nScaler saved to data/env_scaler.pkl")


Environmental parameter statistics (before scaling):
  Mean: [8.84999941e-04 1.19000000e-05 2.10625000e+00]
  Std: [4.63707871e-04 2.02854569e-06 1.19030918e+00]

Environmental parameter statistics (after scaling):
  Mean: [ 8.32667268e-16  4.87387908e-15 -1.74686654e-16]
  Std: [1. 1. 1.]

Scaler saved to data/env_scaler.pkl


In [9]:
# save train and val trajectories for later use
with open(data_dir / "train_trajectories.pkl", "wb") as f:
    pickle.dump(train_trajectories, f)

with open(data_dir / "val_trajectories.pkl", "wb") as f:
    pickle.dump(val_trajectories, f)