In [1]:
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso, LinearRegression

import plotly.express as px
import plotly.graph_objects as go

from spe.mse_estimator import ErrorComparer
from spe.relaxed_lasso import RelaxedLasso
from spe.tree import Tree

In [2]:
err_cmp = ErrorComparer()

In [None]:
alphas = [0.05, 0.1, 0.2, 0.5, 0.8, 1.]
na = len(alphas)
niter = 100

test_err = np.zeros((na, niter))
test_err_alpha = np.zeros((na, niter))
cb_err = np.zeros((na, niter))
blur_err = np.zeros((na, niter))

n=100
p=30
s=10
X = np.random.randn(n,p)
beta = np.zeros(p)
idx = np.random.choice(p,size=s)
beta[idx] = np.random.uniform(-1,1,size=s)

for i,alpha in enumerate(alphas):
    print(i)
    (test_err[i,:],
     test_err_alpha[i,:],
     cb_err[i,:],
     blur_err[i,:]) = err_cmp.compareTreeIID(n=n,
                                         p=p,
                                         X=X,
                                         beta=beta,
                                         model=Tree(max_depth=4),
                                         alpha=alpha,
                                         niter=niter,
                                         est_risk=True)
    

0
1
2
3
4


In [None]:
risk = test_err.mean(axis=1)
risk_alpha = test_err_alpha.mean(axis=1)
risk_cb = cb_err.mean(axis=1)

In [None]:
blur_df = pd.DataFrame(blur_err.T - test_err.T)
cb_df = pd.DataFrame(cb_err.T)

In [None]:
fig = px.box(cb_df)
fig.update_traces(boxmean=True)
fig.add_trace(go.Scatter(x=[-1,6], 
                         y=[test_err.mean(),test_err.mean()], 
                         mode='lines', 
                         name='err'))
fig.add_trace(go.Scatter(x=np.arange(6), 
                         y=risk_alpha, 
                         mode='markers', 
                         name='err_alpha'))
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = np.arange(na),
        ticktext = alphas,
    )
)

In [None]:
fig_blur = px.box(blur_df, 
             labels={
                     "variable": "Alpha",
                     "value": "Risk"
                     },
             title="Blurred n=100, p=30")
fig_blur.update_traces(boxmean=True)
# fig_blur.add_trace(go.Scatter(x=[-1,6], 
#                          y=[test_err.mean(),test_err.mean()], 
#                          mode='lines', 
#                          name='err'))
fig_blur.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = np.arange(na),
        ticktext = alphas,
    )
)