In [23]:
import numpy as np
from mrashpen.models.plr_ash import PenalizedMrASH as PenMrASH
from mrashpen.models.normal_means_ash_scaled import NormalMeansASHScaled
import libmrashpen_plr_mrash as flib
from libmrashpen_plr_mrash import plr_mrash as flib_penmrash
np.random.seed(200)

In [24]:
def _ash_data(n = 200, p = 2000, p_causal = 5, pve = 0.5, rho = 0.0, k = 6, seed = None):

    def sd_from_pve (X, b, pve):
        return np.sqrt(np.var(np.dot(X, b)) * (1 - pve) / pve)

    if seed is not None: np.random.seed(seed)

    '''
    ASH prior
    '''
    wk = np.zeros(k)
    wk[1:(k-1)] = np.repeat(1/(k-1), (k - 2)) 
    wk[k-1] = 1 - np.sum(wk)
    sk = np.arange(k)
    '''
    Equicorr predictors
    X is sampled from a multivariate normal, with covariance matrix V.
    V has unit diagonal entries and constant off-diagonal entries rho.
    '''
    iidX    = np.random.normal(size = n * p).reshape(n, p)
    comR    = np.random.normal(size = n).reshape(n, 1)
    X       = comR * np.sqrt(rho) + iidX * np.sqrt(1 - rho)
    bidx    = np.random.choice(p, p_causal, replace = False)
    b       = np.zeros(p)
    b[bidx] = np.random.normal(size = p_causal)
    sigma   = sd_from_pve(X, b, pve)
    y       = np.dot(X, b) + sigma * np.random.normal(size = n)
    return X, y, b, sigma, wk, sk

def center_and_scale(Z):
    dim = Z.ndim
    if dim == 1:
        Znew = Z / np.std(Z)
        Znew = Znew - np.mean(Znew)
    elif dim == 2:
        Znew = Z / np.std(Z, axis = 0)
        Znew = Znew - np.mean(Znew, axis = 0).reshape(1, -1)
    return Znew

In [25]:
print (flib.plr_mrash.objective_gradients.__doc__)

obj,bgrad,wgrad,s2grad = objective_gradients(x,y,b,stddev,wk,sk,djinv,[n,p,k])

Wrapper for ``objective_gradients``.

Parameters
----------
x : input rank-2 array('d') with bounds (n,p)
y : input rank-1 array('d') with bounds (n)
b : input rank-1 array('d') with bounds (p)
stddev : input float
wk : input rank-1 array('d') with bounds (k)
sk : input rank-1 array('d') with bounds (k)
djinv : input rank-1 array('d') with bounds (p)

Other Parameters
----------------
n : input int, optional
    Default: shape(x,0)
p : input int, optional
    Default: shape(x,1)
k : input int, optional
    Default: len(wk)

Returns
-------
obj : float
bgrad : rank-1 array('d') with bounds (p)
wgrad : rank-1 array('d') with bounds (k)
s2grad : float



In [26]:
print (flib.normal_means_ash_scaled.normal_means_ash_lml.__doc__)

lml,lml_bd,lml_wd,lml_s2d,lml_bd_bd,lml_bd_wd,lml_bd_s2d = normal_means_ash_lml(y,stddev,wk,sk,djinv,[ndim,ncomp])

Wrapper for ``normal_means_ash_lml``.

Parameters
----------
y : input rank-1 array('d') with bounds (ndim)
stddev : input float
wk : input rank-1 array('d') with bounds (ncomp)
sk : input rank-1 array('d') with bounds (ncomp)
djinv : input rank-1 array('d') with bounds (ndim)

Other Parameters
----------------
ndim : input int, optional
    Default: len(y)
ncomp : input int, optional
    Default: len(wk)

Returns
-------
lml : rank-1 array('d') with bounds (ndim)
lml_bd : rank-1 array('d') with bounds (ndim)
lml_wd : rank-2 array('d') with bounds (ndim,ncomp)
lml_s2d : rank-1 array('d') with bounds (ndim)
lml_bd_bd : rank-1 array('d') with bounds (ndim)
lml_bd_wd : rank-2 array('d') with bounds (ndim,ncomp)
lml_bd_s2d : rank-1 array('d') with bounds (ndim)



