In [137]:
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso, LinearRegression
from sklearn.cluster import KMeans
from sklearn.gaussian_process.kernels import Matern, RBF
from sklearn.ensemble import RandomForestRegressor

import plotly.express as px

import plotly.graph_objects as go

from scipy.linalg import toeplitz, block_diag

from spe.mse_estimator import ErrorComparer
from spe.data_generation import gen_rbf_X, gen_matern_X, create_clus_split, gen_cov_mat
from spe.forest import BlurredForest
from spe.estimators import kfoldcv, kmeanscv, better_test_est_split, cp_rf_train_test, cp_general_train_test, bag_kfoldcv, bag_kmeanscv

import os

In [173]:
niter = 50

n=30**2
p=30
s=30

delta = 1.

n_estimators = 50
max_depth = 8

snr = 1.

noise_kernel = 'matern'
noise_length_scale = 1.
noise_nu = 0.5

X_kernel = 'matern'
X_length_scale = 5.
X_nu = 2.5

idx = -1

savedir='~'

const_mu = False

In [174]:
# noise_kernel = kernel
# X_kernel = kernel

# noise_length_scale = length_scale
# X_length_scale = length_scale

# noise_nu = nu
# X_nu = nu

In [175]:
if not os.path.exists(os.path.expanduser(savedir)):
    os.makedirs(os.path.expanduser(savedir))

params = pd.DataFrame({'niter': niter,
                       'n': n, 
                       'p': p, 
                       's': s,
                       'snr': snr,
                       'n_estimators': n_estimators,
                       'max_depth': max_depth,
                       'delta': delta,
                       'nk': noise_kernel, 
                       'nls': noise_length_scale, 
                       'nn': noise_nu, 
                       'xk': X_kernel,
                       'xls': X_length_scale,
                       'xn': X_nu}, index=[idx])
params.to_csv(os.path.expanduser(savedir + 'params.csv'))
dffp = os.path.expanduser(savedir + "err_df.csv")
# barfp = os.path.expanduser(savedir + 'barchart.jpeg')

In [176]:
err_cmp = ErrorComparer()

In [177]:
nx = ny = int(np.sqrt(n))
xs = np.linspace(0, 10, nx)
ys = np.linspace(0, 10, ny)
c_x, c_y = np.meshgrid(xs, ys)
c_x = c_x.flatten()
c_y = c_y.flatten()
coord = np.stack([c_x, c_y]).T

In [178]:
if noise_kernel == 'rbf':
    Sigma_t = gen_cov_mat(c_x, c_y, RBF(length_scale=noise_length_scale))
elif noise_kernel == 'matern':
    Sigma_t = gen_cov_mat(c_x, c_y, Matern(length_scale=noise_length_scale, nu=noise_nu))
else:
    Sigma_t = np.eye(n)
    
Sigma_t = delta*Sigma_t + (1-delta)*np.eye(n)

if noise_kernel == 'rbf' or noise_kernel == 'matern':
    Chol_t = np.linalg.cholesky(Sigma_t)
else:
    Chol_t = np.eye(n)

In [179]:
Chol_t_inv = np.linalg.inv(Chol_t).T
np.allclose(np.linalg.inv(Sigma_t), Chol_t_inv @ Chol_t_inv.T)

True

In [180]:
X = np.random.randn(n,n)
np.allclose(
    np.linalg.inv(
        X.T @ np.linalg.inv(Sigma_t) @ X
    ) @ X.T @ np.linalg.inv(Sigma_t), 
    np.linalg.pinv(Chol_t_inv.T @ X) @ Chol_t_inv.T
)

True

In [181]:
models = [
    BlurredForest(n_estimators=n_estimators,  
                  max_depth=max_depth, 
                  bootstrap_type='blur'),
    BlurredForest(n_estimators=n_estimators,  
                  max_depth=max_depth, 
                  bootstrap_type='blur'),
]

ests = [
    better_test_est_split,
    better_test_est_split
]

est_kwargs = [
    {'full_refit':True,
    'chol_eps': Chol_t,
    'gls': False},
    {'full_refit':True,
    'chol_eps': Chol_t,
    'gls': True},
]

In [182]:
if X_kernel == 'rbf':
    X = gen_rbf_X(c_x, c_y, p)
elif X_kernel == 'matern':
    X = gen_matern_X(
        c_x, 
        c_y, 
        p, 
        length_scale=X_length_scale, 
        nu=X_nu
    )
else:
    X = np.random.randn(n,p)

beta = np.zeros(p)
idx = np.random.choice(p,size=s)
beta[idx] = np.random.uniform(-1,1,size=s)

In [183]:
(bfols_err,
 bfgls_err,) = err_cmp.compare(models,
                     ests,
                     est_kwargs,
                     niter=niter,
                     n=n,
                     p=p,
                     s=p,
                     snr=snr, 
                     X=X,
                     beta=beta,
                     coord=coord,
                     Chol_t=Chol_t,
                     Chol_s=None,
                     tr_idx=None,
                     fair=False,
                     const_mu=const_mu,
                     )

0
0.5833333333333334
10
20
30
40


In [184]:
risk_bfols = bfols_err.mean()
risk_bfgls = bfgls_err.mean()


In [185]:
save_df = pd.DataFrame({
    'BF_OLS': (bfols_err.T),
    'BF_GLS': (bfgls_err.T),
})
# save_df.to_csv(dffp)

In [186]:
df = pd.DataFrame({
    'BF_OLS': (bfols_err.T),
    'BF_GLS': (bfgls_err.T),
})

In [187]:
(df).mean()

BF_OLS    0.267553
BF_GLS    0.269722
dtype: float64

In [188]:
df.var()

