In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt
from modules.fourier import FNO, FNOBlock1d, SpectraclConv1d
from data.utils import TimeWindowDataset, NumpyLoader
from training.train import fit_plain, TRAIN_LOSS_KEY, VAL_LOSS_KEY

In [2]:
dataset_train = TimeWindowDataset("./data/dump/KdV_train_512.h5", 140, 256, 10, 10, "train", True)
dataset_val = TimeWindowDataset("./data/dump/KdV_valid.h5", 140, 256, 10, 10, "valid", True)
loader_train = NumpyLoader(dataset_train, batch_size=16, shuffle=True)
loader_val = NumpyLoader(dataset_val, batch_size=32, shuffle=False)

In [3]:
key = jax.random.PRNGKey(42)
key_proj_in, key_blocks, key_proj_out = jax.random.split(key, 3)

keys_pin = jax.random.split(key_proj_in, 2)
projection_input = eqx.nn.Sequential([
    eqx.nn.Conv1d(10, 64, 1, 1, "same", key=keys_pin[0]),
    eqx.nn.Conv1d(64, 64, 1, 1, "same", key=keys_pin[1]),
])

keys_blocks = jax.random.split(key_blocks, 4)
fourier_blocks = eqx.nn.Sequential([
    FNOBlock1d(64, 64, 16, jax.nn.gelu, key=key) for key in keys_blocks])

keys_pout = jax.random.split(key_proj_out, 2)
projection_output = eqx.nn.Sequential([
    eqx.nn.Conv1d(64, 64, 1, 1, "same", key=keys_pout[0]),
    eqx.nn.Conv1d(64, 10, 1, 1, "same", key=keys_pout[1]),
])

model = FNO(10, 10, fourier_blocks, projection_input, projection_output)

In [4]:
def loss_fn(model, inputs, outputs):
    preds = eqx.filter_vmap(model)(inputs)
    loss = jnp.mean((preds - outputs)**2)
    return loss

In [5]:
lr_scheduler = optax.schedules.exponential_decay(1e-3, 1000, 0.95)
optimizer = optax.adam(lr_scheduler)

In [6]:
model, opt_tate, history = fit_plain(model, loader_train, loss_fn, optimizer, 2, loader_val, print_every=200)

Epoch 0, step 0, loss: 0.42436572909355164
Epoch 0, step 200, loss: 0.01321825385093689
Epoch 0, step 400, loss: 0.00363359646871686
Epoch 0, step 600, loss: 0.0023752169217914343
Epoch 0, step 800, loss: 0.0008829644066281617
Epoch 0, step 1000, loss: 0.0021984167397022247
Epoch 0, step 1200, loss: 0.0008515235967934132
Epoch 0, step 1400, loss: 0.0006207986734807491
Epoch 0, step 1600, loss: 0.0005138018168509007
Epoch 0, step 1800, loss: 0.0013995023909956217
Epoch 0, step 2000, loss: 0.000327286368701607
Epoch 0, step 2200, loss: 0.0004051524156238884
Epoch 0, step 2400, loss: 0.0010465913219377398
Epoch 0, step 2600, loss: 0.00122267531696707
Epoch 0, step 2800, loss: 0.0007978844805620611
Epoch 0, step 3000, loss: 0.0004257399996276945
Epoch 0, step 3200, loss: 0.0001514315663371235
Epoch 0, step 3400, loss: 0.00028852210380136967
Epoch 0, step 3600, loss: 0.0007204374414868653
Epoch 0, step 3800, loss: 0.00023658228747081012
Validation loss: 0.0008646697970107198
Epoch 1, step 0