In [None]:
# Train a single full-order model on all data

In [None]:
from functools import partial
from pathlib import Path
import time
import pickle
from tqdm import tqdm
from argparse import Namespace
import jax.numpy as jnp
import jax
import jax.random as jr
import jax.flatten_util
import optax
import jaxopt
from flax.training import train_state
import scipy
from neuralss import ss_init, ss_apply
import matplotlib.pyplot as plt
import nonlinear_benchmarks

In [None]:
cfg = {
    "nu": 1,
    "ny": 1,
    "nx": 3,
    "hidden_f": 16,
    "hidden_g": 16,
    "skip_loss": 500,
}
cfg = Namespace(**cfg)

In [None]:
key = jr.key(42)

In [None]:
jax.config.update("jax_enable_x64", True)
dtype_adam = jnp.float32
dtype_bfgs = jnp.float64 # need float64 to squeeze the last bit of performance...

In [None]:
#%matplotlib widget

In [None]:
from jax.lib import xla_bridge
jax.config.update("jax_default_device", jax.devices("cpu")[0])
print(xla_bridge.get_backend().platform)

In [None]:
data_folder = "bwdataset"
data_folder = Path(data_folder)
data = scipy.io.loadmat(data_folder / "bw_matlab.mat")

In [None]:
u_train = data["u"].reshape(-1, 1)
y_train = data["y"].reshape(-1, 1)

u_train = u_train / 50.0
y_train = y_train / 7e-4

In [None]:
plt.figure()
plt.plot(y_train)

In [None]:
u_test = scipy.io.loadmat(data_folder / "uval_multisine.mat")["uval_multisine"].reshape(-1, 1)
y_test = scipy.io.loadmat(data_folder / "yval_multisine.mat")["yval_multisine"].reshape(-1, 1)

u_test = u_test / 50.0
y_test = y_test / 7e-4

In [None]:
scalers = {"f": {"lin": 1e-2, "nl": 1e-2}, "g": {"lin": 1e0, "nl": 1e0}}
key, subkey = jr.split(key, 2)
opt_vars_init = {"params": ss_init(subkey, nu=cfg.nu, ny=cfg.ny, nx=cfg.nx), "x0": jnp.zeros(cfg.nx, )}

In [None]:
def loss_full(ov, y, u):
       
    y_hat = ss_apply(ov["params"], scalers, ov["x0"], u)
    #scaled_err = (y1 - y1_hat) / ckpt["sigma_noise"]
    #loss = jnp.sum(scaled_err**2) + jnp.sum(ov["z"]**2)
    loss = jnp.mean((y - y_hat)**2)
    return loss


def train_full_model(ov, y, u, iters=100_000, lr=1e-3):

        opt = optax.adamw(learning_rate=lr)
        loss_fn = partial(loss_full, y=y, u=u)
        state = train_state.TrainState.create(apply_fn=loss_fn, params=ov, tx=opt)

        @jax.jit
        def make_step(state):
                loss, grads = jax.value_and_grad(state.apply_fn)(state.params)
                state = state.apply_gradients(grads=grads)
                return loss, state
        
        losses = jnp.empty(iters)
        for idx in (pbar := tqdm(range(iters))):
                loss, state = make_step(state)
                losses = losses.at[idx].set(loss)
                pbar.set_postfix_str(loss.item())

        return state.params, jnp.array(losses)

In [None]:
time_start = time.time()

In [None]:
opt_vars_adam, losses_adam = train_full_model(opt_vars_init, y=y_train.astype(dtype_adam), u=u_train.astype(dtype_adam), iters=40_000, lr=1e-3)
tima_adam = time.time() - time_start
print(f"Adam took {tima_adam:.2f} seconds")

In [None]:
options = {"disp": True, "return_all": True} #, 'iprint': 1}

loss_bfgs = partial(loss_full, y=y_train.astype(dtype_bfgs), u=u_train.astype(dtype_bfgs))
solver = jaxopt.ScipyMinimize(
    fun=loss_bfgs, tol=1e-6, method="BFGS", maxiter=10_000, options=options)

opt_vars_bfgs, state_full_bfgs = solver.run(opt_vars_adam)

In [None]:
train_time = time.time() - time_start
print(f"Training time: {train_time:.2f} s")

In [None]:
# Use in the decoder both to define x0 and the model parameters
x0 = jnp.zeros((cfg.nx, ))
y2_hat = ss_apply(opt_vars_bfgs["params"], scalers, x0, u_test)
#y2_hat = ss_apply(opt_vars_adam["params"], scalers, x0, u2)
plt.figure()
plt.plot(y_test, "k", label="true")
plt.plot(y2_hat, "b", label="reconstructed")
plt.plot(y_test - y2_hat, "r", label="reconstruction error")
plt.axvline(cfg.skip_loss, color="k")
plt.ylim([-4, 4]);

In [None]:
fit_full = nonlinear_benchmarks.error_metrics.fit_index(y_test[cfg.skip_loss:], y2_hat[cfg.skip_loss:])
rmse_full = nonlinear_benchmarks.error_metrics.RMSE(y_test[cfg.skip_loss:], y2_hat[cfg.skip_loss:])*7e-4 * 1e5
fit_full, rmse_full 
print(f"Fit index: {fit_full[0]:.2f} %")
print(f"RMSE: {rmse_full[0]:.2f}e-5")

In [None]:
# Use in the decoder both to define x0 and the model parameters
y1_hat = ss_apply(opt_vars_bfgs["params"], scalers, opt_vars_bfgs["x0"], u_train)

plt.figure()
plt.plot(y_train, "k", label="true")
plt.plot(y1_hat, "b", label="reconstructed")
plt.plot(y_train - y1_hat, "r", label="reconstruction error");

In [None]:
loss_fn = loss_bfgs
opt_vars_full_flat, unflatten_full = jax.flatten_util.ravel_pytree(opt_vars_bfgs)
loss_fn_flat = lambda of: loss_fn(unflatten_full(of))
loss_fn_flat(opt_vars_full_flat)
H = jax.hessian(loss_fn_flat)(opt_vars_full_flat)

In [None]:
filename = Path("out") / "full_alldata.pkl" 

ckpt = {
    "H": H,
    "params": opt_vars_bfgs["params"],
    "x0": opt_vars_bfgs["x0"],
    "cfg": cfg,
    "scalers": scalers,
    "train_time_adam": tima_adam,
    "train_time": train_time,
}

pickle.dump(ckpt, open(filename, "wb" ))

In [None]:
# Training time: 5068.36 s
# Fit index: 98.91 %
# RMSE: 0.73e-5