In [1]:
import numpy as np
import pandas as pd
from sklearn.gaussian_process.kernels import Matern, RBF
from sklearn.linear_model import RidgeCV
from sklearn.ensemble import RandomForestRegressor

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from doc.mse_estimator import ErrorComparer
from doc.data_generation import gen_rbf_X, gen_matern_X, create_clus_split, gen_cov_mat
from doc.plotting_utils import gen_model_barplots
from spe.relaxed_lasso import RelaxedLasso
from spe.tree import Tree
from spe.smoothers import LinearRegression
from spe.estimators import kfoldcv, kmeanscv, better_test_est_split, cp_general_train_test

In [2]:
np.random.seed(1)


In [3]:
## number of realizations to run
niter = 100

## data generation parameters
n=20**2
p=5
s=5
delta = 0.75
snr = 0.4
tr_frac = .25

noise_kernel = 'matern'
noise_length_scale = 1.
noise_nu = .5

X_kernel = 'matern'
X_length_scale = 5.
X_nu = 2.5

## ErrorComparer parameters
alpha = .05
nboot = 100
k = 5
max_depth = 2
lambd = .05
models = [
    Tree(max_depth=max_depth), 
    RidgeCV(alphas=[.01,.1, 1.]), 
    RandomForestRegressor(n_estimators=100, max_depth=max_depth)
]
ests = [
    better_test_est_split,
    cp_general_train_test,
]
est_kwargs = [
    {},
    {'use_trace_corr': False, 
     'alpha': alpha},
]

model_names = [
    "Decision Tree", 
    "Ridge CV", 
    "Random Forest"
]
est_names = ["EstCov", "Oracle"]

In [4]:
err_cmp = ErrorComparer()

In [5]:
nx = ny = int(np.sqrt(n))
xs = np.linspace(0, 20, nx)
ys = np.linspace(0, 20, ny)
c_x, c_y = np.meshgrid(xs, ys)
c_x = c_x.flatten()
c_y = c_y.flatten()
coord = np.stack([c_x, c_y]).T

In [6]:
if noise_kernel == 'rbf':
    Sigma_t = gen_cov_mat(c_x, c_y, RBF(length_scale=noise_length_scale))
elif noise_kernel == 'matern':
    Sigma_t = gen_cov_mat(c_x, c_y, Matern(length_scale=noise_length_scale, nu=noise_nu))
else:
    Sigma_t = np.eye(n)
    
Cov_st = delta*Sigma_t
Sigma_t = delta*Sigma_t + (1-delta)*np.eye(n)

if noise_kernel == 'rbf' or noise_kernel == 'matern':
    Chol_t = np.linalg.cholesky(Sigma_t)
else:
    Chol_t = np.eye(n)

In [7]:
if X_kernel == 'rbf':
    X = gen_rbf_X(c_x, c_y, p)
elif X_kernel == 'matern':
    X = gen_matern_X(c_x, c_y, p, length_scale=X_length_scale, nu=X_nu)
else:
    X = np.random.randn(n,p)

beta = np.zeros(p)
idx = np.random.choice(p,size=s,replace=False)
beta[idx] = np.random.uniform(-1,1,size=s)

In [8]:
tr_idx = np.ones(n, dtype=bool)

# Simulate $Y, Y^* \overset{iid}{\sim} \mathcal{N}(\mu, \Sigma_Y)$

## Estimate with True Covs

In [9]:
true_model_errs = []
for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr, 
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_t,
        Chol_ystar=None,
        Cov_y_ystar=None,
        tr_idx=tr_idx,
        fair=False,
        piecewise_const_mu=True,
    )
    true_model_errs.append(errs)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:10<00:00,  9.59it/s]
100%|██████████| 100/100 [00:09<00:00, 10.27it/s]
100%|██████████| 100/100 [12:30<00:00,  7.50s/it]


## Estimate with Matching Model

In [10]:
from sklearn.base import clone

In [11]:
model_errs = []
for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr, 
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_t,
        Chol_ystar=None,
        Cov_y_ystar=None,
        tr_idx=tr_idx,
        fair=False,
        piecewise_const_mu=True,
        est_sigma=True,
        est_sigma_model=clone(model),
    )
    model_errs.append(errs)

100%|██████████| 100/100 [02:00<00:00,  1.21s/it]
100%|██████████| 100/100 [01:59<00:00,  1.20s/it]
100%|██████████| 100/100 [14:21<00:00,  8.62s/it]


In [12]:
comb_errs = []
for true_err, err in zip(true_model_errs, model_errs):
    comb_errs.append([true_err[0], err[1], true_err[1]])

In [13]:
fig = gen_model_barplots(
    comb_errs, 
    model_names, 
    est_names, 
    title="Comparing Performance: Estimated vs Oracle Covariance NSN",
    err_bars=True
)
fig.show()

# Simulate $\begin{pmatrix} Y \\ Y^* \end{pmatrix} \sim \mathcal{N}\left(\begin{pmatrix} \mu \\ \mu \end{pmatrix}, \begin{pmatrix}\Sigma_Y & \Sigma_{Y, Y^*} \\ \Sigma_{Y^*, Y} & \Sigma_{Y}  \end{pmatrix}\right)$

## Estimate with True Covs

In [14]:
true_corr_model_errs = []
for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr, 
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_t,
        Chol_ystar=None,
        Cov_y_ystar=Cov_st,
        tr_idx=tr_idx,
        fair=False,
        piecewise_const_mu=True,
    )
    true_corr_model_errs.append(errs)

100%|██████████| 100/100 [00:18<00:00,  5.44it/s]
100%|██████████| 100/100 [00:14<00:00,  6.79it/s]
100%|██████████| 100/100 [12:25<00:00,  7.46s/it]


## Estimate with Matching Model

In [15]:
corr_model_errs = []
for model in models:
    errs = err_cmp.compare(
        model,
        ests,
        est_kwargs,
        niter=niter,
        n=n,
        p=p,
        s=s,
        snr=snr, 
        X=X,
        beta=beta,
        coord=coord,
        Chol_y=Chol_t,
        Chol_ystar=None,
        Cov_y_ystar=Cov_st,
        tr_idx=tr_idx,
        fair=False,
        piecewise_const_mu=True,
        est_sigma="corr_resp",
        est_sigma_model=clone(model),
    )
    corr_model_errs.append(errs)

100%|██████████| 100/100 [02:10<00:00,  1.31s/it]
100%|██████████| 100/100 [02:05<00:00,  1.26s/it]
100%|██████████| 100/100 [14:25<00:00,  8.66s/it]


In [16]:
comb_corr_errs = []
for true_err, err in zip(true_corr_model_errs, corr_model_errs):
    comb_corr_errs.append([true_err[0], err[1], true_err[1]])

In [17]:
fig = gen_model_barplots(
    comb_corr_errs, 
    model_names, 
    est_names, 
    title="Comparing Performance: Estimated vs Oracle Covariance SSN",
    err_bars=True
)
fig.show()