Run the FORTRAN code in `mr-ash-pen/src/mrashpen/flibs` and compare results.
```
gfortran env_precision.f90 global_parameters.f90 futils.f90 normal_means_ash_scaled.f90 plr_mrash.f90 main.f90 -lblas -llapack -o runtest
./runtest
```

In [1]:
import numpy as np
from mrashpen.models.plr_ash import PenalizedMrASH as PenMrASH
from mrashpen.models.normal_means_ash_scaled import NormalMeansASHScaled
from mrashpen.models.normal_means_ash import NormalMeansASH

In [2]:
n = 5
p = 6
k = 3
std = 0.9
b  = np.array([1.21, 2.32, 0.01, 0.03, 0.11, 3.12])
sk = np.array([0.1, 0.5, 0.9])
wk = np.array([0.5, 0.25, 0.25])
y  = np.array([3.5, 4.5, 1.2, 6.5, 2.8])
XT = np.array([8.79, 6.11,-9.15, 9.57,-3.49, 9.84,
               9.93, 6.91,-7.93, 1.64, 4.02, 0.15, 
               9.83, 5.04, 4.86, 8.83, 9.80,-8.99,
               5.45,-0.27, 4.85, 0.74,10.00,-6.02,
               3.16, 7.98, 3.01, 5.80, 4.27,-5.31])
X = XT.reshape(p, n).T

In [3]:
X

array([[ 8.79,  9.84,  4.02,  8.83,  4.85,  7.98],
       [ 6.11,  9.93,  0.15,  9.8 ,  0.74,  3.01],
       [-9.15,  6.91,  9.83, -8.99, 10.  ,  5.8 ],
       [ 9.57, -7.93,  5.04,  5.45, -6.02,  4.27],
       [-3.49,  1.64,  4.86, -0.27,  3.16, -5.31]])

In [4]:
pmash = PenMrASH(X, y, b, std, wk, sk, debug = True, is_prior_scaled = False)
pmash_scaled = PenMrASH(X, y, b, std, wk, sk, debug = True, is_prior_scaled = True)
nmash = NormalMeansASH(b, np.sqrt(np.square(std) / pmash._dj), wk, sk)
nmash_scaled = NormalMeansASHScaled(b, std, wk, sk, d = pmash._dj)

In [5]:
bgrad, wgrad, s2grad = pmash.gradients

2022-04-19 18:10:04,778 | mrashpen.models.plr_ash | DEBUG | Calculating PLR objective with sigma2 = 0.81


In [6]:
bgrad_scaled, wgrad_scaled, s2grad_scaled = pmash_scaled.gradients

2022-04-19 18:10:06,512 | mrashpen.models.plr_ash | DEBUG | Calculating PLR objective with sigma2 = 0.81


In [7]:
mrashpen_res = dict()
mrashpen_res['plrash_scaled'] = {
    'objective' : pmash_scaled.objective,
    'bgrad' : bgrad_scaled,
    'wgrad' : wgrad_scaled,
    's2grad' : s2grad_scaled,
}
mrashpen_res['plrash'] = {
    'objective' : pmash.objective,
    'bgrad' : bgrad,
    'wgrad' : wgrad,
    's2grad' : s2grad,
}

In [8]:

for i, nm in enumerate([nmash, nmash_scaled]):
    # nmash
    if i == 0:
        key  = 'nmash'
        fact = np.ones(p)
        dj   = pmash._dj
        mb, mb_bgrad, mb_wgrad, mb_s2grad = pmash.shrinkage_operator(nm)
        lj, lj_bgrad, lj_wgrad, lj_s2grad = pmash.penalty_operator(nm)
    else:
        key  = 'nmash_scaled'
        dj   = pmash_scaled._dj
        fact = pmash_scaled._dj
        mb, mb_bgrad, mb_wgrad, mb_s2grad = pmash_scaled.shrinkage_operator(nm)
        lj, lj_bgrad, lj_wgrad, lj_s2grad = pmash_scaled.penalty_operator(nm)
        
    mrashpen_res[key] = {
        'logML': nm.logML,
        'logML_deriv': nm.logML_deriv,
        'logML_wderiv': nm.logML_wderiv,
        'logML_s2deriv': nm.logML_s2deriv * fact,
        'logML_deriv2': nm.logML_deriv2,
        'logML_deriv_wderiv': nm.logML_deriv_wderiv,
        'logML_deriv_s2deriv': nm.logML_deriv_s2deriv * fact,
        'shrinkage_mb': mb,
        'shrinkage_mb_bgrad': mb_bgrad,
        'shrinkage_mb_wgrad': mb_wgrad,
        'shrinkage_mb_s2grad': mb_s2grad * dj,
        'penalty_lj': lj,
        'penalty_lj_bgrad': lj_bgrad,
        'penalty_lj_wgrad': lj_wgrad,
        'penalty_lj_s2grad': lj_s2grad * dj,
    }

In [9]:
mrashpen_res

