In [1]:
import numpy as np

from mrashpen.inference.penalized_regression import PenalizedRegression as PLR
from mrashpen.inference.mrash_wrapR          import MrASHR

import sys
sys.path.append('/home/saikat/Documents/work/sparse-regression/simulation/eb-linreg-dsc/dsc/functions')
import simulate

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')

In [2]:
def center_and_scale(Z):
    dim = Z.ndim
    if dim == 1:
        Znew = Z / np.std(Z)
        Znew = Znew - np.mean(Znew)
    elif dim == 2:
        Znew = Z / np.std(Z, axis = 0)
        Znew = Znew - np.mean(Znew, axis = 0).reshape(1, -1)
    return Znew

def initialize_ash_prior(k, scale = 2, sparsity = None):
    w = np.zeros(k)
    w[0] = 1 / k if sparsity is None else sparsity
    w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
    w[k-1] = 1 - np.sum(w)
    sk2 = np.square((np.power(scale, np.arange(k) / k) - 1))
    prior_grid = np.sqrt(sk2)
    return w, prior_grid

In [3]:
n = 200
p = 2000
p_causal = 10
pve = 0.7
k = 20

X, y, Xtest, ytest, btrue, strue = simulate.equicorr_predictors(n, p, p_causal, pve, rho = 0.95, seed = 10)
X      = center_and_scale(X)
Xtest  = center_and_scale(Xtest)
winit, sk = initialize_ash_prior(k, scale = 2)
s2init = np.var(y - np.mean(y))

In [4]:
np.sum(np.square(X), axis = 0)

array([200., 200., 200., ..., 200., 200., 200.])

In [5]:
np.std(X, axis = 0)

array([1., 1., 1., ..., 1., 1., 1.])

In [6]:
binit = btrue.copy()

In [7]:
s2init

1.1723683608566409

In [8]:
binit

array([0., 0., 0., ..., 0., 0., 0.])

In [9]:
winit

array([0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
       0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05])

In [12]:
'''
mr.ash.pen
'''
plr_lbfgs = PLR(method = 'L-BFGS-B', optimize_w = True, optimize_s = True, is_prior_scaled = True,
                function_call = 'python',
                debug = False, display_progress = False, calculate_elbo = False)
plr_lbfgs.fit(X, y, sk, binit = binit, winit = winit, s2init = s2init, is_binit_coef = False)

mr.ash.pen terminated at iteration 374.


In [13]:
plr_lbfgs.theta

array([ 0.00827775, -0.01586799, -0.00934386, ..., -0.01891988,
       -0.0078695 , -0.00292602])

In [14]:
plr_lbfgs._prior_path[-1]

array([ 22.04245413, -36.88411924, -22.00432855, -14.54781646,
       -10.19073587,  -7.57034881,  -5.87597124,  -4.68921511,
        -3.79360013,  -3.06489404,  -2.41730387,  -1.77556332,
        -1.06800853,  -0.23396459,   0.76611683,   1.94130216,
         3.24896859,   4.52763989,   5.35762713,  16.31711957])

In [15]:
plr_lbfgs.prior

array([9.96748255e-01, 2.55328816e-26, 7.40134793e-20, 1.28124705e-16,
       9.99744295e-15, 1.37375275e-13, 7.47768912e-13, 2.45001668e-12,
       5.99970208e-12, 1.24337669e-11, 2.37600612e-11, 4.51390005e-11,
       9.15881062e-11, 2.10891831e-10, 5.73310112e-10, 1.85680418e-09,
       6.86559877e-09, 2.46602975e-08, 5.65532004e-08, 3.25165363e-03])

In [16]:
plr_lbfgs.residual_var

0.43833099528365405

In [17]:
plr_lbfgs.fitobj

      fun: -5041.146950859367
 hess_inv: <2021x2021 LbfgsInvHessProduct with dtype=float64>
      jac: array([-1.17304005e-06,  5.39617913e-05,  1.31685177e-05, ...,
        1.72867928e-05, -3.86212122e-04,  4.96064654e-03])
  message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 442
      nit: 374
     njev: 442
   status: 0
  success: True
        x: array([ 8.27774657e-03, -1.58679856e-02, -9.34385706e-03, ...,
        5.35762713e+00,  1.63171196e+01,  4.38330995e-01])

In [18]:
from libmrashpen_lbfgs_driver import lbfgsb_driver

In [19]:
ftol = 1e-9
gtol = 1e-9
nopt = p + k + 1

f_theta, f_wk, f_s2, f_obj, f_grad, f_nfev, f_niter, f_task = \
    lbfgsb_driver.min_plr_shrinkop(X, y, binit, winit, s2init, sk,
                                   nopt, True, True, True, 1.0, 10, 0,
                                   ftol / np.finfo(float).eps, gtol, 1000, 10000)

RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =         2021     M =           10

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
 2021    247    282    248     0     0   3.565D-02  -5.041D+03

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             

 Total User time 2.619E+00 seconds.



In [20]:
f_theta

array([ 0.01009444, -0.01879536, -0.01119446, ..., -0.00419092,
       -0.00980171, -0.00361926])

In [21]:
f_wk

