In [None]:
import sys
sys.path.append("../..")
import torch
from pathlib import Path
import scipy
import numpy as np
from lru.architectures import DLRU, DLRUConfig
from lru.reduction import lru_reduction_pipeline
import matplotlib.pyplot as plt
import torchid.metrics  # pip install pytorch-ident
import copy

In [None]:
# Load data
data_folder = ("F16GVT_Files", "BenchmarkData")
file_name = "F16Data_FullMSine_Level6_Validation.mat"
file_path = Path(*data_folder) / file_name
data = scipy.io.loadmat(file_path)

In [None]:
u_test = data["Force"].T  # Or force
y_test = data["Acceleration"].T

In [None]:
RUNS = ["ckpt_large_no_reg", "ckpt_large_reg_modal", "ckpt_large_reg_hankel"]
REDUCTIONS = ["balanced_truncation", "balanced_singular_perturbation", "modal_truncation", "modal_singular_perturbation"]
FIT_THRESHOLD = 83.25865396027804*0.99 # 1% less than the worst

In [None]:
d_state = 100
MODES = np.arange(d_state, 0, -1) # all modes to be tested
FIT_MEAN_ALL = np.empty((len(RUNS), len(REDUCTIONS), d_state))
MIN_ORDER_ALL = np.empty((len(RUNS), len(REDUCTIONS)), dtype=np.int64)

In [None]:
for idx_run, run in enumerate(RUNS): # different regularizers applied
    ckpt = torch.load(Path("ckpt")/ f"{run}.pt", map_location="cpu")
    cfg = ckpt["cfg"]
    scaler_u = ckpt["scaler_u"]
    scaler_y = ckpt["scaler_y"]

    # Load model
    config = DLRUConfig(
        d_model=cfg.d_model, d_state=cfg.d_state, n_layers=cfg.n_layers, ff=cfg.ff
    )
    assert(cfg.d_state == d_state)
    model = DLRUConfig(1, 3, config)
    model.load_state_dict(ckpt["model"])


    for idx_red, reduction_method in enumerate(REDUCTIONS): # different reductions applied
        print(f"{run} {reduction_method}")

        FIT_MEAN = []
        for modes in MODES:
            model_reduced = copy.deepcopy(model)
            for block in model_reduced.blocks:

                # reduction pipeline
                ss_params = block.lru.ss_params()
                ss_params = [param.detach().numpy() for param in ss_params]
                lambdas, B, C, D = ss_params
                lambdas_red, B_red, C_red, D_red = lru_reduction_pipeline(lambdas, B, C, D,  modes=modes, method=reduction_method)


                params_red = [lambdas_red.astype(np.complex64), B_red.astype(np.complex64), C_red.astype(np.complex64), D_red.astype(np.float32)]
                params_red = [torch.tensor(param_red) for param_red in params_red]
                block.lru.replace_ss_params(*params_red)

            ut = torch.tensor(scaler_u.transform(u_test)).unsqueeze(0).float()
            with torch.no_grad():
                #    y_test_hat = model(ut, mode="scan").squeeze(0).to("cpu").numpy()
                y_test_hat = model_reduced(ut, mode="scan").squeeze(0).to("cpu").numpy()

            y_test_hat = scaler_y.inverse_transform(y_test_hat)

            fit = torchid.metrics.fit_index(y_test, y_test_hat).mean()
            FIT_MEAN.append(fit)

        FIT_MEAN = np.array(FIT_MEAN)
        MIN_ORDER_ALL[idx_run, idx_red] = MODES[FIT_MEAN > FIT_THRESHOLD].min()
        FIT_MEAN_ALL[idx_run, idx_red, :cfg.d_state] = FIT_MEAN

In [None]:
MIN_ORDER_ALL

In [None]:
import xarray
fit_mean_all = xarray.DataArray(FIT_MEAN_ALL,
                 dims=["run", "truncation_method", "modes"],
                 coords=[RUNS, REDUCTIONS, MODES])
fit_mean_all.to_netcdf("fit_mean_all.nc")

In [None]:
fit_mean_all[:, :]

In [None]:
fig, ax = plt.subplots()
ax.plot(MODES, fit_mean_all.loc["ckpt_large_no_reg", "balanced_truncation"])
ax.axhline(FIT_THRESHOLD, color="red")
ax.invert_xaxis()
ax.grid()
ax.set_xlabel("Number of retained modes (-)")
ax.set_ylabel("Average FIT (%)")