To run this file, please download the 'model_data' directory and place it in the same folder as this script.

In [6]:
import os
import jax
import numpy as np
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from scipy.stats import gmean
from typing import Callable, Dict, List, Tuple, Union

# Import from apebench and pdequinox
import apebench
from apebench.scenarios import scenario_dict
import exponax as ex
from exponax import metrics as ex_metrics

from pde_emulator.utils import get_equation_encoding

# --- Environment Setup ---
# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [7]:
# --- Model 1: PI-FNO-UNET ---
from pde_emulator.models import PI_FNO_UNET

IN_CHANNELS = 1
ENCODING_DIM = 7
MODEL_LOAD_PATH = "./model_data/PI-FNO-UNET.eqx"

# --- Create a dummy/template model instance using the library function ---
# We need a key, but it's just for initialization. The weights will be overwritten.
key = jax.random.PRNGKey(99)

# We also need dummy mean/std vectors for initialization.
# These will also be overwritten by the loaded file since they are part of the model state.
dummy_mean = jnp.zeros(ENCODING_DIM)
dummy_std = jnp.ones(ENCODING_DIM)

# Instantiate the model with the correct architecture but with random weights.
# This creates the "skeleton" that we will fill with the saved parameters.
model_template = PI_FNO_UNET(
    num_spatial_dims=1,
    in_channels=IN_CHANNELS,
    encoding_dim=ENCODING_DIM,
    key=key,
    encoding_mean=dummy_mean,
    encoding_std=dummy_std,
)

# --- Load the saved weights into the model structure ---
# This is the key step. It finds the model file and loads the saved arrays
# into the corresponding places in `model_template`.
print(f"Loading model from {MODEL_LOAD_PATH}...")
model_1 = eqx.tree_deserialise_leaves(MODEL_LOAD_PATH, model_template)
print("Model 1 (PI-FNO-UNET) loaded successfully.")


>>> Instantiating GENERALIZED FiLM-FNO U-Net model with DERIVATIVES, U^2, and U*UX as input: embedding_dim=32, hidden_channels=128, modes=32, input_channels_to_fnonet=7, output_channels_from_fnonet=1, activation=SiLU <<<
Loading model from ./model_data/PI-FNO-UNET.eqx...
Model 1 (PI-FNO-UNET) loaded successfully.


In [8]:
# --- Model 2: LSC-FNO ---
from pde_emulator.models import LSC_FNO

IN_CHANNELS = 1
ENCODING_DIM = 7
SPATIAL_DIM_SIZE = 160 # IMPORTANT: This must match the spatial dimension of your training data.
MODEL_LOAD_PATH = "model_data/LSC-FNO.eqx"

# --- Create a dummy/template model instance using the library function ---
# We need a key, but it's just for initialization. The parameters will be overwritten.
key = jax.random.PRNGKey(123)

# Create dummy normalization stats. These will be overwritten by the loaded file.
dummy_data_mean = jnp.zeros(()) # Scalar
dummy_data_std = jnp.ones(())   # Scalar
dummy_encoding_mean = jnp.zeros(ENCODING_DIM)
dummy_encoding_std = jnp.ones(ENCODING_DIM)

# Instantiate the model with the correct architecture but with random weights.
# This creates the "skeleton" that we will fill with the saved parameters.
model_template = LSC_FNO(
    num_spatial_dims=1,
    in_channels=IN_CHANNELS,
    encoding_dim=ENCODING_DIM,
    spatial_dim_size=SPATIAL_DIM_SIZE,
    data_mean=dummy_data_mean,
    data_std=dummy_data_std,
    encoding_mean=dummy_encoding_mean,
    encoding_std=dummy_encoding_std,
    key=key,   
    
)

# --- Load the saved weights into the model structure ---
print(f"Loading model from {MODEL_LOAD_PATH}...")
model_2 = eqx.tree_deserialise_leaves(MODEL_LOAD_PATH, model_template)
print("Model 2 (LSC-FNO) loaded successfully.")


