In [None]:
import os
import sys
import copy
import torch
import numpy as np
import scipy as sp
import pandas as pd
import dill as pickle
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
x = np.random.rand(5,5)
ls = [x, None, x, x, None]
np.where([i is None for i in ls])[0]

In [None]:
list(map(np.shape, x))

In [None]:
sys.path.append("/home/zanardi/Codes/ML/ROMAr/romar/")
from romar import env
from romar import const
from romar import utils
from romar import backend as bkd
from romar.systems import BoxAd
# from romar.roms import CoBRAS

import romar.postproc.plotting as pltt

In [None]:
env_opts = {
  "backend": "numpy",
  "device": "cpu",
  "device_idx": 0,
  "nb_threads": 2,
  "epsilon": 1e-16,
  "floatx": "float64"
}
env.set(**env_opts)

In [None]:
path_to_dtb = "/home/zanardi/Codes/ML/ROMAr/romar/examples/database"

In [None]:
system = BoxAd(
  species={k: path_to_dtb + "/species/" + k + ".json" for k in ("Ar", "Arp", "em")},
  kin_dtb=path_to_dtb + "/rates/kin_fit.p",
  rad_dtb=path_to_dtb + "/rates/rad_fit.p",
  use_rad=True,
  use_proj=False,
  use_tables=False
)
system.compute_c_mat(
  max_mom=2,
  state_specs=False,
  include_em=False,
  include_temp=False
)

In [None]:
rho = 1e-2
T = 3.5e4
Te = 3e2
quad_mu = {
  "x": np.array([rho, T, Te]).reshape(1,-1),
  "w": np.ones((1,1))
}
tlim = [1e-14, 1e-5]

In [None]:
tfull = np.geomspace(*tlim, 100)
tfull = np.insert(tfull, 0, 0.0)
t = tfull

In [None]:
np.geomspace(1e-5, 2e-5, 5), np.geomspace(1e-14, 1e-5, 5)

In [None]:
i = 50
t0 = float(t[i])
tf = 2e-6
j = np.abs(t-tf).argmin()
tf = float(t[j])
tij = t[i:j+1]
t0-tij[0], tf-tij[-1]

In [None]:
y0, rho = system.equil.get_init_sol(quad_mu["x"].squeeze())
y = system.solve_fom(tfull, y0, rho)[0]

In [None]:
eps_vec = 1e-3*np.mean(y, axis=-1)

In [None]:
t = np.asarray([0.0, tlim[-1]])

In [None]:
x = np.random.rand(2,5)
x, np.flip(x, axis=-1)

In [None]:
x[::] = 1.0
x

In [None]:
def build_sol_interp(
  t: np.ndarray,
  x: np.ndarray
) -> sp.interpolate.interp1d:
  axis = 0 if (x.shape[0] == len(t)) else 1
  return sp.interpolate.interp1d(t, x, kind="cubic", axis=axis)

def solve_adjoint(
  t: np.ndarray,
  g0: np.ndarray,
  interp: sp.interpolate.interp1d
) -> np.ndarray:
  t = np.flip(np.sort(t))
  tf = t[0]
  tau = tf-t
  y = sp.integrate.solve_ivp(
    fun=adjoint_fun,
    t_span=[tau[0],tau[-1]],
    y0=g0,
    method="BDF",
    t_eval=tau,
    args=(tf, interp),
    first_step=1e-10,
    rtol=1e-2,
    atol=0.0,
    jac=adjoint_jac
  ).y
  return y[:,-1]

def adjoint_fun(
  tau: np.ndarray,
  g: np.ndarray,
  tf: float,
  interp: sp.interpolate.interp1d
) -> np.ndarray:
  print(tau)
  return adjoint_jac(tau, g, tf, interp) @ g

def adjoint_jac(
  tau: np.ndarray,
  g: np.ndarray,
  tf: float,
  interp: sp.interpolate.interp1d
) -> np.ndarray:
  x = interp(tf-tau)
  j = system.jac(t, x)
  return j.T

In [None]:
x_interp = build_sol_interp(tfull, y)

In [None]:
grad_adj = [solve_adjoint(t, g0, x_interp) for g0 in system.C]
grad_adj = np.vstack(grad_adj)
np.savetxt("./grad_adj.txt", grad_adj)

In [None]:
gad_m2 = np.loadtxt("./grad_adj_m2.txt")
gad_m3 = np.loadtxt("./grad_adj_m3.txt")
gfd = np.loadtxt("./grad_fd.txt")

