### Varimax Rotation

In [None]:
import os
import sys
import numpy as np
import scipy as sp
import joblib as jl
import dill as pickle

from tqdm import tqdm
from typing import *

In [None]:
import matplotlib.pyplot as plt
style = "/home/zanardi/Workspace/Research/styles/matplotlib/paper_1column.mplstyle"
plt.style.use(style)

In [None]:
sys.path.append("/home/zanardi/Codes/ML/ROMAr/romar/")
from romar import env
from romar import utils
from romar.systems import BoxAd
from romar import postproc as pp

Set enviroment

In [None]:
env_opts = {
  "backend": "numpy",
  "device": "cpu",
  "device_idx": 0,
  "nb_threads": 2,
  "epsilon": None,
  "floatx": "float64",
  "seed": 0
}
env.set(**env_opts)

Set inputs

In [None]:
# Trajectories indices
irange = [0, 100]
# Parallel workers
nb_workers = 10
# Paths
prefix = "/home/zanardi/Codes/ML/ROMAr/runs/run_04_cc/"
paths = {
  # > ROM basis
  "roms": {
    "CoBRAS": prefix + "/max_mom_2/models/cobras/basis.p",
    "CoBRAS-Varimax": prefix + "/max_mom_2/models/cobras/basis_varimax_psi.p"
  },
  # > Path to solutions folder
  "data": prefix + "/data/test/",
  # > Thermochemical database
  "dtb": "/home/zanardi/Codes/ML/ROMAr/romar/examples/database/",
  # > Output folder
  "out": "./figs/varimax/"
}
# ROM identifiers
roms_id = {
  "ref": "CoBRAS",
  "rot": "CoBRAS-Varimax"
}
# Time limit
tlim = [1e-9, 1e-3]
# Number of ROM dimensions
rdims = np.arange(7,11)

In [None]:
os.makedirs(paths["out"], exist_ok=True)

Initialize 0D thermochemical system

In [None]:
system = BoxAd(
  species={k: paths["dtb"] + "/species/" + k + ".json" for k in ("Ar", "Arp", "em")},
  kin_dtb=paths["dtb"] + "/rates/kin_fit.p",
  rad_dtb=paths["dtb"] + "/rates/rad_fit.p",
  use_rad=True,
  use_proj=False,
  use_tables=False
)

Evaluate ROM energies

In [None]:
def evaluate_parallel(
  irange,
  nb_workers,
  **kwargs
):
  iterable = tqdm(
    iterable=range(*irange),
    ncols=80,
    desc="  Cases",
    file=sys.stdout
  )
  return jl.Parallel(nb_workers)(
    jl.delayed(
      env.make_fun_parallel(evaluate_rom_energy)
    )(index=i, **kwargs) for i in iterable
  )

def evaluate_rom_energy(
  system,
  path,
  index,
  tout=5e2,
  tlim=None
):
  system.use_rom = True
  # Load test case
  icase = utils.load_case(path=path, index=index)
  t, y0, rho = [icase[k] for k in ("t", "y0", "rho")]
  # Time window
  if (tlim is not None):
    i = (t >= np.amin(tlim)) * (t <= np.amax(tlim))
    t = t[i]
  # Solve ROM
  z, _ = system.solve_rom(t, y0, rho, tout=tout, decode=False)
  # Postprocess
  if ((z is not None) and (z.shape[1] == len(t))):
    # > Compute energy
    zhat = z[:system.rom.size_zhat]
    energy = np.linalg.norm(zhat, ord=2, axis=0)
    # > Return data
    return {
      "t": t,
      "energy": energy
    }

def build_rom(system, path_to_basis, rdim):
  with open(path_to_basis, "rb") as file:
    basis = pickle.load(file)
  system.rom.build(
    phi=basis["phi"][rdim],
    psi=basis["psi"][rdim],
    **{k: basis[k] for k in ("mask", "xref", "xscale")}
  )

In [None]:
data = {}
kwargs = dict(
  irange=irange,
  nb_workers=nb_workers,
  system=system,
  path=paths["data"],
  tlim=tlim
)
for (rom, path_to_basis) in paths["roms"].items():
  data[rom] = {}
  for rdim in rdims:
    build_rom(system, path_to_basis, rdim)
    data[rom][rdim] = evaluate_parallel(**kwargs)

In [None]:
stats = {}
for rom in paths["roms"].keys():
  stats[rom] = {"t": data[rom][rdims[0]][0]["t"]}
  for rdim in rdims:
    kdata = []
    for traj in data[rom][rdim]:
      if (traj is not None):
        kdata.append(traj["energy"])
    kdata = np.vstack(kdata)
    stats[rom][rdim] = np.mean(kdata, axis=0)

In [None]:
ratios = {}
for rdim in rdims:
  ratios[fr"$r={rdim}$"] = stats[roms_id["rot"]][rdim] / stats[roms_id["ref"]][rdim]

Plotting

In [None]:
pp.plot_evolution(
  x=stats[rom]["t"],
  y=ratios,
  ls="-",
  labels=[r"$t$ [s]", r"$\|\mathbf{z}_\mathrm{R}\|_2/\|\mathbf{z}\|_2$"],
  scales=["log", "linear"],
  legend_loc="best",
  figname=paths["out"] + "/energy.png",
  save=True,
  show=False
)