Run the FORTRAN code in `mr-ash-pen/src/mrashpen/flibs` and compare results.
```
gfortran env_precision.f90 global_parameters.f90 futils.f90 normal_means_ash_scaled.f90 plr_mrash.f90 main.f90 -lblas -llapack -o runtest
./runtest
```

In [1]:
import numpy as np
from mrashpen.models.plr_ash import PenalizedMrASH as PenMrASH
from mrashpen.models.normal_means_ash_scaled import NormalMeansASHScaled

In [2]:
n = 5
p = 6
k = 3
std = 0.9
b  = np.array([1.21, 2.32, 0.01, 0.03, 0.11, 3.12])
sk = np.array([0.1, 0.5, 0.9])
wk = np.array([0.5, 0.25, 0.25])
y  = np.array([3.5, 4.5, 1.2, 6.5, 2.8])
XT = np.array([8.79, 6.11,-9.15, 9.57,-3.49, 9.84,
               9.93, 6.91,-7.93, 1.64, 4.02, 0.15, 
               9.83, 5.04, 4.86, 8.83, 9.80,-8.99,
               5.45,-0.27, 4.85, 0.74,10.00,-6.02,
               3.16, 7.98, 3.01, 5.80, 4.27,-5.31])
X = XT.reshape(p, n).T

In [3]:
X

array([[ 8.79,  9.84,  4.02,  8.83,  4.85,  7.98],
       [ 6.11,  9.93,  0.15,  9.8 ,  0.74,  3.01],
       [-9.15,  6.91,  9.83, -8.99, 10.  ,  5.8 ],
       [ 9.57, -7.93,  5.04,  5.45, -6.02,  4.27],
       [-3.49,  1.64,  4.86, -0.27,  3.16, -5.31]])

In [4]:
pmash = PenMrASH(X, y, b, std, wk, sk, debug = True, is_prior_scaled = True)

In [5]:
pmash.objective

2022-04-13 11:46:22,411 | mrashpen.models.plr_ash | DEBUG | Calculating PLR objective with sigma2 = 0.81


3205.58796066751

In [6]:
bgrad, wgrad, s2grad = pmash.gradients

In [7]:
bgrad

array([ 693.89890522, 1257.60334498,  293.85995256,  615.7151167 ,
        440.13054884,  956.56182369])

In [8]:
wgrad

array([-14.83246902,   6.95343333,  -1.28849528])

In [9]:
s2grad

-3956.0004809984725

In [10]:
nmash = NormalMeansASHScaled(b, std, wk, sk, d = pmash._dj)
nmash.logML

array([-3.06487518, -6.18177123,  0.73109578,  0.77572404,  0.35992288,
       -9.45746262])

In [17]:
nmash.logML_deriv

array([-2.37739752, -3.52362732, -0.64357177, -2.32604672, -6.65298927,
       -4.71726323])

In [19]:
nmash.log_sum_wkLjk(nmash.logLjk(derive = 1))

array([-1.47055059, -4.84490893,  5.81448277,  6.04539061,  5.3812026 ,
       -8.12512828])

In [11]:
pmash._dj

array([302.0837, 308.7531, 161.833 , 284.6044, 170.2961, 152.8095])

In [12]:
def softmax(x, base = np.exp(1)):
    if base is not None:
        beta = np.log(base)
        x = x * beta
    e_x = np.exp(x - np.max(x))
    return e_x / np.sum(e_x, axis = 0, keepdims = True)

smlogbase = 1.0
ak = np.log(wk) / smlogbase
print(ak)

[-0.69314718 -1.38629436 -1.38629436]


In [13]:
softmax(ak, base = np.exp(smlogbase))

array([0.5 , 0.25, 0.25])

In [14]:
akjac = smlogbase * wk.reshape(-1, 1) * (np.eye(k) - wk)
akjac

array([[ 0.25  , -0.125 , -0.125 ],
       [-0.125 ,  0.1875, -0.0625],
       [-0.125 , -0.0625,  0.1875]])

In [15]:
agrad  = np.sum(wgrad * akjac, axis = 1)
agrad

array([-4.41623451,  3.23835833,  1.17787618])

In [16]:
wgrad * akjac

array([[-3.70811726, -0.86917917,  0.16106191],
       [ 1.85405863,  1.30376875,  0.08053096],
       [ 1.85405863, -0.43458958, -0.24159287]])