In [31]:
from plot_utils import set_size, tex_fonts, LINEWIDTH_L_CSS as linewidth
import pickle
import scipy
from pathlib import Path
import jax.numpy as jnp
import numpy as onp
import matplotlib.pyplot as plt
import jax
from neuralss import ss_init, ss_apply
import nonlinear_benchmarks

In [32]:
from argparse import Namespace

cfg = {
    "nu": 1,
    "ny": 1,
    "nx": 3,
    "hidden_f": 16,
    "hidden_g": 16,
    "skip_loss": 500,
}
cfg = Namespace(**cfg)

scalers = {"f": {"lin": 1e-2, "nl": 1e-2}, "g": {"lin": 1e0, "nl": 1e0}}

In [33]:
#plt.rcParams.update(tex_fonts) # use latex fonts
plt.rcParams.update({"axes.grid": True}) 

In [34]:
%matplotlib widget

In [35]:
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_device", jax.devices("cpu")[0])

In [36]:
filename = Path("out") / "full_alldata.pkl" 
ckpt = pickle.load(open(filename, "rb"))


In [37]:
data_folder = "bwdataset"
data_folder = Path(data_folder)

u2 = scipy.io.loadmat(data_folder / "uval_multisine.mat")["uval_multisine"].reshape(-1, 1)
y2 = scipy.io.loadmat(data_folder / "yval_multisine.mat")["yval_multisine"].reshape(-1, 1)

u2 = u2 / 50.0
y2 = y2 / 7e-4

In [None]:
x0 = jnp.zeros((cfg.nx, ))
y2_hat = ss_apply(ckpt["params"], scalers, x0, u2)
#y2_hat = ss_apply(opt_vars_adam["params"], scalers, x0, u2)
plt.figure()
plt.plot(y2, "k", label="true")
plt.plot(y2_hat, "b", label="reconstructed")
plt.plot(y2 - y2_hat, "r", label="reconstruction error")
plt.axvline(cfg.skip_loss, color="k")
plt.ylim([-4, 4]);

In [None]:
fit_full = nonlinear_benchmarks.error_metrics.fit_index(y2[cfg.skip_loss:], y2_hat[cfg.skip_loss:])
rmse_full = nonlinear_benchmarks.error_metrics.RMSE(y2[cfg.skip_loss:], y2_hat[cfg.skip_loss:])*7e-4 * 1e5
fit_full, rmse_full # (Array([98.54584449], dtype=float64), array([0.96720211]))
print(f"Fit index: {fit_full[0]:.2f} %")
print(f"RMSE: {rmse_full[0]:.2f}e-5")

In [40]:
#w, v = onp.linalg.eigh(onp.array(data["H"]))#[:-3, :-3])
w, v = jnp.linalg.eigh(ckpt["H"][:-3, :-3])#[:-3, :-3])
w, v = jnp.linalg.eigh(ckpt["H"])
w = w[::-1] # eigvals
v = v[:, ::-1] # eigvecs

In [None]:
fig, ax = plt.subplots(1, 1, figsize=set_size(linewidth, fraction=1.0))
#plt.title("Hessian eigenvalues")
plt.plot(w[1:], "k*")
plt.xlabel("Eigenvalue index")
plt.ylabel("Eigenvalue")
plt.tight_layout()
plt.savefig("hessian_eigenvalues.pdf")
