In [12]:
import numpy as np
import pandas as pd
from sklearn.gaussian_process.kernels import Matern, RBF

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from doc.mse_estimator import ErrorComparer
from doc.data_generation import gen_rbf_X, gen_matern_X, create_clus_split, gen_cov_mat
from spe.estimators import kfoldcv, kmeanscv, better_test_est_split, cp_smoother_train_test 
from spe.smoothers import LinearRegression, BSplineRegressor

In [13]:
np.random.seed(1)

In [14]:
## number of realizations to run
niter = 100

## data generation parameters
n=20**2
p=30
s=30
delta = 0.75
snr = 0.4
tr_frac = .25

noise_kernel = 'matern'
noise_length_scale = 1.
noise_nu = .5

X_kernel = 'matern'
X_length_scale = 5.
X_nu = 2.5

## ErrorComparer parameters
alpha = .05
nboot = 100
k = 5
max_depth = 4
lambd = .025
# models = [RelaxedLasso(lambd=lambd), Tree(max_depth=max_depth)]
# model_names = ["Relaxed Lasso", f"Depth {max_depth} Decision Tree"]
models = [BSplineRegressor()]# LinearRegression(fit_intercept=True)]
model_names = ["Linear Regression"]
ests = [
    better_test_est_split,
    # cp_adaptive_smoother_train_test,
    cp_smoother_train_test,
#     cp_smoother_train_test,
    kfoldcv, 
    kmeanscv
]
est_kwargs = [
    {},
    {},
    {'k': k},
    {'k': k}
]

## plot parameters
title = "Simulated Model Comparisons"
est_names = ["GenCp", "KFCV", "SPCV"]

## output/save parameters
idx = -1

savedir='~'

In [15]:
# if not os.path.exists(os.path.expanduser(savedir)):
#     os.makedirs(os.path.expanduser(savedir))

# params = pd.DataFrame({'niter': niter,
#                        'n': n, 
#                        'p': p, 
#                        's': s,
#                        'snr': snr,
#                        'n_estimators': n_estimators,
#                        'max_depth': max_depth,
#                        'delta': delta,
#                        'nk': noise_kernel, 
#                        'nls': noise_length_scale, 
#                        'nn': noise_nu, 
#                        'xk': X_kernel,
#                        'xls': X_length_scale,
#                        'xn': X_nu}, index=[idx])
# params.to_csv(os.path.expanduser(savedir + 'params.csv'))
# dffp = os.path.expanduser(savedir + "err_df.csv")
# barfp = os.path.expanduser(savedir + 'barchart.jpeg')

In [16]:
err_cmp = ErrorComparer()

In [17]:
nx = ny = int(np.sqrt(n))
xs = np.linspace(0, 10, nx)
ys = np.linspace(0, 10, ny)
c_x, c_y = np.meshgrid(xs, ys)
c_x = c_x.flatten()
c_y = c_y.flatten()
coord = np.stack([c_x, c_y]).T

In [18]:
if noise_kernel == 'rbf':
    Sigma_t = gen_cov_mat(c_x, c_y, RBF(length_scale=noise_length_scale))
elif noise_kernel == 'matern':
    Sigma_t = gen_cov_mat(c_x, c_y, Matern(length_scale=noise_length_scale, nu=noise_nu))
else:
    Sigma_t = np.eye(n)
    
Cov_st = delta*Sigma_t
Sigma_t = delta*Sigma_t + (1-delta)*np.eye(n)

if noise_kernel == 'rbf' or noise_kernel == 'matern':
    Chol_t = np.linalg.cholesky(Sigma_t)
else:
    Chol_t = np.eye(n)

In [19]:
if X_kernel == 'rbf':
    X = gen_rbf_X(c_x, c_y, p)
elif X_kernel == 'matern':
    X = gen_matern_X(c_x, c_y, p, length_scale=X_length_scale, nu=X_nu)
else:
    X = np.random.randn(n,p)

beta = np.zeros(p)
idx = np.random.choice(p,size=s,replace=False)
beta[idx] = np.random.uniform(-1,1,size=s)
# beta[idx] = np.random.uniform(1,3,size=s) * np.random.choice([-1,1],size=s,replace=True)

In [20]:
tr_idx = create_clus_split(
            int(np.sqrt(n)), int(np.sqrt(n)), tr_frac
        )
# tr_idx = np.ones(n, dtype=bool)

In [21]:
model_errs = []

for model in models:
    errs = err_cmp.compare(model,
                         ests,
                         est_kwargs,
                         niter=niter,
                         n=n,
                         p=p,
                         s=s,
                         snr=snr, 
                         X=X,
                         beta=beta,
    #                        X_kernel=X_kernel,
    #                        X_ls = X_length_scale,
    #                        X_nu = X_nu,
                         coord=coord,
                         Chol_y=Chol_t,
                         Chol_ystar=None,
                         Cov_y_ystar=None,
                         tr_idx=tr_idx,
                         fair=False,
                        #    tr_frac=tr_frac,
                         )
    model_errs.append(errs)

  0%|          | 0/100 [00:00<?, ?it/s]

