In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt
from modules.fourier import FNO, FNOBlock1d, SpectraclConv1d
from data.utils import TimeWindowDataset, NumpyLoader
from training.train import fit_plain, TRAIN_LOSS_KEY, VAL_LOSS_KEY

In [2]:
dataset_train = TimeWindowDataset("./data/dump/KdV_train_512.h5", 140, 256, 20, 20, "train", True)
dataset_val = TimeWindowDataset("./data/dump/KdV_valid.h5", 140, 256, 20, 20, "valid", True)
loader_train = NumpyLoader(dataset_train, batch_size=32, shuffle=True)
loader_val = NumpyLoader(dataset_val, batch_size=32, shuffle=False)

In [3]:
key = jax.random.PRNGKey(42)
key_proj_in, key_blocks, key_proj_out = jax.random.split(key, 3)

keys_pin = jax.random.split(key_proj_in, 2)
projection_input = eqx.nn.Sequential([
    eqx.nn.Conv1d(20, 128, 1, 1, "same", key=keys_pin[0]),
    eqx.nn.Conv1d(128, 128, 1, 1, "same", key=keys_pin[1]),
])

keys_blocks = jax.random.split(key_blocks, 5)
fourier_blocks = eqx.nn.Sequential([
    FNOBlock1d(128, 128, 32, jax.nn.gelu, key=key) for key in keys_blocks])

keys_pout = jax.random.split(key_proj_out, 2)
projection_output = eqx.nn.Sequential([
    eqx.nn.Conv1d(128, 128, 1, 1, "same", key=keys_pout[0]),
    eqx.nn.Conv1d(128, 20, 1, 1, "same", key=keys_pout[1]),
])

model = FNO(20, 20, fourier_blocks, projection_input, projection_output)

In [4]:
def MSE_LOSS(model, inputs, outputs):
    preds = eqx.filter_vmap(model)(inputs)
    loss = jnp.mean((preds - outputs)**2)
    return loss

def NMSE_LOSS(model, inputs, outputs):
    # Normalized MSE
    preds = eqx.filter_vmap(model)(inputs)
    loss = (preds - outputs) ** 2 / outputs ** 2
    loss = jnp.mean(loss)    
    return loss

In [5]:
lr_scheduler = optax.schedules.exponential_decay(1e-3, 1000, 0.95)
optimizer = optax.adam(lr_scheduler)

In [None]:
model, opt_tate, history = fit_plain(model, loader_train, NMSE_LOSS, optimizer, 20, loader_val, print_every=200)

In [None]:
plt.plot(history[TRAIN_LOSS_KEY], label="train")
plt.plot(history[VAL_LOSS_KEY], label="val")
plt.legend()
# log scale
plt.yscale("log")