In [None]:
import time
from functools import partial
from pathlib import Path
from tqdm import tqdm
import pickle
import jax.numpy as jnp
import jax
import jax.random as jr
import numpy as onp
import optax
from flax.training import train_state
import scipy
from neuralss import ss_init, ss_apply
import nonlinear_benchmarks

In [None]:
key = jr.key(42)

In [None]:
jax.config.update("jax_default_device", jax.devices("cpu")[0])

In [None]:
ckpt_path = Path("out") / f"maml_10s.p"
ckpt = pickle.load(open(ckpt_path, "rb"))

cfg = ckpt["cfg"]
params_maml = ckpt["params"]
scalers = ckpt["scalers"]

In [None]:
#train_lens = [100, 200, 400, 500]#, 600, 800, 1_000, 2000, 3000, 4000, 5000]
train_lens = [100, 200, 400, 500, 600, 800, 1_000, 2000, 3000, 4000, 5000]
mc_size = 100

In [None]:
data_folder = "bwdataset"
data_folder = Path(data_folder)
data = scipy.io.loadmat(data_folder / "bw_matlab.mat")

y_train = data["y"] / 7e-4
u_train = data["u"] / 50.0

y_test = scipy.io.loadmat(data_folder / "yval_multisine.mat")["yval_multisine"].reshape(-1, 1) / 7e-4
u_test = scipy.io.loadmat(data_folder / "uval_multisine.mat")["uval_multisine"].reshape(-1, 1) / 50.0
N = y_train.shape[0]

In [None]:
# Mean Squared Error loss function
def mse_loss_x0_fn(ov, y, u):
    y_hat = ss_apply(ov["params"], scalers, ov["x0"], u)
    err = y - y_hat
    loss = jnp.mean(err ** 2)
    return loss

def train(ov, y, u, iters=10, lr=0.1):

        loss_cfg = partial(mse_loss_x0_fn, y=y, u=u)
#        opt = optax.adamw(learning_rate=lr)
        opt = optax.sgd(learning_rate=lr)
        state = train_state.TrainState.create(apply_fn=loss_cfg, params=ov, tx=opt)

        @jax.jit
        def make_step(state):
                loss, grads = jax.value_and_grad(state.apply_fn)(state.params)
                state = state.apply_gradients(grads=grads)
                return loss, state
        
        losses = jnp.empty(iters)
        for idx in (pbar := tqdm(range(iters))):
                loss, state = make_step(state)
                losses = losses.at[idx].set(loss)
                #if idx % 100 == 0:
                #    pbar.set_postfix_str(loss.item())

        return state.params, jnp.array(losses)


def train_adamw(ov, y, u, iters=10_000, lr=1e-3):

        loss_cfg = partial(mse_loss_x0_fn, y=y, u=u)
        opt = optax.adamw(learning_rate=lr)
#        opt = optax.sgd(learning_rate=lr)
        state = train_state.TrainState.create(apply_fn=loss_cfg, params=ov, tx=opt)

        @jax.jit
        def make_step(state):
                loss, grads = jax.value_and_grad(state.apply_fn)(state.params)
                state = state.apply_gradients(grads=grads)
                return loss, state
        
        losses = jnp.empty(iters)
        for idx in (pbar := tqdm(range(iters))):
                loss, state = make_step(state)
                losses = losses.at[idx].set(loss)
                #if idx % 100 == 0:
                #pbar.set_postfix_str(loss.item())

        return state.params, jnp.array(losses)

In [None]:
fit = onp.empty((len(train_lens), mc_size))
fit_tr = onp.empty((len(train_lens), mc_size))
train_time = onp.empty(len(train_lens))

In [None]:
# train mc_size models in parallel!

for len_idx, train_len in enumerate(train_lens):

    print(f"Processing length {train_len}...")
    
    # generate mc sequences
    key, subkey = jr.split(key)     
    start_indexes = jr.randint(subkey, shape=(mc_size,),  minval=0, maxval=N-train_len)
    mc_indexes = start_indexes[:, None] + jnp.arange(train_len)
    mc_y, mc_u = y_train[mc_indexes], u_train[mc_indexes]

    # train models
    time_start = time.time()
    print(f"Training  {mc_size} full models starting from MAML initialization with SGD...")
    key, subkey = jr.split(key)
    keys_init = jr.split(subkey, mc_size)
    opt_vars_init = {"params": params_maml, "x0": jnp.zeros((cfg.nx,))}
    opt_vars_adam, losses_full = jax.vmap(train_adamw, in_axes=(None, 0, 0))(opt_vars_init, mc_y, mc_u)
    train_time[len_idx] = time.time() - time_start

    # test adam models
    x0 = jnp.zeros((cfg.nx, ))
    y_test_hat = jax.vmap(ss_apply, in_axes=(0, None, None, None))(opt_vars_adam["params"], scalers, x0, u_test)
    y_train_hat = jax.vmap(ss_apply, in_axes=(0, None, 0, 0))(opt_vars_adam["params"], scalers, opt_vars_adam["x0"], mc_u)
    for mc_idx in range(mc_size):
        fit[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(y_test[cfg.skip_loss:], y_test_hat[mc_idx, cfg.skip_loss:])[0]
        fit_tr[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(mc_y[mc_idx, :], y_train_hat[mc_idx, :])[0]


    train_time[len_idx] = time.time() - time_start

In [None]:
# Save the final checkpoint
ckpt = {
    "train_lens": train_lens,
    "train_time": train_time,
    "fit": fit,
    "fit_tr": fit_tr,
}

ckpt_path = Path("out") / f"mc_maml.p"
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
pickle.dump(ckpt, open(ckpt_path, "wb" ))