In [None]:
import time
from functools import partial
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import pickle
import jax.numpy as jnp
import jax
import jax.random as jr
import numpy as onp
import jax.flatten_util
import pandas as pd
import optax
import jaxopt
from flax.training import train_state
import scipy
from neuralss import ss_init, ss_apply
from ae import Encoder, Projector
import nonlinear_benchmarks

In [None]:
key = jr.key(42)

In [None]:
jax.config.update("jax_default_device", jax.devices("gpu")[0])

In [None]:
ckpt_path = Path("out") / f"vae.p"
ckpt = pickle.load(open(ckpt_path, "rb"))

cfg = ckpt["cfg"]
params_enc = ckpt["params_enc"]
params_ss = ckpt["params_dec"] 
params_proj = ckpt["params_proj"]
sigma_noise = ckpt["sigma_noise"]
scalers = ckpt["scalers"]
params_dec_flat, unflatten_dec = jax.flatten_util.ravel_pytree(params_ss)
n_params = params_dec_flat.shape[0]

In [None]:
train_lens = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1_000, 2000, 3000, 4000, 5000]
#train_lens = [100, 200]
mc_size = 100

In [None]:
data_folder = "bwdataset"
data_folder = Path(data_folder)
data = scipy.io.loadmat(data_folder / "bw_matlab.mat")

y_train = data["y"] / 7e-4
u_train = data["u"] / 50.0

y_test = scipy.io.loadmat(data_folder / "yval_multisine.mat")["yval_multisine"].reshape(-1, 1) / 7e-4
u_test = scipy.io.loadmat(data_folder / "uval_multisine.mat")["uval_multisine"].reshape(-1, 1) / 50.0
N = y_train.shape[0]

In [None]:
proj = Projector(outputs=n_params, unflatten=unflatten_dec)
enc = Encoder(mlp_layers=[cfg.nh, cfg.nz], rnn_size=cfg.nh)

In [None]:
def loss_full(ov, y, u):
    y_hat = ss_apply(ov["params"], scalers, ov["x0"], u)
    #scaled_err = (y1 - y1_hat) / ckpt["sigma_noise"]
    #loss = jnp.sum(scaled_err**2) + jnp.sum(ov["z"]**2)
    loss = jnp.mean((y - y_hat)**2)
    return loss


def train_full(ov, y, u, iters=40_000, lr=1e-3):
        loss_cfg = partial(loss_full, y=y, u=u)
        opt = optax.adamw(learning_rate=lr)
        state = train_state.TrainState.create(apply_fn=loss_cfg, params=ov, tx=opt)

        @jax.jit
        def make_step(state):
                loss, grads = jax.value_and_grad(state.apply_fn)(state.params)
                state = state.apply_gradients(grads=grads)
                return loss, state
        
        losses = jnp.empty(iters)
        for idx in (pbar := tqdm(range(iters))):
                loss, state = make_step(state)
                losses = losses.at[idx].set(loss)
                #pbar.set_postfix_str(loss.item())

        return state.params, jnp.array(losses)

In [None]:
fit_adam = onp.empty((len(train_lens), mc_size))
fit_bfgs = onp.empty((len(train_lens), mc_size))
fit_red = onp.empty((len(train_lens), mc_size))
train_time_adam = onp.empty(len(train_lens))
train_time = onp.empty(len(train_lens))

In [None]:
# train mc_size models in parallel!

for len_idx, train_len in enumerate(train_lens):

    print(f"Processing length {train_len}...")
    
    # generate mc sequences
    key, subkey = jr.split(key)     
    start_indexes = jr.randint(subkey, shape=(mc_size,),  minval=0, maxval=N-train_len)
    mc_indexes = start_indexes[:, None] + jnp.arange(train_len)
    mc_y, mc_u = y_train[mc_indexes], u_train[mc_indexes]

    time_start = time.time()
    # train adam models
    print(f"Training  {mc_size} full models with ADAM...")
    key, subkey = jr.split(key)
    keys_init = jr.split(subkey, mc_size)
    params_init = jax.vmap(ss_init)(keys_init)
    opt_vars_init = {"params": params_init, 
                        "x0": jnp.zeros((mc_size, cfg.nx))}
    opt_vars_adam, losses_full_adam = jax.vmap(train_full, in_axes=(0, 0, 0))(opt_vars_init, mc_y, mc_u)
    train_time_adam[len_idx] = time.time() - time_start

    # test adam models
    x0 = jnp.zeros((cfg.nx, ))
    y_test_hat = jax.vmap(ss_apply, in_axes=(0, None, None, None))(opt_vars_adam["params"], scalers, x0, u_test)
    for mc_idx in range(mc_size):
        fit_adam[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(y_test[cfg.skip_loss:], y_test_hat[mc_idx, cfg.skip_loss:])[0]

    # train bfgs models
    opt_vars_bfgs = []
    states_bfgs = []
    for mc_idx in range(mc_size):
        print(f"Training model {mc_idx} with BFGS...")
        loss_i = partial(loss_full, y=mc_y[mc_idx], u=mc_u[mc_idx])
        solver = jaxopt.ScipyMinimize(fun=loss_i, tol=1e-6, method="BFGS", maxiter=10_000)
        ov_adam_i = jax.tree.map(lambda x: x[mc_idx], opt_vars_adam)
        ov_bfgs, s_bfgs = solver.run(ov_adam_i)
        opt_vars_bfgs.append(ov_bfgs)
        states_bfgs.append(s_bfgs)    
    opt_vars_bfgs = jax.tree.map(lambda *x: jnp.stack(x), *opt_vars_bfgs)
    states_bfgs = jax.tree.map(lambda *x: jnp.stack(x), *states_bfgs)

    train_time[len_idx] = time.time() - time_start

    # test bfgs models
    x0 = jnp.zeros((cfg.nx, ))
    y_test_hat = jax.vmap(ss_apply, in_axes=(0, None, None, None))(opt_vars_bfgs["params"], scalers, x0, u_test)
    for mc_idx in range(mc_size):
        fit_bfgs[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(y_test[cfg.skip_loss:], y_test_hat[mc_idx, cfg.skip_loss:])[0]

In [None]:
df_adam = pd.DataFrame(fit_adam.T, columns=[str(l) for l in train_lens])
df_adam = df_adam.melt(var_name="length", value_name="fit")
df_adam.insert(0, "model", "full (adam)")

df_bfgs = pd.DataFrame(fit_bfgs.T, columns=[str(l) for l in train_lens])
df_bfgs = df_bfgs.melt(var_name="length", value_name="fit")
df_bfgs.insert(0, "model", "full (bfgs)")

df_all = pd.concat((df_adam, df_bfgs), ignore_index=True)
df_all.to_pickle(Path("out") / "df_mc_full.pkl")

df_time = pd.DataFrame({"length": train_lens, "time_adam": train_time_adam, "time": train_time})
df_time.to_pickle(Path("out") / "df_mc_full_time.pkl")