BF_OLS    0.005837
BF_GLS    0.005808
dtype: float64

In [189]:
fig = go.Figure()
fig.add_trace(go.Bar(
    x=df.columns, 
    y=(df).mean(),
    marker_color=px.colors.qualitative.Plotly,
    text=np.around((df).mean(),3),
    textposition='outside',
    error_y=dict(
        type='data',
        color='black',
        symmetric=False,
        array=(df).quantile(.75) - (df).mean(),
        arrayminus=(df).mean() - (df).quantile(.25))
#         array=(df).mean() + (df).std(),
#         arrayminus=(df).mean() - (df).std())
))
# fig.add_trace(go.Bar(
#     name='Experimental',
#     x=['GenCp', 'KFCV', 'SPCV'], y=(df).mean(),
#     error_y=dict(type='data', array=[1, 2])
# ))
# fig.update_layout(barmode='group')
# fig.add_hline(y=1., line_color='red')
fig.update_layout(
#     title=f"FC_\u03B4{delta}_snr{snr}_nk{noise_kernel}_nls{noise_length_scale}_nv{noise_nu}_Xk{X_kernel}_Xls{X_length_scale}_Xv{X_nu}",
    title=f"Blurred Forest: W vs Y Refit, Constant \u03bc" if const_mu \
            else f"Blurred Forest: W vs Y Refit, Linear \u03bc, SNR: {snr}",
    xaxis_title="Method",
    yaxis_title="MSE",
#     legend_title="Legend Title",
#     font=dict(
#         family="Courier New, monospace",
#         size=18,
#         color="RebeccaPurple"
#     )
)
# barfp = os.path.expanduser(savedir + fig.layout.title['text'] + ".jpeg")
# fig.write_image(os.path.expanduser(barfp))
fig.show()

In [None]:
# ^^ snr .01, linear mu, 50 depth 8 trees

In [154]:
fig = go.Figure()
fig.add_trace(go.Bar(
    x=df.columns, 
    y=(df).mean(),
    marker_color=px.colors.qualitative.Plotly,
    text=np.around((df).mean(),3),
    textposition='outside',
    error_y=dict(
        type='data',
        color='black',
        symmetric=False,
        array=(df).quantile(.75) - (df).mean(),
        arrayminus=(df).mean() - (df).quantile(.25))
#         array=(df).mean() + (df).std(),
#         arrayminus=(df).mean() - (df).std())
))
# fig.add_trace(go.Bar(
#     name='Experimental',
#     x=['GenCp', 'KFCV', 'SPCV'], y=(df).mean(),
#     error_y=dict(type='data', array=[1, 2])
# ))
# fig.update_layout(barmode='group')
# fig.add_hline(y=1., line_color='red')
fig.update_layout(
#     title=f"FC_\u03B4{delta}_snr{snr}_nk{noise_kernel}_nls{noise_length_scale}_nv{noise_nu}_Xk{X_kernel}_Xls{X_length_scale}_Xv{X_nu}",
    title=f"Blurred Forest: W vs Y Refit, Constant \u03bc" if const_mu \
            else f"Blurred Forest: W vs Y Refit, Linear \u03bc",
    xaxis_title="Method",
    yaxis_title="MSE",
#     legend_title="Legend Title",
#     font=dict(
#         family="Courier New, monospace",
#         size=18,
#         color="RebeccaPurple"
#     )
)
# barfp = os.path.expanduser(savedir + fig.layout.title['text'] + ".jpeg")
# fig.write_image(os.path.expanduser(barfp))
fig.show()

In [None]:
# ^^ snr .01, linear mu, 100 depth 6 trees

In [136]:
fig = go.Figure()
fig.add_trace(go.Bar(
    x=df.columns, 
    y=(df).mean(),
    marker_color=px.colors.qualitative.Plotly,
    text=np.around((df).mean(),3),
    textposition='outside',
    error_y=dict(
        type='data',
        color='black',
        symmetric=False,
        array=(df).quantile(.75) - (df).mean(),
        arrayminus=(df).mean() - (df).quantile(.25))
#         array=(df).mean() + (df).std(),
#         arrayminus=(df).mean() - (df).std())
))
# fig.add_trace(go.Bar(
#     name='Experimental',
#     x=['GenCp', 'KFCV', 'SPCV'], y=(df).mean(),
#     error_y=dict(type='data', array=[1, 2])
# ))
# fig.update_layout(barmode='group')
# fig.add_hline(y=1., line_color='red')
fig.update_layout(
#     title=f"FC_\u03B4{delta}_snr{snr}_nk{noise_kernel}_nls{noise_length_scale}_nv{noise_nu}_Xk{X_kernel}_Xls{X_length_scale}_Xv{X_nu}",
    title=f"Blurred Forest: W vs Y Refit, Constant \u03bc" if const_mu \
            else f"Blurred Forest: W vs Y Refit, Linear \u03bc",
    xaxis_title="Method",
    yaxis_title="MSE",
#     legend_title="Legend Title",
#     font=dict(
#         family="Courier New, monospace",
#         size=18,
#         color="RebeccaPurple"
#     )
)
# barfp = os.path.expanduser(savedir + fig.layout.title['text'] + ".jpeg")
# fig.write_image(os.path.expanduser(barfp))
fig.show()

In [19]:
"""
1.5 snr doesnt work, neither does .4 snr. 
.1 snr sort of works?
Maybe we need a more middle ground snr? 
No snr 1 also favors OLS.
Lets try very small snr...
Seems to help for small snr.
"""

'\n1.5 snr doesnt work, neither does .4 snr. \n.1 snr sort of works?\nMaybe we need a more middle ground snr? \nNo snr 1 also favors OLS.\nLets try very small snr...\nSeems to help for small snr.\n'