# VB Lasso by Laplace posterior
+ We consider the following approximated posterior distribution q(w):
$$
p(y|x,w) = N(y|x^T w,1), p(w) = Laplace(w, 0, \beta)  \\
q(w) \propto \prod_{j=1}^M \exp(-\frac{1}{\sigma_j} |w_j - \mu_j|),
$$
where $\sigma_j \in \mathbb{R}_+, \mu_j, w_j \in \mathbb{R}$

In [1]:
import numpy as np
from scipy.optimize import minimize
from sklearn.linear_model import LassoCV, Lasso

In [2]:
# mu = -1
# sigma = 0.1

# val = np.random.laplace(loc = mu, scale = sigma/np.sqrt(2), size = 10000)

# (1/sigma**2 - 2*np.sqrt(2)/sigma**3*np.abs(val - mu) + 2/sigma**4*(val - mu)**2).mean()

# 1/sigma**2

# (3-np.sqrt(2))/sigma**2

## Problem settings

In [81]:
n = 100
data_seed = 20210103
M = 100
zero_ratio = 0.5
n_zero_ind = int(M*zero_ratio) # # of zero elements in the parameter

## Data Generation

In [82]:
np.random.seed(data_seed)

In [83]:
true_w = np.random.normal(scale = 3, size = M)
zero_ind = np.random.choice(M, size = n_zero_ind)
true_w[zero_ind] = 0

## Learning settings

In [115]:
iteration = 500
ln_seed = 20210105
np.random.seed(ln_seed)
pri_beta = 0.1
rho = 0.1

In [116]:
### initialization
est_mu = np.random.normal(size = M)
est_ln_sigma = np.random.normal(size = M)
est_sigma = np.exp(est_ln_sigma)
est_pri_beta = pri_beta

In [117]:
def calc_energy(post_mu: np.ndarray, post_ln_sigma: np.ndarray, X:np.ndarray, y:np.ndarray, pri_beta: float) -> float:
    post_sigma = np.exp(post_ln_sigma)
    n, M = X.shape
    energy = 0
    energy += ((y-X@post_mu)**2).sum()/2 + (X**2).sum(axis=0)@post_sigma**2/2 + n/2*np.log(2*np.pi) - M - M*np.log(pri_beta)
    energy += (-np.log(post_sigma) + np.sqrt(2)/pri_beta*np.abs(post_mu) + post_sigma/pri_beta*np.exp(-post_sigma/np.sqrt(2)*np.abs(post_mu))).sum()    
    return energy

In [126]:
def calc_energy_dash(post_mu: np.ndarray, post_ln_sigma: np.ndarray, X:np.ndarray, y:np.ndarray, pri_beta: float) -> float:
    post_sigma = np.exp(post_ln_sigma)
    n, M = X.shape
    energy = 0
    energy += ((y-X@post_mu)**2).sum()/2 + (X**2).sum(axis=0)@post_sigma**2/2 + n/2*np.log(2*np.pi) - n*M - n*M*np.log(pri_beta)
    energy += n*(-np.log(post_sigma) + np.sqrt(2)/pri_beta*np.abs(post_mu) + post_sigma/pri_beta*np.exp(-post_sigma/np.sqrt(2)*np.abs(post_mu))).sum()    
    return energy

In [127]:
def calc_energy_wrapper(est_params: np.ndarray, X:np.ndarray, y:np.ndarray, pri_beta: float) -> np.ndarray:
    post_mu = est_params[:M]
    post_ln_sigma = est_params[M:]
    return calc_energy(post_mu, post_ln_sigma, X, y, pri_beta)
    pass

In [131]:
def df_param(X :np.ndarray, y :np.ndarray, mu: np.ndarray, sigma: np.ndarray, pri_beta: float):
    pdf_mu = np.exp(-sigma/np.sqrt(2)*np.abs(mu))*sigma/pri_beta
    dFdm = -X.T @ (y - X @ mu) + np.sqrt(2)/pri_beta*np.sign(mu) - pdf_mu*sigma/np.sqrt(2)*np.sign(mu) 
    dFds = (X**2).sum(axis = 0)*sigma - 1/sigma + pdf_mu*(1-sigma/np.sqrt(2)*np.abs(mu))
    
    return dFdm, dFds

In [132]:
def df_param_dash(X :np.ndarray, y :np.ndarray, mu: np.ndarray, sigma: np.ndarray, pri_beta: float):
    n = len(y)
    pdf_mu = np.exp(-sigma/np.sqrt(2)*np.abs(mu))*sigma/pri_beta
    dFdm = -X.T @ (y - X @ mu) + n*np.sqrt(2)/pri_beta*np.sign(mu) - n*pdf_mu*sigma/np.sqrt(2)*np.sign(mu) 
    dFds = (X**2).sum(axis = 0)*sigma - n/sigma + n*pdf_mu*(1-sigma/np.sqrt(2)*np.abs(mu))
    
    return dFdm, dFds

