In [1]:
import time
from functools import partial
from pathlib import Path
from tqdm import tqdm
import pickle
import jax.numpy as jnp
import jax
import jax.random as jr
import numpy as onp
import pandas as pd
import optax
import jaxopt
from flax.training import train_state
import scipy
from neuralss import ss_init, ss_apply
import nonlinear_benchmarks

In [2]:
key = jr.key(42)

In [3]:
jax.config.update("jax_default_device", jax.devices("cpu")[0])

In [4]:
ckpt_path = Path("out") / f"maml_1step.p"
ckpt = pickle.load(open(ckpt_path, "rb"))

cfg = ckpt["cfg"]
params_maml = ckpt["params"]
scalers = ckpt["scalers"]

In [5]:
#train_lens = [100, 200, 400, 500]#, 600, 800, 1_000, 2000, 3000, 4000, 5000]
train_lens = [100, 200, 400, 500, 600, 800, 1_000, 2000, 3000, 4000, 5000]
mc_size = 100

In [6]:
data_folder = "bwdataset"
data_folder = Path(data_folder)
data = scipy.io.loadmat(data_folder / "bw_matlab.mat")

y_train = data["y"] / 7e-4
u_train = data["u"] / 50.0

y_test = scipy.io.loadmat(data_folder / "yval_multisine.mat")["yval_multisine"].reshape(-1, 1) / 7e-4
u_test = scipy.io.loadmat(data_folder / "uval_multisine.mat")["uval_multisine"].reshape(-1, 1) / 50.0
N = y_train.shape[0]

In [None]:
# Mean Squared Error loss function
def mse_loss_x0_fn(ov, y, u):
    y_hat = ss_apply(ov["params"], scalers, ov["x0"], u)
    err = y - y_hat
    loss = jnp.mean(err ** 2)
    return loss


def train(ov, y, u, iters=40_000, lr=1e-3):

        loss_cfg = partial(mse_loss_x0_fn, y=y, u=u)
        opt = optax.adamw(learning_rate=lr)
        state = train_state.TrainState.create(apply_fn=loss_cfg, params=ov, tx=opt)

        @jax.jit
        def make_step(state):
                loss, grads = jax.value_and_grad(state.apply_fn)(state.params)
                state = state.apply_gradients(grads=grads)
                return loss, state
        
        losses = jnp.empty(iters)
        for idx in (pbar := tqdm(range(iters))):
                loss, state = make_step(state)
                losses = losses.at[idx].set(loss)
                #pbar.set_postfix_str(loss.item())

        return state.params, jnp.array(losses)

In [8]:
fit_adam = onp.empty((len(train_lens), mc_size))
fit_bfgs = onp.empty((len(train_lens), mc_size))
fit_adam_tr = onp.empty((len(train_lens), mc_size))
fit_bfgs_tr = onp.empty((len(train_lens), mc_size))
train_time_adam = onp.empty(len(train_lens))
train_time = onp.empty(len(train_lens))

In [9]:
# train mc_size models in parallel!

for len_idx, train_len in enumerate(train_lens):

    print(f"Processing length {train_len}...")
    
    # generate mc sequences
    key, subkey = jr.split(key)     
    start_indexes = jr.randint(subkey, shape=(mc_size,),  minval=0, maxval=N-train_len)
    mc_indexes = start_indexes[:, None] + jnp.arange(train_len)
    mc_y, mc_u = y_train[mc_indexes], u_train[mc_indexes]

    # train models
    time_start = time.time()
    print(f"Training  {mc_size} full models starting from MAML initialization with ADAM...")
    key, subkey = jr.split(key)
    keys_init = jr.split(subkey, mc_size)
    opt_vars_init = {"params": params_maml, 
                        "x0": jnp.zeros((mc_size, cfg.nx))}
    opt_vars_adam, losses_full = jax.vmap(train, in_axes=({"params": None, "x0": 0}, 0, 0))(opt_vars_init, mc_y, mc_u)
    train_time_adam[len_idx] = time.time() - time_start

    # test adam models
    x0 = jnp.zeros((cfg.nx, ))
    y_test_hat = jax.vmap(ss_apply, in_axes=(0, None, None, None))(opt_vars_adam["params"], scalers, x0, u_test)
    y_train_hat = jax.vmap(ss_apply, in_axes=(0, None, 0, 0))(opt_vars_adam["params"], scalers, opt_vars_adam["x0"], mc_u)
    for mc_idx in range(mc_size):
        fit_adam[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(y_test[cfg.skip_loss:], y_test_hat[mc_idx, cfg.skip_loss:])[0]
        fit_adam_tr[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(mc_y[mc_idx, :], y_train_hat[mc_idx, :])[0]


    # train bfgs models
    opt_vars_bfgs = []
    states_bfgs = []
    for mc_idx in range(mc_size):
        print(f"Training model {mc_idx} with BFGS...")
        loss_i = partial(mse_loss_x0_fn, y=mc_y[mc_idx], u=mc_u[mc_idx])
        solver = jaxopt.ScipyMinimize(fun=loss_i, tol=1e-6, method="BFGS", maxiter=10_000)
        ov_adam_i = jax.tree.map(lambda x: x[mc_idx], opt_vars_adam)
        ov_bfgs, s_bfgs = solver.run(ov_adam_i)
        opt_vars_bfgs.append(ov_bfgs)
        states_bfgs.append(s_bfgs)    
    opt_vars_bfgs = jax.tree.map(lambda *x: jnp.stack(x), *opt_vars_bfgs)
    states_bfgs = jax.tree.map(lambda *x: jnp.stack(x), *states_bfgs)

    train_time[len_idx] = time.time() - time_start
    
    # test bfgs models
    x0 = jnp.zeros((cfg.nx, ))
    y_test_hat = jax.vmap(ss_apply, in_axes=(0, None, None, None))(opt_vars_bfgs["params"], scalers, x0, u_test)
    y_train_hat = jax.vmap(ss_apply, in_axes=(0, None, 0, 0))(opt_vars_bfgs["params"], scalers, opt_vars_bfgs["x0"], mc_u)
    for mc_idx in range(mc_size):
        fit_bfgs[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(y_test[cfg.skip_loss:], y_test_hat[mc_idx, cfg.skip_loss:])[0]
        fit_bfgs_tr[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(mc_y[mc_idx, :], y_train_hat[mc_idx, :])[0]

Processing length 100...
Training  100 full models starting from MAML initialization with ADAM...


  0%|          | 0/10000 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# Save the final checkpoint
ckpt = {
    "train_lens": train_lens,
    "train_time_adam": train_time_adam,
    "train_time": train_time,
    "fit_adam": fit_adam,
    "fit_adam_tr": fit_adam_tr,
    "fit_bfgs": fit_bfgs,
    "fit_bfgs_tr": fit_bfgs_tr,
}

ckpt_path = Path("out") / f"mc_maml.p"
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
pickle.dump(ckpt, open(ckpt_path, "wb" ))