In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import vmap
from jax import jit, lax, grad
import matplotlib.pyplot as plt
import cmocean as cmo
import importlib


import gpytorch
import torch
import linear_operator
from linear_operator.operators import (
    AddedDiagLinearOperator,
    DiagLinearOperator,
    LinearOperator,
    DenseLinearOperator,
)

from jax.config import config
config.update("jax_enable_x64", True)

  from .autonotebook import tqdm as notebook_tqdm


In [183]:
## import modules
import preconditioner as precond
import conjugate_gradient as cg
import pivoted_cholesky as pc
# import pivoted_cholesky_ref as pc_ref # to use this script we need "torch", please comment out if not needed.
import calc_logdet
import calc_trace
import mmm
def reload():
    importlib.reload(precond)
    importlib.reload(cg)
    importlib.reload(pc)
    importlib.reload(calc_logdet)
    importlib.reload(calc_trace)
    importlib.reload(mmm)
reload()

In [3]:
import warnings
warnings.filterwarnings("always")

In [4]:
from stopro.data_generator.sinusoidal import Sinusoidal
from stopro.data_preparer.data_preparer import DataPreparer
from stopro.sub_modules.load_modules import load_params, load_data
from stopro.sub_modules.loss_modules import hessian, logposterior
from stopro.sub_modules.init_modules import get_init, reshape_init
import stopro.GP.gp_sinusoidal_independent as gp_sinusoidal_independent
from stopro.GP.kernels import define_kernel
# from stopro.solver.optimizers import optimize_by_adam
from stopro.data_handler.data_handle_module import *

In [6]:
project_name = 'test'
simulation_name = 'data'

In [7]:
params_main, params_prepare, lbls = load_params(f"{project_name}/{simulation_name}/data_input")
params_model = params_main["model"]
params_optimization = params_main["optimization"]
params_plot = params_prepare["plot"]
vnames = params_prepare["vnames"]
params_setting = params_prepare["setting"]
params_generate_training = params_prepare["generate_training"]
params_generate_test = params_prepare["generate_test"]
params_kernel_arg = params_prepare["kernel_arg"]

# prepare initial hyper-parameter
init = get_init(
    params_model["init_kernel_hyperparameter"],
    params_model["kernel_type"],
    system_type=params_model["system_type"],
)

In [8]:
# prepare data
hdf_operator = HdfOperator(f"{project_name}/{simulation_name}")
r_test, μ_test, r_train, μ_train, f_train = load_data(lbls, vnames, hdf_operator)
delta_y_train = jnp.empty(0)
for i in range(len(r_train)):
    delta_y_train = jnp.append(delta_y_train, f_train[i] - μ_train[i])

args_predict = r_test, μ_test, r_train, delta_y_train, params_model["epsilon"]

In [9]:
# setup model
Kernel = define_kernel(params_model)
gp_model = gp_sinusoidal_independent.GPSinusoidalWithoutPIndependent(
    use_difp=params_setting["use_difp"],
    use_difu=params_setting["use_difu"],
    lbox=jnp.array([2.5, 0.0]),
    infer_governing_eqs=params_prepare["generate_test"]["infer_governing_eqs"],
    Kernel=Kernel,
    index_optimize_noise=params_model["index_optimize_noise"],
)
gp_model.set_constants(*args_predict)

In [177]:
cov_scale = 0.0
length = 2.3
init = jnp.array(
    [
        cov_scale,
        length,
        length,
        cov_scale,
        length,
        length,
        cov_scale,
        length,
        length,
    ]
)
right_matrix = jax.random.normal(
    jax.random.PRNGKey(0), (params_prepare["num_points"]["training"]["sum"], 11)
)
K = gp_model.trainingK_all(init, r_train)
K_x_right_matrix_naive = jnp.matmul(K, right_matrix)

In [178]:
Ks = gp_model.trainingKs(init, r_train)
for i in range(len(Ks)):
    for j in list(range(len(Ks) - len(Ks[i])))[::-1]:
        Ks[i] = [Ks[j][i]] + Ks[i]

In [179]:
r1s = r_train
r2s = r_train
sec1 = gp_model.sec_tr
sec2 = gp_model.sec_tr
Kss = Ks

In [185]:
K_x_right_matrix = mmm.mmm(r_train, r_train, right_matrix, Ks, gp_model.sec_tr, gp_model.sec_tr)

In [186]:
mean_rel_error = jnp.mean(
        jnp.abs(K_x_right_matrix_naive - K_x_right_matrix) / K_x_right_matrix_naive
)

In [187]:
mean_rel_error

DeviceArray(-3.72537117e-16, dtype=float64)

In [188]:
cov_scale = 0.0
length = 0.0
init = jnp.array(
    [
        cov_scale,
        length,
        length,
        cov_scale,
        length,
        length,
        cov_scale,
        length,
        length,
    ]
)
right_matrix = jax.random.normal(
    jax.random.PRNGKey(0), (params_prepare["num_points"]["training"]["sum"], 11)
)
K = gp_model.trainingK_all(init, r_train)
K_x_right_matrix_naive = jnp.matmul(K, right_matrix)

In [189]:
Ks = gp_model.trainingKs(init, r_train)
for i in range(len(Ks)):
    for j in list(range(len(Ks) - len(Ks[i])))[::-1]:
        Ks[i] = [Ks[j][i]] + Ks[i]

In [190]:
r1s = r_train
r2s = r_train
sec1 = gp_model.sec_tr
sec2 = gp_model.sec_tr
Kss = Ks

In [191]:
K_x_right_matrix = mmm.mmm(r_train, r_train, right_matrix, Ks, gp_model.sec_tr, gp_model.sec_tr)

In [192]:
mean_rel_error = jnp.mean(
        jnp.abs(K_x_right_matrix_naive - K_x_right_matrix) / K_x_right_matrix_naive
)

In [193]:
mean_rel_error

DeviceArray(1.17200347e-15, dtype=float64)