In [1]:
import jax
import optax
import equinox as eqx
from models.fno import FNO1d
from data.utils import MatDataset, load_mat_data, jax_collate
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
import jax.numpy as jnp

In [2]:
class BurgersDataset(Dataset):

    def __init__(self, path, input_steps: int, future_steps: int, key: jax.random.PRNGKey):
        self.data = load_mat_data(path)
        self.tspan = self.data['tspan']
        self.data['output'][:, 0, :] = self.data['input']
        self.data = self.data['output']
        self.key = key
        self.input_steps = input_steps
        self.future_steps = future_steps
        self.x = np.linspace(0, 1, 1024, endpoint=True).reshape(1, -1)

    def __len__(self):
        return self.data.shape[0] * (self.data.shape[1] - self.future_steps - self.input_steps)
    
    def __getitem__(self, idx):
        i = idx // (self.data.shape[1] - self.future_steps - self.input_steps)
        j = idx % (self.data.shape[1] - self.future_steps - self.input_steps)
        return np.concatenate([self.data[i, j:j+self.input_steps], self.x], axis=0), self.data[i, j+self.input_steps:j+self.input_steps+self.future_steps]

In [3]:
dataset = BurgersDataset('datasets/burgers_v100_t100_r1024_N2048.mat', 5, 5, jax.random.PRNGKey(42))

Consider mio5.varmats_from_mat to split file into single variable files
  matfile_dict = MR.get_variables(variable_names)


In [4]:
loader = DataLoader(dataset, batch_size=64, collate_fn=jax_collate, shuffle=True)

In [5]:
@partial(eqx.filter_vmap, in_axes=(None, 0, 0))
def loss_fn(model, data, label):
    out = model(data)
    return optax.l2_loss(out, label)

def train(model: 'eqx.Module', dataloader: 'DataLoader', optimizer: optax.GradientTransformation, n_epochs, opt_state=None, history: dict=None, print_every_steps=2000):
    if opt_state is None:
        opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
    if history is None:
        history = {'loss': []}

    loss_f = eqx.filter_value_and_grad(lambda model, data, labels: jnp.mean(loss_fn(model, data, labels)), )
        
    @eqx.filter_jit
    def train_step(model, data, labels, opt_state):
        loss, grads = loss_f(model, data, labels)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss
        
    for epoch in range(n_epochs):
        for step, batch in enumerate(dataloader):
            data, labels = batch
            model, opt_state, loss = train_step(model, data, labels, opt_state)
            history['loss'].append(loss)
            if step % print_every_steps == 0:
                print(f'Epoch {epoch}, Step {step}, Loss {loss}')
                
    return model, opt_state, history

In [6]:
model = FNO1d(6, 5, 64, 6, 4, jax.nn.gelu, jax.random.PRNGKey(10))
lr_scheduler = optax.schedules.exponential_decay(1e-3, 500, 0.9)
optimizer = optax.adam(lr_scheduler)

In [7]:
model, opt_state, history = train(model, loader, optimizer, 10)

Epoch 0, Step 0, Loss 0.021976759657263756
Epoch 0, Step 2000, Loss 0.00011546241876203567
Epoch 1, Step 0, Loss 4.141132376389578e-05
Epoch 1, Step 2000, Loss 1.2299183254071977e-05
Epoch 2, Step 0, Loss 1.971153324120678e-05
Epoch 2, Step 2000, Loss 4.1831495764199644e-05
Epoch 3, Step 0, Loss 5.2846666221739724e-05
Epoch 3, Step 2000, Loss 2.122859950759448e-05
Epoch 4, Step 0, Loss 1.739593608363066e-05
Epoch 4, Step 2000, Loss 1.4753706636838615e-05
Epoch 5, Step 0, Loss 2.024178502324503e-05
Epoch 5, Step 2000, Loss 9.969068742066156e-06
Epoch 6, Step 0, Loss 1.3603561455965973e-05
Epoch 6, Step 2000, Loss 1.1475144674477633e-05
Epoch 7, Step 0, Loss 1.3502404726750683e-05
Epoch 7, Step 2000, Loss 1.515400072094053e-05
Epoch 8, Step 0, Loss 1.8986858776770532e-05
Epoch 8, Step 2000, Loss 1.0824790479091462e-05
Epoch 9, Step 0, Loss 1.0140410267922562e-05
Epoch 9, Step 2000, Loss 7.931656000437215e-06