In [27]:
X, y, b, sigma, wk, sk = _ash_data(seed = 100)
#X = center_and_scale(X)
#y = y - np.mean(y)

In [28]:
pmash = PenMrASH(X, y, b, sigma, wk, sk, debug = True, is_prior_scaled = True)
obj = pmash.objective
bgrad, wgrad, s2grad = pmash.gradients
print (obj)
print (bgrad)
print (wgrad)
print (s2grad)

2021-12-07 23:35:46,988 | mrashpen.models.plr_ash | DEBUG | Calculating PLR objective with sigma2 = 3.166077616860522
1965.4559045114502
[-8.2179048   1.10356175  7.85761507 ...  6.22617378  6.25767307
 11.67515168]
[-61705.85288764  -4373.82809809  -2191.38264212  -1461.47750028
  -1096.25423584   -877.05752367]
1.2039605670265132


In [29]:
djinv = 1 / np.sum(np.square(X), axis = 0)
f_obj, f_bgrad, f_wgrad, f_s2grad = flib_penmrash.objective_gradients(X, y, b, sigma, wk, sk, djinv)
print (f_obj)
print (f_bgrad)
print (f_wgrad)
print (f_s2grad)

1965.4559045114556
[-8.2179048   1.10356175  7.85761507 ...  6.22617378  6.25767307
 11.67515168]
[-61705.85288764  -4373.82809809  -2191.38264212  -1461.47750028
  -1096.25423584   -877.05752367]
1.2039605670339597


In [11]:
lml, lml_bd, lml_wd, lml_s2d, lml_bd_bd, lml_bd_wd, lml_bd_s2d \
    = flib.normal_means_ash_scaled.normal_means_ash_lml(b, sigma, wk, sk, djinv)

In [12]:
dj = np.sum(np.square(X), axis = 0)
nmash = NormalMeansASHScaled(b, sigma, wk, sk, d = dj)

In [13]:
print (nmash.logML_deriv_s2deriv[b!=0])
print (lml_bd_s2d[b!=0])

[-0.02834827  0.01346295 -0.00345402  0.05447382  0.04534753]
[-0.02834827  0.01346295 -0.00345402  0.05447382  0.04534753]


In [14]:
bvar = np.square(sigma) * djinv
f_mb, f_mb_bgrad, f_mb_wgrad, f_mb_s2grad \
    = flib.plr_mrash.plr_shrinkage_operator(b, bvar, djinv,
                                            lml_bd,lml_bd_bd,lml_bd_wd,lml_bd_s2d)
mb, mb_bgrad, mb_wgrad, mb_s2grad = pmash.shrinkage_operator(nmash)

In [15]:
print (mb_bgrad[b != 0])
print (f_mb_bgrad[b != 0])

[0.9975574  0.99744715 0.99741918 0.99811912 0.99783109]
[0.9975574  0.99744715 0.99741918 0.99811912 0.99783109]


In [16]:
f_lambdaj, f_l_bgrad, f_l_wgrad, f_l_s2grad \
    = flib.plr_mrash.plr_penalty_operator(b, bvar, djinv, 
                                          lml,lml_bd,lml_wd,lml_s2d,
                                          lml_bd_bd,lml_bd_wd,lml_bd_s2d)
lambdaj, l_bgrad, l_wgrad, l_s2grad = pmash.penalty_operator(nmash)

In [17]:
print(l_wgrad[l_wgrad != 0])
print(f_l_wgrad[f_l_wgrad != 0])

[-61892.0279618   -4373.82027989  -2191.3824997   -1461.48009754
  -1096.25703017   -877.06009271]
[-61892.0279618   -4373.82027989  -2191.3824997   -1461.48009754
  -1096.25703017   -877.06009271]