In [None]:
np.linalg.norm(gfd-gad_m2)/np.linalg.norm(gfd)

In [None]:
np.mean(np.abs((gfd-gad_m2)/gfd))

In [None]:
np.mean(np.abs((gfd-gad_m3)/gfd))

In [None]:
np.abs((gad_m2-gad_m3)/gad_m3)

In [None]:
np.linalg.norm(gad_m2-gad_m3)/np.linalg.norm(gad_m3)

In [None]:
# np.linalg.norm(g3, axis=-1)

In [None]:
# np.linalg.norm(g2-g3)/np.linalg.norm(g3)

In [None]:
class Output():

  def __init__(self, path):
    self.counter = 0
    self.path = path
    os.makedirs(path, exist_ok=True)

  def __call__(
    self,
    t: np.ndarray,
    y0: np.ndarray,
    rho: float
  ) -> np.ndarray:
    self.counter += 1
    # Setting up
    system.use_rom = False
    y0 = system.set_up(y0, rho)
    # Solve fom
    x = system.solve_fom(t, y0, rho)[0].T
    np.savetxt(self.path+f"/x_{self.counter}.txt", x)
    # Compute output
    return system.C @ x[-1]

In [None]:
# def output(
#   tf: float,
#   y0: np.ndarray,
#   rho: float
# ) -> np.ndarray:
#   # Setting up
#   system.use_rom = False
#   y0 = system.set_up(y0, rho)
#   # Time vector
#   t = np.array([tf])
#   # Solve fom
#   x = system.solve_fom(t, y0, rho)[0].reshape(-1)
#   # Compute output
#   return system.C @ x

In [None]:
output = Output("./sols/")
grad_fd = sp.optimize.approx_fprime(
  xk=y0,
  f=lambda z: output(t, z, rho),
  epsilon=eps_vec
)
np.savetxt("./grad_fd.txt", grad_fd)

In [None]:
np.linalg.norm(grad_fd)
# np.linalg.norm(np.loadtxt("./grad_adj.txt"))

In [None]:
np.linalg.norm(grad_fd-grad_adj)/np.linalg.norm(grad_adj)

In [None]:
y = np.vstack([np.linalg.norm(system.C @ np.loadtxt(f"./sols/x_{i+1}.txt").T, axis=0) for i in range(36)])

In [None]:
plt.semilogx(tfull, y.T)
plt.show()
plt.close()

In [None]:
for i in range(36):
  for xi in x:
    plt.loglog(tfull, xi[:,i])
  plt.show()
  plt.close()

In [None]:
stop

In [None]:
def lin_adj(
  tf: float,
  y0: np.ndarray,
  rho: float
) -> np.ndarray:
  # Setting up
  system.use_rom = False
  y0 = system.set_up(y0, rho)
  # Time vector
  t = np.array([tf])
  # Compute linear operators
  system.compute_lin_fom_ops(y0)
  A, C = [getattr(system, k) for k in ("A", "C")]
  # Eigendecomposition
  l, V = sp.linalg.eig(A)
  Vinv = sp.linalg.inv(V)
  # Allocate memory
  shape = [len(t)] + list(C.T.shape)
  g = np.zeros(shape)
  # Compute solution
  VC = V.T @ C.T
  for (i, ti) in enumerate(t):
    L = np.diag(np.exp(ti*l))
    g[i] = Vinv.T @ (L @ VC)
  # Manipulate tensor
  g = np.transpose(g, axes=(1,2,0))
  g = np.reshape(g, (shape[1],-1))
  return g.T

In [None]:
grad_f_adj = lin_adj(tf, y0, rho)

In [None]:
grad_f_fd = sp.optimize.approx_fprime(
  xk=y0,
  f=lambda z: output(tf, z, rho),
  epsilon=1e-12
)

In [None]:
np.linalg.norm(grad_f_adj-grad_f_fd)/np.linalg.norm(grad_f_fd)

In [None]:
stop 

In [None]:
yfom, _ = system.solve_fom(t, y0, rho)

In [None]:
w_fom = copy.deepcopy(yfom[:system.mix.nb_comp])
n_fom = system.mix.get_n(bkd.to_torch(w_fom)).numpy()
T_fom = copy.deepcopy(yfom[system.mix.nb_comp:])
T_fom[-1] = system.mix.get_Te(pe=T_fom[-1], ne=n_fom[0])

In [None]:
xref = np.mean(yfom, axis=1)
xscale = np.std(yfom, axis=1)
# xref[-3:] = 0.0
# xscale[-3:] = 1.0