>>> Instantiating GENERALIZED FNO (LNO variant) model with parameters: latent_dim=80, embedding_dim=64, fno_hidden=128, fno_depth=12, num_modes=40, activation=silu, original_spatial_dim_size=160, encoding_dim=7, encoder_hidden=64 <<<
Loading model from model_data/LSC-FNO.eqx...
Model 2 (LSC-FNO) loaded successfully.


In [9]:
# --- Model 3: PINO-FNO ---
from pickle import GLOBAL
from pde_emulator.models import PINO

MODEL_SAVE_DIR = "model_data/"
MODEL_FILENAME = "PINO_model.eqx"
CONSTANTS_FILENAME = "PINO_constants.npz"
MODEL_LOAD_PATH = os.path.join(MODEL_SAVE_DIR, MODEL_FILENAME)
CONSTANTS_LOAD_PATH = os.path.join(MODEL_SAVE_DIR, CONSTANTS_FILENAME)

print(f"Loading normalization constants from {CONSTANTS_LOAD_PATH}...")
constants = np.load(CONSTANTS_LOAD_PATH)
DATA_MEAN_U = jnp.array(constants['global_data_mean_u'])
DATA_STD_U = jnp.array(constants['global_data_std_u'])
ENCODING_MIN_VALS = jnp.array(constants['global_encoding_min_vals'])
ENCODING_MAX_VALS = jnp.array(constants['global_encoding_max_vals'])
print("Global constants loaded and set successfully.")

# --- Configuration for creating template model ---
IN_CHANNELS_U = 1
ENCODING_DIM = 7
SPATIAL_RESOLUTION = 160 # Must match the spatial dimension of your training data

# --- Create a dummy/template model instance ---
key = jax.random.PRNGKey(456)
model_template = PINO(
    num_spatial_dims=1,
    in_channels_u=IN_CHANNELS_U,
    encoding_dim=ENCODING_DIM,
    key=key,
    encoding_min_vals=ENCODING_MIN_VALS,
    encoding_max_vals=ENCODING_MAX_VALS,
    data_mean=DATA_MEAN_U,
    data_std=DATA_STD_U,
)

# --- Load the saved weights into the model structure ---
print(f"Loading model parameters from {MODEL_LOAD_PATH}...")
model_3 = eqx.tree_deserialise_leaves(MODEL_LOAD_PATH, model_template)
print("Model loaded successfully.")

Loading normalization constants from model_data/PINO_constants.npz...
Global constants loaded and set successfully.

>>> Instantiating GENERALIZED FiLM + FNO model with parameters: embedding_dim=96, encoder_hidden=192, fno_hidden=256, num_fno_modes=32, fno_depth=6, encoding_projection_channels=2 <<<
Loading model parameters from model_data/PINO_model.eqx...
Model loaded successfully.


In [10]:
# --- Model 4: Learned Correction (LC) ---
from pde_emulator.models import LC

IN_CHANNELS = 1
ENCODING_DIM = 7
MODEL_LOAD_PATH = "model_data/LC.eqx"

# --- Create a dummy/template model instance using the library function ---
# We need a key for random initialization, but the weights will be overwritten.
key = jax.random.PRNGKey(789)

# Instantiate the model with the correct architecture but with random weights.
# This creates the "skeleton" that we will fill with the saved parameters.
model_template = LC(
    num_spatial_dims=1, 
    in_channels=IN_CHANNELS,
    encoding_dim=ENCODING_DIM,
    key=key,
)

# --- Load the saved weights into the model structure ---
print(f"Loading model from {MODEL_LOAD_PATH}...")
model_4 = eqx.tree_deserialise_leaves(MODEL_LOAD_PATH, model_template)
print("Model 4 (LC) loaded successfully.")


>>> Instantiating GENERALIZED FiLM Correction model with parameters: embedding_dim=32, cnn_hidden=160, cnn_depth=14, using SpectralConv1d with N=160, num_modes_fraction=0.3333333333333333, and improved complex weight initialization. Augmented input channels: 49 <<<
Loading model from model_data/LC.eqx...
Model 4 (LC) loaded successfully.
