In [1]:
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso, LinearRegression

import plotly.express as px
import plotly.graph_objects as go

from spe.mse_estimator import ErrorComparer

In [2]:
err_cmp = ErrorComparer()

In [3]:
alphas = [0.05, 0.1, 0.2, 0.5, 0.8, 1.]
na = len(alphas)
niter = 500

test_err = np.zeros((na, niter))
test_err_alpha = np.zeros((na, niter))
cb_err = np.zeros((na, niter))
blur_err = np.zeros((na, niter))

n=100
p=20
s=20
X = np.random.randn(n,p)
beta = np.zeros(p)
idx = np.random.choice(p,size=s)
beta[idx] = np.random.uniform(-1,1,size=s)

for i,alpha in enumerate(alphas):
    print(i)
    (test_err[i,:],
     test_err_alpha[i,:],
     cb_err[i,:],
     blur_err[i,:]) = err_cmp.compareBlurLinearIID(n=n,
                                         p=p,
                                         X=X,
                                         beta=beta,
                                         model=LinearRegression(),
                                         alpha=alpha,
                                         niter=niter,
                                         est_risk=True)
    

0
1
2
3
4
5


In [4]:
risk = test_err.mean(axis=1)
risk_alpha = test_err_alpha.mean(axis=1)
risk_cb = cb_err.mean(axis=1)
risk_blur = blur_err.mean(axis=1)
risk, risk_alpha, risk_cb, risk_blur

(array([2.34085976, 2.40733057, 2.15594048, 2.2772265 , 2.13646869,
        2.29780957]),
 array([2.4312816 , 2.4033745 , 2.74926531, 3.55759144, 4.03164345,
        4.56949081]),
 array([2.38863936, 2.48843957, 2.69205666, 3.47333189, 4.07614863,
        4.42542491]),
 array([2.16572078, 2.28303875, 2.25356965, 2.37485315, 2.35076554,
        2.26791321]))

In [5]:
cb_df = pd.DataFrame(cb_err.T)
blur_df = pd.DataFrame(blur_err.T)

In [6]:
blur_df

Unnamed: 0,0,1,2,3,4,5
0,5.915851,1.046028,4.066821,1.388659,3.565509,1.668968
1,-1.725807,-0.384717,2.363485,1.289095,4.260051,2.234188
2,0.982120,2.761910,2.268700,1.515250,1.912953,2.844134
3,0.866790,3.354935,2.703584,3.280134,2.664889,1.022247
4,1.953031,1.702743,3.214920,2.736015,1.012824,0.382558
...,...,...,...,...,...,...
495,1.630051,2.495394,1.617694,3.647081,2.028327,4.015063
496,3.183096,3.625538,1.744180,4.331681,4.292293,2.110333
497,0.215142,1.836885,1.192327,2.272167,2.796289,2.719034
498,0.704502,1.549663,2.800713,3.190608,1.649287,1.807094


In [9]:
fig = px.box(blur_df)
fig.update_traces(boxmean=True)
fig.add_trace(go.Scatter(x=[-1,6], 
                         y=[test_err.mean(),test_err.mean()], 
                         mode='lines', 
                         name='err'))

In [8]:
fig = px.box(cb_df)
fig.update_traces(boxmean=True)
fig.add_trace(go.Scatter(x=[-1,6], 
                         y=[test_err.mean(),test_err.mean()], 
                         mode='lines', 
                         name='err'))
fig.add_trace(go.Scatter(x=np.arange(6), 
                         y=risk_alpha, 
                         mode='markers', 
                         name='err_alpha'))