In [1]:
import numpy as np

from mrashpen.inference.penalized_regression import PenalizedRegression as PLR
from mrashpen.inference.mrash_wrapR          import MrASHR

import sys
sys.path.append('/home/saikat/Documents/work/sparse-regression/simulation/eb-linreg-dsc/dsc/functions')
import simulate

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')

In [2]:
def center_and_scale(Z):
    dim = Z.ndim
    if dim == 1:
        Znew = Z / np.std(Z)
        Znew = Znew - np.mean(Znew)
    elif dim == 2:
        Znew = Z / np.std(Z, axis = 0)
        Znew = Znew - np.mean(Znew, axis = 0).reshape(1, -1)
    return Znew

def initialize_ash_prior(k, scale = 2, sparsity = None):
    w = np.zeros(k)
    w[0] = 1 / k if sparsity is None else sparsity
    w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
    w[k-1] = 1 - np.sum(w)
    sk2 = np.square((np.power(scale, np.arange(k) / k) - 1))
    prior_grid = np.sqrt(sk2)
    return w, prior_grid

In [3]:
n = 200
p = 2000
p_causal = 10
pve = 0.7
k = 20

X, y, Xtest, ytest, btrue, strue = simulate.equicorr_predictors(n, p, p_causal, pve, rho = 0.95, seed = 10)
X      = center_and_scale(X)
Xtest  = center_and_scale(Xtest)
winit, sk = initialize_ash_prior(k, scale = 2)
s2init = np.var(y - np.mean(y))

In [4]:
np.sum(np.square(X), axis = 0)

array([200., 200., 200., ..., 200., 200., 200.])

In [5]:
np.std(X, axis = 0)

array([1., 1., 1., ..., 1., 1., 1.])

In [6]:
binit = btrue.copy()

In [7]:
s2init

1.1723683608566409

In [8]:
binit

array([0., 0., 0., ..., 0., 0., 0.])

In [9]:
winit

array([0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
       0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05])

In [10]:
'''
mr.ash.pen
'''
plr_lbfgs = PLR(method = 'L-BFGS-B', optimize_w = True, optimize_s = True, is_prior_scaled = True,
                function_call = 'fortran',
                debug = False, display_progress = False, calculate_elbo = False)
plr_lbfgs.fit(X, y, sk, binit = binit, winit = winit, s2init = s2init, is_binit_coef = False)

mr.ash.pen terminated at iteration 258.


In [11]:
plr_lbfgs.theta

array([ 0.01044527, -0.02093494, -0.01016991, ..., -0.0203594 ,
       -0.00853223, -0.00203439])

In [12]:
plr_lbfgs._prior_path[-1]

array([ 16.76373324, -30.90183969, -18.53787182, -12.38810479,
        -8.82769906,  -6.69835665,  -5.32606541,  -4.36613573,
        -3.6412724 ,  -3.05017928,  -2.52281178,  -1.99712808,
        -1.41327558,  -0.71915941,   0.12339646,   1.13787037,
         2.3364266 ,   3.7259555 ,   5.35511334,  11.0327627 ])

In [13]:
plr_lbfgs.prior

array([9.96752557e-01, 1.98468717e-21, 4.64830994e-16, 2.17823618e-13,
       7.66248305e-12, 6.44363877e-11, 2.54161463e-10, 6.63745916e-10,
       1.37026955e-09, 2.47465402e-09, 4.19321646e-09, 7.09330790e-09,
       1.27178217e-08, 2.54603022e-08, 5.91263295e-08, 1.63065220e-07,
       5.40614509e-07, 2.16946391e-06, 1.10633510e-05, 3.23339287e-03])

In [14]:
plr_lbfgs.residual_var

0.43837860155871317

In [15]:
plr_lbfgs.fitobj

      fun: -5041.140923593008
 hess_inv: <2021x2021 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 2.50590746e-04, -7.00599914e-04, -9.51026311e-05, ...,
        3.34496755e-03, -1.34200589e-02,  1.35025264e-02])
  message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 312
      nit: 258
     njev: 312
   status: 0
  success: True
        x: array([ 1.04452711e-02, -2.09349424e-02, -1.01699125e-02, ...,
        5.35511334e+00,  1.10327627e+01,  4.38378602e-01])

In [16]:
from libmrashpen_lbfgs_driver import lbfgsb_driver

In [17]:
ftol = 1e-9
gtol = 1e-9
nopt = p + k + 1

f_theta, f_wk, f_s2, f_obj, f_grad, f_nfev, f_niter, f_task = \
    lbfgsb_driver.min_plr_shrinkop(X, y, binit, winit, s2init, sk,
                                   nopt, True, True, True, 1.0, 10, 0,
                                   ftol / np.finfo(float).eps, gtol, 1000, 10000)

RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =         2021     M =           10

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
 2021    318    376    319     0     0   3.020D-02  -5.041D+03

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             

 Total User time 2.476E+01 seconds.



In [18]:
f_theta

array([ 0.00660328, -0.01767628, -0.00785757, ..., -0.01982615,
       -0.0063995 , -0.00199676])

In [19]:
f_wk

array([9.96746834e-01, 3.67207049e-22, 3.02066190e-17, 1.20587585e-14,
       5.87100018e-13, 6.55144453e-12, 3.15115735e-11, 9.45790018e-11,
       2.16068202e-10, 4.21284498e-10, 7.57388259e-10, 1.34053227e-09,
       2.48170004e-09, 5.04723374e-09, 1.15982127e-08, 3.01081442e-08,
       8.53058466e-08, 2.47176903e-07, 7.30379473e-07, 3.25205102e-03])

In [20]:
f_theta

array([ 0.00660328, -0.01767628, -0.00785757, ..., -0.01982615,
       -0.0063995 , -0.00199676])

In [21]:
f_s2

0.4394845826980046

In [22]:
f_obj

-5040.895274257286

In [23]:
f_grad

array([-0.00018766, -0.00018387,  0.00019238, ...,  0.00022294,
        0.00173484,  0.03020415])

In [24]:
from libmrashpen_plr_mrash import plr_mrash as flib_penmrash
dj = np.sum(np.square(X), axis = 0)
djinv = 1 / dj
obj, bgrad, wgrad, s2grad = \
    flib_penmrash.plr_obj_grad_shrinkop(X, y, f_theta, f_s2, f_wk, sk, djinv)

In [25]:
obj

-5040.895274257286

In [26]:
bgrad

array([-1.87663313e-04, -1.83874891e-04,  1.92379314e-04, ...,
        5.41203281e-05,  2.20965794e-04,  1.02697590e-04])

In [27]:
f_grad[:p]

array([-1.87663313e-04, -1.83874891e-04,  1.92379314e-04, ...,
        5.41203281e-05,  2.20965794e-04,  1.02697590e-04])

In [28]:
akjac = f_wk.reshape(-1, 1) * (np.eye(k) - f_wk)
agrad  = np.sum(wgrad * akjac, axis = 1)
agrad

array([-2.25048640e-03,  7.41242124e-20,  1.75044822e-14,  1.08199182e-11,
        6.60103322e-10,  8.41945449e-09,  4.41819530e-08,  1.40895392e-07,
        3.36150392e-07,  6.74955336e-07,  1.23172630e-06,  2.17457294e-06,
        3.92241198e-06,  7.52435452e-06,  1.56039499e-05,  3.44363991e-05,
        7.62209685e-05,  1.50386345e-04,  2.22943968e-04,  1.73483643e-03])

In [29]:
f_grad[p:p+k]

array([-2.25048640e-03,  7.41242124e-20,  1.75044822e-14,  1.08199182e-11,
        6.60103322e-10,  8.41945449e-09,  4.41819530e-08,  1.40895392e-07,
        3.36150392e-07,  6.74955336e-07,  1.23172630e-06,  2.17457294e-06,
        3.92241198e-06,  7.52435452e-06,  1.56039499e-05,  3.44363991e-05,
        7.62209685e-05,  1.50386345e-04,  2.22943968e-04,  1.73483643e-03])

In [30]:
s2grad

0.030204153668137224

In [31]:
f_grad[-1]

0.030204153668137224

In [32]:
binit

array([0., 0., 0., ..., 0., 0., 0.])

In [33]:
winit

array([0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
       0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05])

In [34]:
s2init

1.1723683608566409

In [35]:
%%timeit -n 2 -r 3

'''
mr.ash.pen
'''
plr_lbfgs = PLR(method = 'L-BFGS-B', optimize_w = True, optimize_s = True, is_prior_scaled = True,
                function_call = 'python',
                debug = False, display_progress = False, calculate_elbo = False)
plr_lbfgs.fit(X, y, sk, binit = binit, winit = winit, s2init = s2init, is_binit_coef = False)

mr.ash.pen terminated at iteration 213.
mr.ash.pen terminated at iteration 213.
mr.ash.pen terminated at iteration 213.
mr.ash.pen terminated at iteration 213.
mr.ash.pen terminated at iteration 213.
mr.ash.pen terminated at iteration 213.
4.52 s ± 138 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)


In [36]:
%%timeit -n 2 -r 3

ftol = 1e-9
gtol = 1e-9
nopt = p + k + 1

f_theta, f_wk, f_s2, f_obj, f_grad, f_nfev, f_niter, f_task = \
    lbfgsb_driver.min_plr_shrinkop(X, y, binit, winit, s2init, sk,
                                   nopt, True, True, True, 1.0, 10, -1,
                                   ftol / np.finfo(float).eps, gtol, 1000, 10000)

2.74 s ± 23.5 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)
