In [1]:
import jax
import optax
import equinox as eqx
from data.utils import MatDataset, load_mat_data, jax_collate
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
import jax.numpy as jnp

from modules.fourier import FNO, FNOBlock1d, SpectraclConv1d

In [2]:
class BurgersDataset(Dataset):

    def __init__(self, path, input_steps: int, future_steps: int, key: jax.random.PRNGKey):
        self.data = load_mat_data(path)
        self.tspan = self.data['tspan']
        self.data['output'][:, 0, :] = self.data['input']
        self.data = self.data['output']
        self.key = key
        self.input_steps = input_steps
        self.future_steps = future_steps
        self.x = np.linspace(0, 1, 1024, endpoint=True).reshape(1, -1)

    def __len__(self):
        return self.data.shape[0] * (self.data.shape[1] - self.future_steps - self.input_steps)
    
    def __getitem__(self, idx):
        i = idx // (self.data.shape[1] - self.future_steps - self.input_steps)
        j = idx % (self.data.shape[1] - self.future_steps - self.input_steps)
        # return np.concatenate([self.data[i, j:j+self.input_steps], self.x], axis=0), self.data[i, j+self.input_steps:j+self.input_steps+self.future_steps]
        # return {'input': self.data[i, j:j+self.input_steps], 'x': self.x, 'y': self.data[i, j+self.input_steps:j+self.input_steps+self.future_steps]}
        return self.data[i, j:j+self.input_steps], self.data[i, j+self.input_steps:j+self.input_steps+self.future_steps]

In [3]:
dataset = BurgersDataset('datasets/burgers_v100_t100_r1024_N2048.mat', 5, 5, jax.random.PRNGKey(42))

Consider mio5.varmats_from_mat to split file into single variable files
  matfile_dict = MR.get_variables(variable_names)


In [4]:
loader = DataLoader(dataset, batch_size=64, collate_fn=jax_collate, shuffle=True)

In [5]:
@partial(eqx.filter_vmap, in_axes=(None, 0, 0))
def loss_fn(model, data, label):
    out = model(data)
    return optax.l2_loss(out, label)

def train(model: 'eqx.Module', dataloader: 'DataLoader', optimizer: optax.GradientTransformation, n_epochs, opt_state=None, history: dict=None, print_every_steps=2000):
    if opt_state is None:
        opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
    if history is None:
        history = {'loss': []}

    loss_f = eqx.filter_value_and_grad(lambda model, data, labels: jnp.mean(loss_fn(model, data, labels)), )
        
    @eqx.filter_jit
    def train_step(model, data, labels, opt_state):
        loss, grads = loss_f(model, data, labels)
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss
        
    for epoch in range(n_epochs):
        for step, batch in enumerate(dataloader):
            data, labels = batch
            model, opt_state, loss = train_step(model, data, labels, opt_state)
            history['loss'].append(loss)
            if step % print_every_steps == 0:
                print(f'Epoch {epoch}, Step {step}, Loss {loss}')
                
    return model, opt_state, history

In [6]:
in_channels = [10, 32, 64, 64]
out_channels = [32, 64, 64, 64]
key = jax.random.PRNGKey(43)
keys = jax.random.split(key, 8)
activation = jax.nn.gelu
modes = 16

input_projection = eqx.nn.Sequential([
    eqx.nn.Conv1d(5, 10, 1, 1, 0, key=keys[0]),
    eqx.nn.Conv1d(10, 10, 1, 1, 0, key=keys[1])
])
fourier_blocks = eqx.nn.Sequential([FNOBlock1d(in_channels[i], out_channels[i], modes, activation, keys[i + 2]) for i in range(4)])
output_projection = eqx.nn.Sequential([
    eqx.nn.Conv1d(64, 64, 1, 1, 0, key=keys[-2]),
    eqx.nn.Conv1d(64, 5, 1, 1, 0, key=keys[-1])
])
fno = FNO(5, 5, fourier_blocks, input_projection, output_projection)

lr_scheduler = optax.schedules.exponential_decay(1e-3, 500, 0.9)
optimizer = optax.adam(lr_scheduler)

In [7]:
model, opt_state, history = train(fno, loader, optimizer, 10)

Epoch 0, Step 0, Loss 0.023738935589790344
Epoch 0, Step 2000, Loss 1.1456663884246154e-07
