In [None]:
import os
import sys
import copy
import numpy as np
import dill as pickle
import matplotlib.pyplot as plt

In [None]:
sys.path.append("/home/zanardi/Codes/ML/ROMAr/romar/")
from romar import env
from romar import utils
from romar.systems import BoxAd

In [None]:
env_opts = {
  "backend": "numpy",
  "device": "cpu",
  "device_idx": 0,
  "nb_threads": 2,
  "epsilon": None,
  "floatx": "float64"
}
env.set(**env_opts)

In [None]:
path_to_dtb = "/home/zanardi/Codes/ML/ROMAr/romar/examples/database"
path_to_basis = "/home/zanardi/Codes/ML/ROMAr/run_old2/rad_on_test10/max_mom_2a/cobras_basis.p"
path_to_data = "/home/zanardi/Codes/ML/ROMAr/run_old2/rad_on_test10/data/case_mrho_mT.p"
path_to_rom_data = "/home/zanardi/Codes/ML/ROMAr/romar/examples/database/utils/rom_data"

In [None]:
os.makedirs(path_to_rom_data, exist_ok=True)

In [None]:
r = 10
use_proj = False

In [None]:
system = BoxAd(
  species={k: path_to_dtb + "/species/" + k + ".json" for k in ("Ar", "Arp", "em")},
  kin_dtb=path_to_dtb + "/rates/kin_fit.p",
  rad_dtb=path_to_dtb + "/rates/rad_fit.p",
  use_rad=True,
  use_proj=use_proj,
  use_tables=False
)

Load basis

In [None]:
basis = pickle.load(open(path_to_basis, "rb"))

In [None]:
system.rom.build(
  phi=basis["phi"][r],
  psi=basis["psi"][r],
  **{k: basis[k] for k in ("mask", "xref", "xscale")}
)

Solve ROM

In [None]:
rho = 1e-2
t = np.loadtxt(path_to_rom_data + "/t.txt")
y_true = np.loadtxt(path_to_rom_data + "/y_true.txt")
y_pred = np.loadtxt(path_to_rom_data + "/y_romar.txt")

In [None]:
# Get pe
y = copy.deepcopy(y_true.T)
system.mix.set_rho(rho)
w, Th, Te = y[:-2], y[-2], y[-1]
n = system.mix.get_n(w)
pe = system.mix.get_pe(Te, ne=n[0])
y[-1] = pe

In [None]:
# Solve ROM
z = system.solve_rom(t, y[:,0], rho, tout=1e2, decode=False)[0]

In [None]:
# Save 'z' states for Plato
y_hat = system.rom.decode(z.T, is_der=False).T
Te = system.get_prim(y_hat, clip=False)[-1]
zt = np.vstack([z[:-1], Te]).T
np.savetxt(path_to_rom_data + "/z_true.txt", zt)

In [None]:
# # Load test case
# icase = utils.load_case(filename=path_to_data)
# t, y0, rho = [icase[k] for k in ("t", "y0", "rho")]
# # Time grid
# t = system.get_tgrid(1e-12, t.max(), 200)
# Solve ROM
# z = system.solve_rom(t, y[:,0], rho, timeout=1e2, decode=False)[0]

In [None]:
# # Save 'y' states for Plato
# Te = system.get_prim(y, clip=False)[-1]
# yt = np.vstack([y[:-1], Te]).T
# np.savetxt(path_to_rom_data + "/y.txt", yt)

In [None]:
def fun(z, y):
  y_hat, z_hat, dydt, dydt_hat, dzdt = [], [], [], [], []
  for i in range(len(z)):

    dydt.append(system._fun(0.0, y[i]))

    yi = system.rom.decode(z[i], is_der=False)
    y_hat.append(yi)

    zi = system.rom.encode(yi, is_der=False)
    z_hat.append(zi)
    
    dyidt = system._fun(0.0, yi)
    dydt_hat.append(dyidt)
    
    dzdt.append(system.rom.encode(dyidt, is_der=True))

  return map(np.vstack, (y_hat, z_hat, dydt, dydt_hat, dzdt))

In [None]:
y_hat, z_hat, dydt, dydt_hat, dzdt = fun(z.T, y.T)

In [None]:
y_hat_pred = np.loadtxt(path_to_rom_data + "/y_hat.txt")
z_hat_pred = np.loadtxt(path_to_rom_data + "/z_hat.txt")
dydt_pred = np.loadtxt(path_to_rom_data + "/dydt.txt")
dydt_hat_pred = np.loadtxt(path_to_rom_data + "/dydt_hat.txt")
dzdt_pred = np.loadtxt(path_to_rom_data + "/dzdt.txt")