True


  2%|▏         | 2/100 [00:00<00:21,  4.49it/s]

True
True


  3%|▎         | 3/100 [00:00<00:22,  4.38it/s]

True


  5%|▌         | 5/100 [00:01<00:23,  4.12it/s]

True


  6%|▌         | 6/100 [00:01<00:21,  4.29it/s]

True


  7%|▋         | 7/100 [00:01<00:21,  4.38it/s]

True


  8%|▊         | 8/100 [00:01<00:20,  4.45it/s]

True


  9%|▉         | 9/100 [00:02<00:20,  4.50it/s]

True
True


 10%|█         | 10/100 [00:02<00:20,  4.40it/s]

True


 11%|█         | 11/100 [00:02<00:20,  4.34it/s]

True


 12%|█▏        | 12/100 [00:02<00:20,  4.35it/s]

True


 14%|█▍        | 14/100 [00:03<00:19,  4.42it/s]

True
True


 15%|█▌        | 15/100 [00:03<00:20,  4.23it/s]

True


 17%|█▋        | 17/100 [00:03<00:19,  4.34it/s]

True
True


 18%|█▊        | 18/100 [00:04<00:18,  4.37it/s]

True


 19%|█▉        | 19/100 [00:04<00:18,  4.31it/s]

True


 21%|██        | 21/100 [00:04<00:18,  4.33it/s]

True
True


 22%|██▏       | 22/100 [00:05<00:18,  4.32it/s]

True


 24%|██▍       | 24/100 [00:05<00:17,  4.40it/s]

True


 25%|██▌       | 25/100 [00:05<00:16,  4.44it/s]

True
True


 27%|██▋       | 27/100 [00:06<00:16,  4.42it/s]

True
True


 29%|██▉       | 29/100 [00:06<00:15,  4.44it/s]

True


 30%|███       | 30/100 [00:06<00:15,  4.48it/s]

True


 31%|███       | 31/100 [00:07<00:15,  4.52it/s]

True


 32%|███▏      | 32/100 [00:07<00:14,  4.54it/s]

True
True


 34%|███▍      | 34/100 [00:07<00:14,  4.43it/s]

True
True


 35%|███▌      | 35/100 [00:08<00:14,  4.34it/s]

True


 37%|███▋      | 37/100 [00:08<00:14,  4.28it/s]

True


 38%|███▊      | 38/100 [00:08<00:14,  4.37it/s]

True
True


 40%|████      | 40/100 [00:09<00:13,  4.34it/s]

True
True


 41%|████      | 41/100 [00:09<00:13,  4.30it/s]

True


 42%|████▏     | 42/100 [00:09<00:13,  4.30it/s]

True


 44%|████▍     | 44/100 [00:10<00:12,  4.39it/s]

True


 45%|████▌     | 45/100 [00:10<00:12,  4.47it/s]

True


 46%|████▌     | 46/100 [00:10<00:11,  4.53it/s]

True


 47%|████▋     | 47/100 [00:10<00:11,  4.58it/s]

True


 48%|████▊     | 48/100 [00:10<00:11,  4.60it/s]

True


 49%|████▉     | 49/100 [00:11<00:11,  4.62it/s]

True


 50%|█████     | 50/100 [00:11<00:10,  4.62it/s]

True


 51%|█████     | 51/100 [00:11<00:10,  4.64it/s]

True


 52%|█████▏    | 52/100 [00:11<00:10,  4.61it/s]

True


 53%|█████▎    | 53/100 [00:12<00:10,  4.59it/s]

True


 54%|█████▍    | 54/100 [00:12<00:09,  4.63it/s]

True


 55%|█████▌    | 55/100 [00:12<00:09,  4.61it/s]

True


 56%|█████▌    | 56/100 [00:12<00:09,  4.63it/s]

True


 57%|█████▋    | 57/100 [00:12<00:09,  4.62it/s]

True


 58%|█████▊    | 58/100 [00:13<00:09,  4.66it/s]

True


 59%|█████▉    | 59/100 [00:13<00:08,  4.68it/s]

True
True


 61%|██████    | 61/100 [00:13<00:08,  4.52it/s]

True


 62%|██████▏   | 62/100 [00:13<00:08,  4.56it/s]

True


 63%|██████▎   | 63/100 [00:14<00:08,  4.57it/s]

True
True


 65%|██████▌   | 65/100 [00:14<00:07,  4.48it/s]

True
True


 67%|██████▋   | 67/100 [00:15<00:07,  4.46it/s]

True
True


 68%|██████▊   | 68/100 [00:15<00:07,  4.45it/s]

True


 69%|██████▉   | 69/100 [00:15<00:07,  4.38it/s]