In [None]:
# xref = np.linalg.norm(yfom, ord=2, axis=1)
# xscale = np.std(yfom, axis=1)
xref[0] = 0.0
xref[-3:] = 0.0
xscale[0] = 1.0
xscale[-3:] = 1.0
xref = None
xscale = None

In [None]:
# xscale = 1.0/xref
# xref = None

In [None]:
cobras = CoBRAS(
  system=system,
  tgrid={"start": tlim[0], "stop": tlim[1]*1e1, "num": 51},
  quad_mu=quad_mu,
  xref=xref,
  xscale=xscale,
  path_to_saving="./",
  saving=True
)

In [None]:
X, Y = cobras.compute_cov_mats(
  nb_meas=10
)
X.shape, Y.shape

In [None]:
# W = np.diag(1.0/xref)
# Ws = np.sqrt(W)
# X = Ws @ X
# Y = Ws @ Y

In [None]:
cobras.compute_modes(
  X=X,
  Y=Y,
  xnot=[0,1,-2,-1],
  pod=True
)

In [None]:
basis = pickle.load(open("./cobras_basis.p", "rb"))

In [None]:
for i in range(15):
  nb = str(i+1)
  b = basis["phi"][:,i]
  pltt.plot_dist_2d(
    x=np.arange(len(b)),
    y=b,
    labels=[r"$\epsilon_i$ [eV]", r"$\%s_{%s}$" % ("phi", nb)],
    scales=["linear", "linear"],
    markersize=1,
    # figname=path + f"/{name}_{nb.zfill(2)}",
    save=False,
    show=True
  )

In [None]:
for i in range(15):
  nb = str(i+1)
  b = basis["psi"][:,i]
  pltt.plot_dist_2d(
    x=np.arange(len(b)),
    y=b,
    labels=[r"$\epsilon_i$ [eV]", r"$\%s_{%s}$" % ("psi", nb)],
    scales=["linear", "linear"],
    markersize=1,
    # figname=path + f"/{name}_{nb.zfill(2)}",
    save=False,
    show=False
  )

In [None]:
rdim = 8

# basis = pickle.load(open("./pod_basis.p", "rb"))
# phi, psi = basis["phi"][:,:rdim], basis["phi"][:,:rdim]

basis = pickle.load(open("./cobras_basis.p", "rb"))
phi, psi = basis["phi"][:,:rdim], basis["psi"][:,:rdim]

# system.set_rom(phi, psi, mask="./rom_mask.txt", xref=xref, xscale=xscale)
system.set_rom(phi, psi, mask=basis["mask"], xref=xref, xscale=xscale)

In [None]:
# I = np.eye(system.nb_eqs)
# system.set_rom(I, I, mask=np.diag(I), xref=xref, xscale=xscale)

In [None]:
yrom, _ = system.solve_rom(t, y0, rho)
nt = len(yrom.T)
nt

In [None]:
w_rom = copy.deepcopy(yrom[:system.mix.nb_comp])
n_rom = system.mix.get_n(bkd.to_torch(w_rom)).numpy()
T_rom = copy.deepcopy(yrom[system.mix.nb_comp:])
T_rom[-1] = system.mix.get_Te(pe=T_rom[-1], ne=n_rom[0])

In [None]:
for i in range(2):
  plt.semilogx(t[1:nt], T_fom[i][1:nt])
  plt.semilogx(t[1:nt], T_rom[i][1:nt], ls="--", color=plt.gca().lines[-1].get_color())
plt.show()
plt.close()

In [None]:
# for k in ("Ar", "Arp"):
#   s = system.mix.species[k]
#   pltt.plot_mom_evolution(
#     path=f"./figs/{k}/",
#     t=t[1:nt],
#     n_m={"FOM": bkd.to_torch(n_fom.T[1:nt,s.indices]), "ROM": bkd.to_torch(n_rom.T[1:nt,s.indices])},
#     molecule=s,
#     molecule_label="{%s}" % k,
#     tlim=None,
#     ylim_err=None,
#     err_scale="linear",
#     hline=None,
#     max_mom=2
#   )

In [None]:
for s in system.mix.species.values():
  print(s.name)
  for i in s.indices:
    plt.loglog(t[1:nt], n_fom[i][1:nt])
    plt.loglog(t[1:nt], n_rom[i][1:nt], ls="--", color=plt.gca().lines[-1].get_color())
  plt.show()
  plt.close()