In [133]:
train_X = np.random.normal(size = (n, M))
train_Y = train_X @ true_w + np.random.normal(size = n)

test_X = np.random.normal(size = (n, M))
test_Y = test_X @ true_w + np.random.normal(size = n)

In [134]:
for ite in range(iteration):
    res = minimize(
        fun=calc_energy_wrapper, x0=np.hstack([est_mu, est_ln_sigma]), 
        args=(train_X, train_Y, est_pri_beta), method = "L-BFGS-B", options={"disp":True, "maxiter": 1}
    )

    est_mu = res.x[:M]
    est_ln_sigma = res.x[M:]
    est_sigma = np.exp(est_ln_sigma)
    est_pri_beta = (np.sqrt(2)*np.abs(est_mu) + est_sigma*np.exp(-est_sigma/np.sqrt(2)*np.abs(est_mu))).mean()

    dFdm, dFds = df_param(train_X, train_Y, est_mu, est_sigma, est_pri_beta)
    
    print(res.fun, (dFdm**2).mean(), (dFds**2).mean())

1931.0152797244828 6998.757629084297 361.3090189004456
776.185448237904 3602.430510076956 297.35180439975255
-142.0698916096917 2168.1732437818796 224.90760728795473
-829.4138848621969 1106.3757354395493 168.24963320815365
-1426.313077754592 1290.547256599485 82.99374326707839
-1655.5284416055438 1041.2277001236819 68.44429027653547
-1856.5010511454238 1123.9462084722973 43.245869438792994
-1931.7114919905835 1008.7646315277815 36.53585005248871
-1896.2149430082209 1435.569593361696 34.83220889071314
-1923.6508631050474 969.1640803885936 37.17323443430589
-1943.3151772102956 1230.2854000743034 28.199805526101294
-1908.470679649974 588.440027636655 21.890480991641446
-1836.6631604094127 598.7997135798199 19.34774099931843
-1781.8630141637932 717.2761229083942 18.969374566188073
-1755.8945085795021 667.1947584459471 18.89603474316035
-1723.9023767966792 578.1063629269329 18.465917316003083
-1708.3282529723638 526.6807946375837 18.410774932980015
-1699.4896299073534 484.6772097374365 18.2

In [136]:
post_mu = res.x[:M]
post_sigma = np.exp(res.x[M:])

In [139]:
clf = LassoCV(fit_intercept=False)
clf.fit(train_X, train_Y)

LassoCV(fit_intercept=False)

In [140]:
print(np.sqrt(((test_Y - test_X@clf.coef_)**2).mean()))
print(np.sqrt(((test_Y - test_X@post_mu)**2).mean()))
print(np.sqrt(((test_Y - test_X@true_w)**2).mean()))

2.938704474599947
7.082817415565786
0.8810319949267867


In [16]:
dFdm = -train_X.T @ (train_Y - train_X @ est_mu) + 1/pri_beta * np.sign(est_mu) * (1 - np.exp(-np.abs(est_mu) / est_sigma))
dFds = -1/est_sigma + (train_X**2).sum(axis = 0) + 1/pri_beta * (1 + np.abs(est_mu)/est_sigma) * np.exp(-np.abs(est_mu) / est_sigma)

In [17]:
est_mu += rho * (dFdm * est_sigma **2)
est_sigma += rho * (dFds * est_sigma **2)

In [56]:
print((dFdm**2).sum(), (dFds**2).sum())

2.689193402136793e+21 542528.8645508956


In [10]:
est_mu + 0.1 * (dFds * est_sigma**2)

array([ 6.30501282e-01,  7.37448667e+01,  2.57156858e+02,  2.57702044e+01,
        3.39545755e+01,  1.00792194e+00,  2.47826443e+01,  8.28315896e+00,
       -3.82877467e-01,  4.01311542e+01,  7.60040766e+01,  3.42574308e+01,
        5.44876974e-01,  1.31971408e+02, -2.31088035e-02,  4.75951255e-01,
        1.19668983e+01,  5.09313390e+02,  4.33652351e+00,  1.65465368e+00,
       -3.42007692e-01,  1.10922144e+01,  1.26127582e+00,  3.42180033e+01,
        1.76372057e+01, -2.01576455e+00,  1.35855242e+00,  2.73004442e+01,
        1.00778317e+00,  2.49418258e+01, -2.15797015e+00,  1.18996469e+01,
        1.97418817e-01,  1.96923160e+01,  5.48687120e+00,  3.62454122e+00,
        4.89942007e+00,  3.79221774e+00, -1.08543985e+00,  6.05900177e+01,
        1.46812050e+02,  6.13162631e+01,  2.52597511e+01,  2.87435486e+00,
        2.01384358e+02,  3.40583173e+01, -8.12667663e-01,  1.24548863e+01,
        5.73056256e+00,  1.20841744e+02])