In [None]:
# Plot 'y' solution
for i in range(system.nb_comp):
  plt.loglog(t[1:], y_true[:,i][1:])
  plt.loglog(t[1:], y_pred[:,i][1:], ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
plt.xlabel("$t$ [s]")
plt.ylabel("$w_i$")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/yt_sol.png")
plt.show()
plt.close()

# Plot 'y' error
err = np.abs((y_true[:,:-2]-y_pred[:,:-2])/y_true[:,:-2])
# err = np.abs(y_true[:,:-2]-y_pred[:,:-2])
plt.loglog(t[1:], err[1:])
plt.xlabel("$t$ [s]")
plt.ylabel("$w_i$ rel. error")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/yt_err.png")
plt.show()
plt.close()

In [None]:
# Plot 'y_hat' solution
for i in range(system.nb_comp):
  plt.loglog(t[1:], y_hat[:,i][1:])
  plt.loglog(t[1:], y_hat_pred[:,i][1:], ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
plt.xlabel("$t$ [s]")
plt.ylabel("$\hat{w}_i$")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/y_hat_sol.png")
plt.show()
plt.close()

# Plot 'y_hat' error
err = np.abs((y_hat[:,:-2]-y_hat_pred[:,:-2])/y_hat[:,:-2])
# err = np.abs((y_hat[:,:-2]-y_hat_pred[:,:-2]))
plt.loglog(t[1:], err[1:])
plt.xlabel("$t$ [s]")
plt.ylabel("$\hat{w}_i$ rel. error")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/y_hat_err.png")
plt.show()
plt.close()

In [None]:
# Plot 'z_hat' solution
for i in range(r):
  plt.loglog(t[1:], np.abs(zt[:,i][1:]))
  plt.loglog(t[1:], np.abs(z_hat_pred[:,i][1:]), ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
plt.xlabel("$t$ [s]")
plt.ylabel("$\\left|\hat{z}_i\\right|$")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/z_hat_sol.png")
plt.show()
plt.close()

# Plot 'z_hat' error
err = np.abs((zt[:,:-2]-z_hat_pred[:,:-2])/zt[:,:-2])
# err = np.abs((zt[:,:-2]-z_hat_pred[:,:-2]))
plt.loglog(t[1:], err[1:])
plt.xlabel("$t$ [s]")
plt.ylabel("$\hat{z}_i$ rel. error")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/z_hat_err.png")
plt.show()
plt.close()

In [None]:
# Plot 'dydt' solution
for i in range(system.nb_comp):
  plt.loglog(t[1:], np.abs(dydt[:,i][1:]))
  plt.loglog(t[1:], np.abs(dydt_pred[:,i][1:]), ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
plt.xlabel("$t$ [s]")
plt.ylabel("$\\left|\\dfrac{dw_i}{dt}\\right|$")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/dydt_sol.png")
plt.show()
plt.close()

# Plot 'dydt' error
err = np.abs((dydt[:,:-2]-dydt_pred[:,:-2])/dydt[:,:-2])
# err = np.abs((dydt[:,:-2]-dydt_pred[:,:-2]))
plt.loglog(t[1:], err[1:])
plt.xlabel("$t$ [s]")
plt.ylabel("$\\dfrac{dw_i}{dt}$ rel. error")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/dydt_err.png")
plt.show()
plt.close()

In [None]:
# Plot 'dydt_hat' solution
for i in range(system.nb_comp):
  # plt.loglog(t[1:], dydt_hat[:,i][1:])
  # plt.loglog(t[1:], dydt_hat_pred[:,i][1:], ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
  plt.loglog(t[1:], np.abs(dydt_hat[:,i][1:]))
  plt.loglog(t[1:], np.abs(dydt_hat_pred[:,i][1:]), ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
plt.xlabel("$t$ [s]")
plt.ylabel("$\\left|\\dfrac{d\hat{w}_i}{dt}\\right|$")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/dydt_hat_sol.png")
plt.show()
plt.close()

# Plot 'dydt_hat' error
err = np.abs((dydt_hat[:,:-2]-dydt_hat_pred[:,:-2])/dydt_hat[:,:-2])
# err = np.abs((dydt_hat[:,:-2]-dydt_hat_pred[:,:-2]))
plt.loglog(t[1:], err[1:])
plt.xlabel("$t$ [s]")
plt.ylabel("$\\dfrac{d\hat{w}_i}{dt}$ rel. error")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/dydt_hat_err.png")
plt.show()
plt.close()

In [None]:
# Plot 'dzdt' solution
for i in range(r):
  # plt.loglog(t[1:], dzdt[:,i][1:])
  # plt.loglog(t[1:], dzdt_pred[:,i][1:], ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
  plt.loglog(t[1:], np.abs(dzdt[:,i][1:]))
  plt.loglog(t[1:], np.abs(dzdt_pred[:,i][1:]), ls="--", lw=2.0, color=plt.gca().lines[-1].get_color())
plt.xlabel("$t$ [s]")
plt.ylabel("$\\left|\\dfrac{dz_i}{dt}\\right|$")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/dzdt_sol.png")
plt.show()
plt.close()

# Plot 'dzdt' error
err = np.abs((dzdt[:,:-2]-dzdt_pred[:,:-2])/dzdt[:,:-2])
# err = np.abs((dzdt[:,:-2]-dzdt_pred[:,:-2]))
plt.loglog(t[1:], err[1:])
plt.xlabel("$t$ [s]")
plt.ylabel("$\\dfrac{dz_i}{dt}$ rel. error")
plt.tight_layout()
plt.savefig(path_to_rom_data + "/dzdt_err.png")
plt.show()
plt.close()