array([9.96745281e-01, 2.23785418e-20, 9.05796818e-16, 2.26697513e-13,
       7.88065775e-12, 7.11215311e-11, 2.97958068e-10, 8.12643402e-10,
       1.73025127e-09, 3.19444606e-09, 5.49318638e-09, 9.35819571e-09,
       1.67498625e-08, 3.31360493e-08, 7.50916253e-08, 1.98539322e-07,
       6.14352934e-07, 2.22553685e-06, 1.03759391e-05, 3.24115917e-03])

In [22]:
f_s2

0.4394442629102773

In [23]:
f_obj

-5040.889706224509

In [22]:
f_grad

array([ 2.26235204e-04, -3.42191535e-04, -2.02569985e-04, ...,
        3.16099967e-03, -4.99098637e-05, -5.28579074e-03])

In [23]:
from libmrashpen_plr_mrash import plr_mrash as flib_penmrash
dj = np.sum(np.square(X), axis = 0)
djinv = 1 / dj
obj, bgrad, wgrad, s2grad = \
    flib_penmrash.plr_obj_grad_shrinkop(X, y, f_theta, f_s2, f_wk, sk, djinv)

In [24]:
obj

-5040.889706224509

In [25]:
bgrad

array([ 2.26235204e-04, -3.42191535e-04, -2.02569985e-04, ...,
        1.81484759e-03, -1.73435910e-04, -7.66787520e-05])

In [26]:
f_grad[:p]

array([ 2.26235204e-04, -3.42191535e-04, -2.02569985e-04, ...,
        1.81484759e-03, -1.73435910e-04, -7.66787520e-05])

In [27]:
akjac = f_wk.reshape(-1, 1) * (np.eye(k) - f_wk)
agrad  = np.sum(wgrad * akjac, axis = 1)
agrad

array([-5.45004731e-03,  4.52505465e-18,  5.25561709e-13,  2.03591114e-10,
        8.86649960e-09,  9.14476220e-08,  4.17936240e-07,  1.21101716e-06,
        2.69261897e-06,  5.11912180e-06,  8.93508809e-06,  1.51825867e-05,
        2.64755902e-05,  4.93981167e-05,  1.01012881e-04,  2.27012177e-04,
        5.48627340e-04,  1.35277251e-03,  3.16099967e-03, -4.99098637e-05])

In [28]:
f_grad[p:p+k]

array([-5.45004731e-03,  4.52505465e-18,  5.25561709e-13,  2.03591114e-10,
        8.86649960e-09,  9.14476220e-08,  4.17936240e-07,  1.21101716e-06,
        2.69261897e-06,  5.11912180e-06,  8.93508809e-06,  1.51825867e-05,
        2.64755902e-05,  4.93981167e-05,  1.01012881e-04,  2.27012177e-04,
        5.48627340e-04,  1.35277251e-03,  3.16099967e-03, -4.99098637e-05])

In [29]:
s2grad

-0.005285790738071228

In [30]:
f_grad[-1]

-0.005285790738071228

In [31]:
binit

array([0., 0., 0., ..., 0., 0., 0.])

In [32]:
winit

array([0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
       0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05])

In [33]:
s2init

1.1723683608566409

In [42]:
task_str = f_task.strip(b'\x00').strip()
if task_str.startswith(b'CONV'):
    warnflag = 0
elif f_nfev > 10000 or f_niter >= 1000:
    warnflag = 1
else:
    warnflag = 2
    
print(warnflag)

0


In [37]:
%%timeit -n 2 -r 3

'''
mr.ash.pen
'''
plr_lbfgs = PLR(method = 'L-BFGS-B', optimize_w = True, optimize_s = True, is_prior_scaled = True,
                function_call = 'python',
                debug = False, display_progress = False, calculate_elbo = False)
plr_lbfgs.fit(X, y, sk, binit = binit, winit = winit, s2init = s2init, is_binit_coef = False)

mr.ash.pen terminated at iteration 374.
mr.ash.pen terminated at iteration 374.
mr.ash.pen terminated at iteration 374.
mr.ash.pen terminated at iteration 374.
mr.ash.pen terminated at iteration 374.
mr.ash.pen terminated at iteration 374.
3.67 s ± 25.5 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)


In [38]:
%%timeit -n 2 -r 3

ftol = 1e-9
gtol = 1e-9
nopt = p + k + 1

f_theta, f_wk, f_s2, f_obj, f_grad, f_nfev, f_niter, f_task = \
    lbfgsb_driver.min_plr_shrinkop(X, y, binit, winit, s2init, sk,
                                   nopt, True, True, True, 1.0, 10, -1,
                                   ftol / np.finfo(float).eps, gtol, 1000, 10000)

656 ms ± 15.7 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)


In [36]:
%%timeit -n 2 -r 3
'''
mr.ash.alpha
'''
mrash_r = MrASHR(option = "r2py", debug = False)
mrash_r.fit(X, y, sk, binit = binit, winit = winit, s2init = s2init)

Mr.ASH terminated at iteration 333.
Mr.ASH terminated at iteration 333.
Mr.ASH terminated at iteration 333.
Mr.ASH terminated at iteration 333.
Mr.ASH terminated at iteration 333.
Mr.ASH terminated at iteration 333.
1.18 s ± 204 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)