In [16]:
est_sigma + 0.1 * (dFds * est_sigma**2)

array([2.25757256e+00, 7.56023673e+01, 2.60003562e+02, 2.60926779e+01,
       3.52623437e+01, 1.30019205e+00, 2.52314193e+01, 9.47386933e+00,
       1.81878410e-01, 4.18130682e+01, 7.82506706e+01, 3.54967294e+01,
       5.83919137e-01, 1.35395147e+02, 4.41249399e-02, 7.13589330e-01,
       1.36337361e+01, 5.14971583e+02, 3.82355226e+00, 4.60986353e-01,
       5.70212864e-03, 1.16247360e+01, 6.63534276e-01, 3.52783802e+01,
       1.74111090e+01, 1.40961292e-01, 1.17115727e+00, 2.98735137e+01,
       9.31138137e-01, 2.53395838e+01, 3.66390783e-01, 1.27713890e+01,
       3.75304519e-01, 2.00764047e+01, 6.72967913e+00, 3.67531632e+00,
       5.34529263e+00, 5.92476623e+00, 1.81456601e-01, 6.11720515e+01,
       1.49515153e+02, 6.24096767e+01, 2.56243824e+01, 1.96813449e+00,
       2.03849137e+02, 3.53228655e+01, 8.24814105e-02, 1.23586751e+01,
       6.44384301e+00, 1.24870838e+02])

In [22]:
est_mu

array([-1.21298429e+00,  1.39156558e-01,  6.35718810e-01,  7.76853396e-01,
       -1.17578905e-03, -4.89157194e-02,  6.61150335e-01, -4.99452990e-01,
       -4.76905446e-01, -2.91121709e-01, -3.96103021e-01,  8.00172260e-02,
        1.16305211e-01, -7.83693786e-01, -3.72149984e-02, -6.46044412e-02,
       -8.03995712e-01, -3.60124517e-01,  1.01896499e+00,  1.36878006e+00,
       -3.41814469e-01,  2.13188064e-01,  7.89948965e-01,  2.50055039e-01,
        1.23570020e+00, -2.07896325e+00,  4.59612633e-01, -1.33208976e+00,
        3.00631589e-01,  7.71761520e-01, -2.36390763e+00, -9.90179745e-02,
       -5.66988684e-02,  6.06024937e-01, -6.33522308e-01,  4.24033838e-01,
        6.06537565e-02, -1.49561072e+00, -1.16896228e+00,  1.17945905e+00,
       -7.24521304e-02,  5.12204259e-01,  8.42049785e-01,  1.28598127e+00,
        8.39340437e-01, -7.25821758e-02, -8.38987296e-01,  9.15702365e-01,
       -1.84879779e-01, -1.48530652e+00])

In [24]:
dFds * est_sigma**2

array([ 1.84348558e+01,  7.36057102e+02,  2.56521139e+03,  2.49933510e+02,
        3.39557513e+02,  1.05683766e+01,  2.41214939e+02,  8.78261195e+01,
        9.40279787e-01,  4.04222759e+02,  7.64001796e+02,  3.41774135e+02,
        4.28571763e+00,  1.32755102e+03,  1.41061949e-01,  5.40555696e+00,
        1.27708940e+02,  5.09673514e+03,  3.31755851e+01,  2.85873621e+00,
       -1.93222992e-03,  1.08790264e+02,  4.71326852e+00,  3.39679483e+02,
        1.64015055e+02,  6.31986981e-01,  8.98939786e+00,  2.86325340e+02,
        7.07151578e+00,  2.41700643e+02,  2.05937474e+00,  1.19986649e+02,
        2.54117685e+00,  1.90862910e+02,  6.12039351e+01,  3.20050738e+01,
        4.83876631e+01,  5.28782846e+01,  8.35224362e-01,  5.94105586e+02,
        1.46884502e+03,  6.08040588e+02,  2.44177014e+02,  1.58837360e+01,
        2.00545017e+03,  3.41308995e+02,  2.63196332e-01,  1.15391839e+02,
        5.91544234e+01,  1.22327050e+03])