In [18]:
sigma2 = np.square(sigma)
r = y - np.dot(X, mb)
rTr = np.sum(np.square(r))
obj = (0.5 * rTr / sigma2) + np.sum(lambdaj)
obj += 0.5 * (X.shape[0] - X.shape[1]) * (np.log(2. * np.pi) + np.log(sigma2))
print(obj)

1965.2547899418032


In [19]:
sigma2

3.166077616860522

In [18]:
print (flib.plr_mrash.plr_penalty_operator.__doc__)

lambdaj,l_bgrad,l_wgrad,l_s2grad = plr_penalty_operator(b,bvar,djinv,lml,lml_bd,lml_wd,lml_s2d,lml_bd_bd,lml_bd_wd,lml_bd_s2d)

Wrapper for ``plr_penalty_operator``.

Parameters
----------
b : input rank-1 array('d') with bounds (f2py_b_d0)
bvar : input rank-1 array('d') with bounds (f2py_bvar_d0)
djinv : input rank-1 array('d') with bounds (f2py_djinv_d0)
lml : input rank-1 array('d') with bounds (f2py_lml_d0)
lml_bd : input rank-1 array('d') with bounds (f2py_lml_bd_d0)
lml_wd : input rank-2 array('d') with bounds (f2py_lml_wd_d0,f2py_lml_wd_d1)
lml_s2d : input rank-1 array('d') with bounds (f2py_lml_s2d_d0)
lml_bd_bd : input rank-1 array('d') with bounds (f2py_lml_bd_bd_d0)
lml_bd_wd : input rank-2 array('d') with bounds (f2py_lml_bd_wd_d0,f2py_lml_bd_wd_d1)
lml_bd_s2d : input rank-1 array('d') with bounds (f2py_lml_bd_s2d_d0)

Returns
-------
lambdaj : rank-1 array('d') with bounds (size(b))
l_bgrad : rank-1 array('d') with bounds (size(b))
l_wgrad : rank-1 array('d') with bounds (

In [44]:
sigma2 = np.square(sigma)
sk2    = np.square(sk).reshape(1, -1)
v2pk   = sk2 + (1 / dj.reshape(-1, 1))

In [47]:
logljk0,logljk1,logljk2 \
    = flib.normal_means_ash_scaled.calculate_logljk(b,sigma2,v2pk)

In [52]:
logljk0

array([[ 2.48232943, -0.16932302, -0.86060104, -1.26571924, -1.55327984,
        -1.77636715],
       [ 2.48232943, -0.16932302, -0.86060104, -1.26571924, -1.55327984,
        -1.77636715],
       [ 2.48232943, -0.16932302, -0.86060104, -1.26571924, -1.55327984,
        -1.77636715],
       ...,
       [ 2.48232943, -0.16932302, -0.86060104, -1.26571924, -1.55327984,
        -1.77636715],
       [ 2.48232943, -0.16932302, -0.86060104, -1.26571924, -1.55327984,
        -1.77636715],
       [ 2.48232943, -0.16932302, -0.86060104, -1.26571924, -1.55327984,
        -1.77636715]])

In [53]:
nmash.logLjk(derive = 0)

array([[-0.16682925, -0.51340284, -0.97154821, -1.3181218 , -1.58343592,
        -1.79587752],
       [-0.16682925, -0.51340284, -0.97154821, -1.3181218 , -1.58343592,
        -1.79587752],
       [-0.16682925, -0.51340284, -0.97154821, -1.3181218 , -1.58343592,
        -1.79587752],
       ...,
       [-0.16682925, -0.51340284, -0.97154821, -1.3181218 , -1.58343592,
        -1.79587752],
       [-0.16682925, -0.51340284, -0.97154821, -1.3181218 , -1.58343592,
        -1.79587752],
       [-0.16682925, -0.51340284, -0.97154821, -1.3181218 , -1.58343592,
        -1.79587752]])

In [62]:
pmash._dj

array([200., 200., 200., ..., 200., 200., 200.])

In [63]:
dj

array([[200.],
       [200.],
       [200.],
       ...,
       [200.],
       [200.],
       [200.]])