In [None]:
import time
from functools import partial
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import pickle
import jax.numpy as jnp
import jax
import jax.random as jr
import numpy as onp
import jax.flatten_util
import pandas as pd
import optax
import jaxopt
from flax.training import train_state
import scipy
from neuralss import ss_init, ss_apply
from ae import Encoder, Projector
import nonlinear_benchmarks

In [None]:
key = jr.key(42)

In [None]:
jax.config.update("jax_default_device", jax.devices("gpu")[0])

In [None]:
ckpt_path = Path("out") / f"hypernet.p"
ckpt = pickle.load(open(ckpt_path, "rb"))

cfg = ckpt["cfg"]
params_enc, params_proj, params_ss = ckpt["params"]
scalers = ckpt["scalers"]
params_dec_flat, unflatten_dec = jax.flatten_util.ravel_pytree(params_ss)
n_params = params_dec_flat.shape[0]

In [None]:
train_lens = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1_000, 2000, 3000, 4000, 5000]
mc_size = 100

In [None]:
data_folder = "bwdataset"
data_folder = Path(data_folder)
data = scipy.io.loadmat(data_folder / "bw_matlab.mat")

y_train = data["y"] / 7e-4
u_train = data["u"] / 50.0

y_test = scipy.io.loadmat(data_folder / "yval_multisine.mat")["yval_multisine"].reshape(-1, 1) / 7e-4
u_test = scipy.io.loadmat(data_folder / "uval_multisine.mat")["uval_multisine"].reshape(-1, 1) / 50.0
N = y_train.shape[0]

In [None]:
enc = Encoder(mlp_layers=[cfg.nh, cfg.nz], rnn_size=cfg.nh)
proj = Projector(outputs=n_params,  unflatten=unflatten_dec)

In [None]:
def loss_reduced(ov, y, u):
    
    # Project the latent into parameters of the decoder
    pe = proj.apply(params_proj, ov["z"])

    # Use in the decoder output to define/update the model parameters
    pa = jax.tree.map(lambda x, y: x+y, params_ss, pe)
    
    y_hat = ss_apply(pa, scalers, ov["x0"], u)
    #scaled_err = (y1 - y1_hat) / ckpt["sigma_noise"]
    #loss = jnp.sum(scaled_err**2) + jnp.sum(ov["z"]**2)
    loss = jnp.mean((y - y_hat)**2)
    return loss


def train_reduced(ov, y, u, iters=10_000, lr=1e-3):

        loss_cfg = partial(loss_reduced, y=y, u=u)
        opt = optax.adamw(learning_rate=lr)
        state = train_state.TrainState.create(apply_fn=loss_cfg, params=ov, tx=opt)

        @jax.jit
        def make_step(state):
                loss, grads = jax.value_and_grad(state.apply_fn)(state.params)
                state = state.apply_gradients(grads=grads)
                return loss, state
        
        losses = jnp.empty(iters)
        for idx in (pbar := tqdm(range(iters))):
                loss, state = make_step(state)
                losses = losses.at[idx].set(loss)
                #pbar.set_postfix_str(loss.item())

        return state.params, jnp.array(losses)



In [None]:
fit_red = onp.empty((len(train_lens), mc_size))
train_time = onp.empty(len(train_lens))

In [None]:
# train mc_size models in parallel!

for len_idx, train_len in enumerate(train_lens):

    print(f"Processing length {train_len}...")
    
    time_start = time.time()
    # generate mc sequences
    key, subkey = jr.split(key)     
    start_indexes = jr.randint(subkey, shape=(mc_size,),  minval=0, maxval=N-train_len)
    mc_indexes = start_indexes[:, None] + jnp.arange(train_len)
    mc_y, mc_u = y_train[mc_indexes], u_train[mc_indexes]

    print(f"Training {mc_size} reduced models...")
    # train reduced models
    z_init = jnp.zeros((mc_size, cfg.nz, ))
    z_init = jax.vmap(enc.apply, in_axes=(None, 0, 0))(params_enc, mc_y, mc_u)
    opt_vars_red_init = {"z": z_init, "x0": jnp.zeros((mc_size, cfg.nx, ))}
    opt_vars_red, losses_red = jax.vmap(train_reduced, in_axes=(0, 0, 0))(opt_vars_red_init, mc_y, mc_u)

    # project trained zetas to the ss parameter space
    params_ss_proj = jax.vmap(proj.apply, in_axes=(None, 0))(params_proj, opt_vars_red["z"])
    params_red = jax.tree.map(lambda x, y: x+y, params_ss, params_ss_proj)

    train_time[len_idx] = time.time() - time_start
    
    # test reduced models
    x0 = jnp.zeros((cfg.nx, ))
    y_test_hat = jax.vmap(ss_apply, in_axes=(0, None, None, None))(params_red, scalers, x0, u_test)

    for mc_idx in range(mc_size):
        fit_red[len_idx, mc_idx] = nonlinear_benchmarks.error_metrics.fit_index(y_test[cfg.skip_loss:], y_test_hat[mc_idx, cfg.skip_loss:])[0]

In [None]:
# Save the final checkpoint
ckpt = {
    "train_lens": train_lens,
    "train_time": train_time,
    "fit": fit_red,
}

ckpt_path = Path("out") / f"mc_red.p"
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
pickle.dump(ckpt, open(ckpt_path, "wb" ))