{'plrash_scaled': {'objective': 3205.58796066751,
  'bgrad': array([ 693.89890522, 1257.60334498,  293.85995256,  615.7151167 ,
          440.13054884,  956.56182369]),
  'wgrad': array([-14.83246902,   6.95343333,  -1.28849528]),
  's2grad': -3956.0004809984725},
 'plrash': {'objective': 3213.7913940545036,
  'bgrad': array([ 694.29363183, 1260.18101292,  311.65275926,  640.59877728,
          450.23955438,  958.69539922]),
  'wgrad': array([-13.23088186,   4.27533301,  -1.81356928]),
  's2grad': -4009.8783665317305},
 'nmash': {'logML': array([-2.88430004, -5.51207019,  0.65799283,  0.69708953,  0.33103137,
         -8.1729519 ]),
  'logML_deriv': array([-2.13525576, -2.8623001 , -0.56486873, -1.99020565, -5.90464463,
         -3.82683969]),
  'logML_wderiv': array([[5.39152820e-024, 7.83524957e-001, 3.21647504e+000],
         [2.27379652e-090, 4.64506480e-003, 3.99535494e+000],
         [1.68104736e+000, 4.09060715e-001, 2.28844557e-001],
         [1.69265949e+000, 3.94427245e-001, 

In [12]:
import pickle
mfile = open("../../../gradvi/src/gradvi/tests/mrashpen_res.pkl", "wb")
pickle.dump(mrashpen_res, mfile)
mfile.close()

In [10]:
pmash._dj

array([302.0837, 308.7531, 161.833 , 284.6044, 170.2961, 152.8095])

In [11]:
nmash = NormalMeansASHScaled(b, std, wk, sk, d = pmash._dj)
#nmash = NormalMeansASH(b, np.sqrt(np.square(std) / pmash._dj), wk, sk * std)

In [12]:
nmash.logML

array([-3.06487518, -6.18177123,  0.73109578,  0.77572404,  0.35992288,
       -9.45746262])

In [13]:
nmash.logML_deriv

array([-2.37739752, -3.52362732, -0.64357177, -2.32604672, -6.65298927,
       -4.71726323])

In [14]:
nmash.logML_deriv2

array([-6.17181043e-02, -1.50584974e+00, -6.42845767e+01, -7.66156168e+01,
       -4.85715642e+01, -1.51192336e+00])

In [15]:
nmash.logML_wderiv

array([[2.67457588e-028, 5.32617455e-001, 3.46738254e+000],
       [1.89767986e-106, 8.54495380e-004, 3.99914550e+000],
       [1.67117175e+000, 4.21483043e-001, 2.36173456e-001],
       [1.68474594e+000, 4.04410563e-001, 2.26097563e-001],
       [1.53343844e+000, 5.93835426e-001, 3.39287698e-001],
       [8.07215632e-154, 7.53656745e-007, 3.99999925e+000]])

In [17]:
nmash.logML_s2deriv * pmash._dj

array([ 349.94177629, 1367.43610432,  -99.25400525, -163.42241115,
        -28.19043594, 1293.96266755])

In [25]:
nmash.logML_deriv_wderiv

array([[-2.93810651e-026, -1.87471963e+000,  1.87471963e+000],
       [-4.03872856e-104, -6.65364326e-003,  6.65364326e-003],
       [-1.99682410e-001,  2.50942656e-001,  1.48422165e-001],
       [-6.98607730e-001,  8.81595565e-001,  5.15619895e-001],
       [-2.91820458e+000,  3.63560642e+000,  2.20080274e+000],
       [-1.84130802e-151, -7.76049536e-006,  7.76049537e-006]])

In [26]:
nmash.logML_deriv_s2deriv

array([-3.28861113e-01,  5.26400175e+00,  5.22794812e+01,  2.25004059e+02,
        5.11523760e+02,  7.13204724e+00])

In [19]:
nmash.log_sum_wkLjk(nmash.logLjk(derive = 1))

array([-1.47055059, -4.84490893,  5.81448277,  6.04539061,  5.3812026 ,
       -8.12512828])

In [11]:
pmash._dj

array([302.0837, 308.7531, 161.833 , 284.6044, 170.2961, 152.8095])

In [12]:
def softmax(x, base = np.exp(1)):
    if base is not None:
        beta = np.log(base)
        x = x * beta
    e_x = np.exp(x - np.max(x))
    return e_x / np.sum(e_x, axis = 0, keepdims = True)

smlogbase = 1.0
ak = np.log(wk) / smlogbase
print(ak)

[-0.69314718 -1.38629436 -1.38629436]


In [13]:
softmax(ak, base = np.exp(smlogbase))

array([0.5 , 0.25, 0.25])

In [14]:
akjac = smlogbase * wk.reshape(-1, 1) * (np.eye(k) - wk)
akjac

array([[ 0.25  , -0.125 , -0.125 ],
       [-0.125 ,  0.1875, -0.0625],
       [-0.125 , -0.0625,  0.1875]])

In [15]:
agrad  = np.sum(wgrad * akjac, axis = 1)
agrad

array([-4.41623451,  3.23835833,  1.17787618])

In [16]:
wgrad * akjac

array([[-3.70811726, -0.86917917,  0.16106191],
       [ 1.85405863,  1.30376875,  0.08053096],
       [ 1.85405863, -0.43458958, -0.24159287]])