True


 71%|███████   | 71/100 [00:16<00:06,  4.43it/s]

True


 72%|███████▏  | 72/100 [00:16<00:06,  4.48it/s]

True


 73%|███████▎  | 73/100 [00:16<00:06,  4.50it/s]

True


 74%|███████▍  | 74/100 [00:16<00:05,  4.56it/s]

True


 75%|███████▌  | 75/100 [00:16<00:05,  4.58it/s]

True


 76%|███████▌  | 76/100 [00:17<00:05,  4.58it/s]

True


 77%|███████▋  | 77/100 [00:17<00:05,  4.59it/s]

True


 78%|███████▊  | 78/100 [00:17<00:04,  4.61it/s]

True
True


 80%|████████  | 80/100 [00:18<00:04,  4.47it/s]

True


 81%|████████  | 81/100 [00:18<00:04,  4.51it/s]

True


 82%|████████▏ | 82/100 [00:18<00:03,  4.52it/s]

True


 83%|████████▎ | 83/100 [00:18<00:03,  4.56it/s]

True


 84%|████████▍ | 84/100 [00:18<00:03,  4.55it/s]

True


 85%|████████▌ | 85/100 [00:19<00:03,  4.58it/s]

True


 86%|████████▌ | 86/100 [00:19<00:03,  4.61it/s]

True


 87%|████████▋ | 87/100 [00:19<00:02,  4.63it/s]

True


 88%|████████▊ | 88/100 [00:19<00:02,  4.64it/s]

True


 89%|████████▉ | 89/100 [00:19<00:02,  4.62it/s]

True


 90%|█████████ | 90/100 [00:20<00:02,  4.61it/s]

True


 91%|█████████ | 91/100 [00:20<00:01,  4.61it/s]

True


 92%|█████████▏| 92/100 [00:20<00:01,  4.60it/s]

True
True


 93%|█████████▎| 93/100 [00:20<00:01,  4.53it/s]

True


 94%|█████████▍| 94/100 [00:21<00:01,  4.50it/s]

True


 95%|█████████▌| 95/100 [00:21<00:01,  4.39it/s]

True


 96%|█████████▌| 96/100 [00:21<00:00,  4.14it/s]

True


 97%|█████████▋| 97/100 [00:21<00:00,  4.00it/s]

True


 99%|█████████▉| 99/100 [00:22<00:00,  4.11it/s]

True
True


100%|██████████| 100/100 [00:22<00:00,  4.43it/s]


In [22]:
fig = make_subplots(
    rows=1, cols=len(models),
    subplot_titles=model_names)

for i, errs in enumerate(model_errs):
    risks = [err.mean() for err in errs]
    test_risk = risks[0]
    est_risks = risks[1:]

    df = pd.DataFrame({est_names[i]: errs[i+1] for i in np.arange(len(est_names))})

    fig.add_trace(go.Bar(
    #     x=['VRF', 'BF_WR', 'BF_FR'], 
        x = df.columns,
        y=(df).mean()/test_risk,
        marker_color=px.colors.qualitative.Plotly,
        text=np.around((df).mean()/test_risk,3),
        textposition='outside',
        error_y=dict(
            type='data',
            color='black',
            array=(df).std() / test_risk,
    #         array=(df/test_risk).clip(upper=(df/test_risk).quantile(.95),axis=1).std(),
    #         symmetric=False,
    #         array=(df/test_risk).quantile(.95) - (df/test_risk).mean(),
    #         arrayminus=(df/test_risk).mean() - (df/test_risk).quantile(.93)
        )
    #         array=(df).mean() + (df).std(),
    #         arrayminus=(df).mean() - (df).std())
    ), row=1, col=i+1)
    # fig.add_trace(go.Bar(
    #     name='Experimental',
    #     x=['GenCp', 'KFCV', 'SPCV'], y=(df).mean(),
    #     error_y=dict(type='data', array=[1, 2])
    # ))
    # fig.update_layout(barmode='group')
    fig.add_hline(y=1., line_color='red', row=1, col=i+1)
    
    fig.update_xaxes(title_text="Method", row=1, col=i+1)
    fig.update_yaxes(title_text="Relative MSE", row=1, col=i+1)
    
# fig.update_layout(
# #     title=f"FC_\u03B4{delta}_snr{snr}_nk{noise_kernel}_nls{noise_length_scale}_nv{noise_nu}_Xk{X_kernel}_Xls{X_length_scale}_Xv{X_nu}",
#     title=title,
#     xaxis_title="Method",
#     yaxis_title="MSE",
# #     legend_title="Legend Title",
# #     font=dict(
# #         family="Courier New, monospace",
# #         size=18,
# #         color="RebeccaPurple"
# #     )
# )
# barfp = os.path.expanduser(savedir + fig.layout.title['text'] + ".jpeg")
# fig.write_image(os.path.expanduser(barfp))
fig.show()