In [None]:
import numpy as np
from sklearn.gaussian_process.kernels import Matern, RBF

import plotly
from IPython.display import display, Markdown

from docs.mse_estimator import ErrorComparer, gen_X_beta
from docs.data_generation import gen_chol_cov, gen_coords, gen_tr_idxs
from docs.plotting_utils import gen_model_barplots
from docs.sim_utils import *
from docs import fitting_models
from spe import estimators

In [None]:
np.random.seed(1)

In [None]:
n = sqrt_n**2
try:
    est_strs
except:
    est_strs = get_est_array(spe_est_str_list)
    
model_md_str = get_model_md_str(spe_est_str)
try:
    fig_name_prefix
except:
    fig_name_prefix = model_md_str
est_md_str = get_est_md_str(spe_est_str)
models = [getattr(fitting_models, m)(**kwargs) for m,kwargs in zip(model_names, model_kwargs)]
ests = [getattr(estimators, e) for e in est_strs]


In [None]:
try:
    markdown_str = eval(f'f"""{markdown_str}"""')
except:
    markdown_str = f"""# {model_md_str} Models\n Here we demonstrate the effectiveness of ```{est_md_str}``` to estimate MSE on simulated data."""
    
display(Markdown(markdown_str))


# Arbitrary Models
Here we demonstrate the effectiveness of ```spe.estimators.cp_arbitrary``` to estimate MSE on simulated data.


In [None]:
err_cmp = ErrorComparer()

In [None]:
coord = gen_coords(sqrt_n, gsize)
c_x = coord[:,0]
c_y = coord[:,1]

In [None]:
Chol_y, Cov_y_ystar = gen_chol_cov(delta, c_x, c_y, noise_kernel, noise_length_scale, noise_nu)

In [None]:
X, beta = gen_X_beta(n, p, s, x_kernel, c_x, c_y, x_length_scale, x_nu)

In [None]:
tr_idx = gen_tr_idxs(n, tr_frac, use_spatial_split)

## Simulate $Y, Y^* \overset{iid}{\sim} \mathcal{N}(\mu, \Sigma_Y)$

In [None]:
model_errs = []

for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr, 
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_y,
        Chol_ystar=None,
        Cov_y_ystar=None,
        tr_idx=tr_idx,
        fair=fair,
        est_sigma=est_sigma,
        friedman_mu=friedman_mu,
    )
    model_errs.append(errs)

In [None]:
plotly.offline.init_notebook_mode()
fig = gen_model_barplots(
    model_errs, 
    model_names, 
    est_names, 
    title=f"{fig_name_prefix} Models: NSN", 
    has_elev_err=has_elev_err,
    err_bars=True,
    color_discrete_sequence=colors,
    fig_name=f"{fig_name_prefix.lower()}_ind",
)
fig.show()

## Simulate $\begin{pmatrix} Y \\ Y^* \end{pmatrix} \sim \mathcal{N}\left(\begin{pmatrix} \mu \\ \mu \end{pmatrix}, \begin{pmatrix}\Sigma_Y & \Sigma_{Y, Y^*} \\ \Sigma_{Y^*, Y} & \Sigma_{Y}  \end{pmatrix}\right)$

In [None]:
corr_model_errs = []

for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr,
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_y,
        Chol_ystar=None,
        Cov_y_ystar=Cov_y_ystar,
        tr_idx=tr_idx,
        fair=fair,
        est_sigma=est_sigma,
        friedman_mu=friedman_mu,
    )
    corr_model_errs.append(errs)

In [None]:
corr_fig = gen_model_barplots(
    corr_model_errs, 
    model_names, 
    est_names, 
    title=f"{fig_name_prefix} Models: SSN", 
    has_elev_err=has_elev_err,
    err_bars=True,
    color_discrete_sequence=colors,
    fig_name=f"{fig_name_prefix.lower()}_corr",
)
corr_fig.show()