# Prototype of VEM in M&M ASH model
This is the VEM step of M&M ASH model.

$\newcommand{\bs}[1]{\boldsymbol{#1}}$
$\DeclareMathOperator*{\diag}{diag}$
$\DeclareMathOperator*{\cov}{cov}$
$\DeclareMathOperator*{\rank}{rank}$
$\DeclareMathOperator*{\var}{var}$
$\DeclareMathOperator*{\tr}{tr}$
$\DeclareMathOperator*{\veco}{vec}$
$\DeclareMathOperator*{\uniform}{\mathcal{U}niform}$
$\DeclareMathOperator*{\argmin}{arg\ min}$
$\DeclareMathOperator*{\argmax}{arg\ max}$
$\DeclareMathOperator*{\N}{N}$
$\DeclareMathOperator*{\gm}{Gamma}$
$\DeclareMathOperator*{\dif}{d}$

## M&M ASH model

We assume the following multivariate, multiple regression model with $N$ samples, $J$ effects and $R$ conditions (and **without covariates, for the time being**)
\begin{align}
\bs{Y}_{N\times R} = \bs{X}_{N \times J}\bs{B}_{J \times R} + \bs{E}_{N \times R},
\end{align}
where
\begin{align}
\bs{E} &\sim \N_{N \times R}(\bs{0}, \bs{I}_N, \bs{\Lambda}^{-1}),\\
\bs{\Lambda} &= \diag(\lambda_1,\ldots,\lambda_R).
\end{align}

We assume true effects $\bs{b}_j$ (rows of $\bs{B}$) are iid with prior distribution of mixtures of multivariate normals

$$p(\bs{b}_j) = \sum_{t = 0}^T\pi_t\N_R(\bs{b}_j | \bs{0}, \bs{V}_t),$$

Where the $\bs{V}_t$'s are $R \times R$ covariance matrices and the $\pi_t$'s are their weights.

We place Gamma prior on $\bs{\Lambda}$

$$\lambda_r \overset{iid}{\sim} \gm(\alpha, \beta),$$

and set $\alpha = \beta = 0$ so that it is equivalent to estimating $\bs{\Lambda}$ via maximum likelihood.

We can augment the prior of $\bs{b}_j$ by indicator vector $\bs{w}_j \in \mathbb{R}^T$ for membership of $\bs{b}_j$ into one of the $T$ mixture groups. The densities involved are

\begin{align}
p(\bs{Y},\bs{B},\bs{W},\bs{\Lambda}) &= p(\bs{Y}|\bs{B}, \bs{\Lambda})p(\bs{B}|\bs{W})p(\bs{W})p(\bs{\Lambda}), \\
p(\bs{Y}|\bs{B}, \bs{\Lambda}) &= N_{N \times R}(\bs{X}\bs{B}, \bs{I}_N, \bs{\Lambda}^{-1}), \\
p(\lambda_r|\alpha,\beta) &= \frac{\beta^{\alpha}}{\Gamma(\alpha)}\lambda_r^{\alpha - 1}\exp\{-\beta\lambda_r\}, \\
p(\bs{b}_j|\bs{w}_j) &= \prod_{t = 0}^T\left[\N(\bs{b}_j|\bs{0},\bs{V}_t)\right]^{w_{jt}},\\
p(\bs{w}_j) &= \prod_{t = 0}^{T} \pi_t^{w_{jt}}.
\end{align}

**We assume $V_t$'s and their corresponding $\pi_t$'s are known. In practice we use `mashr` to estimate these quantities and provide them to M&M.**

### Variational approximation to densities

We seek a variational approximation based on

\begin{align}
q(\bs{B}, \bs{W}, \bs{\Lambda}) = q(\bs{\Lambda})\prod_{j = 1}^{J}q(\bs{b}_j,\bs{w}_j),
\end{align}

so that we can maximize over $q$ the following lower bound of the marginal log-likelihood

\begin{align}
\log p(\bs{Y}) \geq \mathcal{L}(q) = \int q(\bs{B}, \bs{W}, \bs{\Lambda}) \log\left\{\frac{p(\bs{Y},\bs{B},\bs{W},\bs{\Lambda})}{q(\bs{B}, \bs{W}, \bs{\Lambda})}\right\}\dif\bs{B}\dif\bs{W}\dif\bs{\Lambda},
\end{align}

Gao & Wei have previously developed [a version that assumes $\Lambda = I_R$](https://github.com/gaow/mvarbvs/blob/master/writeup/identity_cov/mnmash.pdf). This version generalized it to a diagonal matrix with Gamma priors. [David has developed a version](https://www.overleaf.com/11985539jvwgjhrqnrry#/45465793/) that assumes a diagonal plus low rank structure -- the model of that version is a bit different from shown here, and will be prototyped later after this version works.

### Core updates

The complete derivation of updates are documented elsewhere (in the two PDF write-ups whose links are shown above). Here I document core updates to guide implementation of the algorithm.

Let $E[\bs{R}_{-j}] := \bs{Y} - \bs{X}\bs{\mu}_{\bs{B}} + \bs{x}_j\bs{\mu}_{\bs{B}[j, ]}^{\intercal}$, then


\begin{align}
\bs{\xi}_j &= E\left[\bs{R}_{-j}\right]^{\intercal}\bs{x}_j\|\bs{x}_j\|^{-2}, \\
\bs{\Sigma}_{jt} &= \left(\bs{V}_t^{-1} + \|\bs{x}_j\|^2\bs{\Lambda}\right)^{-1}, \\
\bs{\mu}_{jt} &= \bs{\Sigma}_{jt}\bs{\Lambda}E\left[\bs{R}_{-k}\right]^{\intercal}\bs{x}_j, \\
\gamma_{jt} &= \frac{\pi_t\N(\bs{\xi}_j|\bs{0}, \bs{V}_t + \bs{\Lambda}^{-1}\|\bs{x}_j\|^{-2})}{\sum_{t = 0}^T\pi_t\N(\bs{\xi}_j|\bs{0}, \bs{V}_t + \bs{\Lambda}^{-1}\|\bs{x}_j\|^{-2})},\\
\bs{\mu}_{\bs{B}[j, ]}  &= \sum_{t = 0}^T \gamma_{jt}\bs{\mu}_{jt}
\end{align}

In [1]:
dat = readRDS('/home/gaow/Documents/GTExV8/Thyroid.Lung.FMO2.filled.rds')
str(dat)
attach(dat)

List of 2
 $ Y:'data.frame':	698 obs. of  2 variables:
  ..$ Thyroid: num [1:698] 0.163 0.436 -0.212 0.327 -0.698 ...
  ..$ Lung   : num [1:698] 0.77011 0.77799 -0.65361 0.00672 -0.36792 ...
 $ X: num [1:698, 1:7492] 1 0 0 0 0 1 1 0 1 1 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:698] "GTEX-111CU" "GTEX-111FC" "GTEX-111VG" "GTEX-111YS" ...
  .. ..$ : chr [1:7492] "chr1_170185243_G_A_b38" "chr1_170185272_T_C_b38" "chr1_170185405_C_A_b38" "chr1_170185417_G_A_b38" ...


In [2]:
%get X Y --from R
import numpy as np
from scipy.stats import multivariate_normal

Loading required package: feather


## Data preview

In [4]:
X

array([[ 1.,  0.,  0., ...,  0.,  1.,  0.],
       [ 0.,  0.,  0., ...,  0.,  1.,  0.],
       [ 0.,  1.,  0., ...,  0.,  1.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  1.,  0.],
       [ 0.,  0.,  0., ...,  0.,  1.,  0.]])

In [5]:
Y

Unnamed: 0,Thyroid,Lung
GTEX-111CU,0.163481,0.770109
GTEX-111FC,0.435890,0.777987
GTEX-111VG,-0.212373,-0.653612
GTEX-111YS,0.326649,0.006720
GTEX-1122O,-0.697831,-0.367918
GTEX-1128S,0.840123,-0.097376
GTEX-113JC,1.286902,-0.003500
GTEX-117XS,0.115426,-0.003500
GTEX-117YW,0.763787,-0.300018
GTEX-117YX,-0.329763,-0.379263


In [3]:
Y = Y.as_matrix()

## Initial values

We need to initialize M&M model with effect size estimates and a version of MASH results. We can use:

1. LD-convoluted effect estimate for mash model, to learn $\pi_t$ corresponding to $V_t$
2. Use multivariate LASSO to get the ordering of initialization of $X$, and initial $X$ values

### Multivariate LASSO 
For initialization of effects.
FIXME: not working yet.

In [None]:
from sklearn import linear_model
import numpy as np
model = linear_model.lasso_path(X, Y, 0.1, fit_intercept = False)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.scatter(model.coef_[0], model.coef_[1], cmap="viridis")
ax = plt.gca()
plt.show()

In [None]:
max(model.coef_[1])

### MASH priors and weights
FIXME: not working yet

## Initialization

## Core updates

In [128]:
class MNMASH:
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y        
        self.B = np.zeros((X.shape[1], Y.shape[1]))
        # self.V = np.ones((2, Y.shape[1], Y.shape[1]))
        self.V = np.array([np.identity(2)])
        self.pi = np.random.uniform(0,1,self.V.shape[0])
        self.pi = self.pi / sum(self.pi)
        self.Lambda = np.identity(Y.shape[1]))
        self.tol = 1E-4
        self.debug = 1
        self.maxiter = 1
        # initialize intermediate variables
        self.R = np.ones((self.X.shape[1], self.Y.shape[0], self.Y.shape[1])) * np.nan
        self.E = np.ones((self.Y.shape[1], self.X.shape[1])) * np.nan
        self.Rx = np.ones((self.Y.shape[1], self.X.shape[1])) * np.nan
        self.wdE = np.ones((self.X.shape[1], self.V.shape[0])) * np.nan
        self.Sigma = {tt: np.ones((self.X.shape[1], self.Y.shape[1], self.Y.shape[1])) * np.nan for tt in range(self.V.shape[0])}
        self.Mu = np.ones((self.X.shape[1], self.Y.shape[1], self.V.shape[0])) * np.nan
        self.gamma = np.ones((self.X.shape[1], self.V.shape[0])) * np.nan
        # X is N by K matrix, X_norm is 1 by K vector of L2 norm of column k's
        self.X_norm = np.linalg.norm(self.X, ord = 2, axis = 0)
        if self.debug:
            self.X_std = self.X / self.X_norm
            np.testing.assert_array_almost_equal(np.sum(np.square(self.X_std), axis = 0), np.ones(self.X.shape[1]))
        self.X_norm = np.square(self.X_norm)
        self.X_std = self.X / self.X_norm
        self.R_all = None
        self.H = None
        self.Delta = None

    def update(self):
        '''
        core update
        '''
        iLambda = np.linalg.inv(self.Lambda)
        # B is M by K matrix
        self.R_all = self.Y - self.X @ self.B
        # K is number of effects
        for kk in range(self.X.shape[1]):
            # R[kk] is N x M, where M is number of conditions
            self.R[kk,:,:] = self.R_all + np.outer(self.X[:,kk], (self.B[kk,:].T))
            # E[kk] is M x 1
            self.E[:,kk] = (self.R[kk,:,:].T @ self.X_std[:,kk]).ravel()
            # Rx[kk] is M x 1
            self.Rx[:,kk] = (self.R[kk,:,:].T @ self.X[:,kk]).ravel()
            for tt in range(self.V.shape[0]):
                self.wdE[kk,tt] = multivariate_normal.pdf(self.E[:,kk], np.zeros(self.Y.shape[1]), \
                                                          self.V[tt] + iLambda / self.X_norm[kk]) * self.pi[tt]
            wdE_sum = sum(self.wdE[kk,:])
            for tt in range(self.V.shape[0]):
                # Can be made faster via low-rank approximation
                self.Sigma[tt][kk,:,:] = np.linalg.inv(np.identity(self.Y.shape[1]) + self.V[tt] * self.X_norm[kk] @ self.Lambda) @ self.V[tt]
                self.Mu[kk,:,tt] = self.Sigma[tt][kk,:,:] @ self.Lambda @ self.Rx[:,kk]
                self.gamma[kk,tt] = self.wdE[kk,tt] / wdE_sum
            # dot product for weighted sums
            self.B[kk,:] = self.Mu[kk,:,:] @ self.gamma[kk,:]
        # Recalculate h, a M by M matrix
        self.H = np.diag(np.sum([np.sum([self.gamma[kk,tt] * ((self.Mu[kk,:,tt] - self.B[kk,:]) ** 2 + np.diag(self.Sigma[tt][kk,:,:])) for tt in range(self.V.shape[0])], axis = 0) for kk in range(self.X.shape[1])], axis = 0))
        # Update Lambda
        self.Delta = np.diag(self.R_all.T @ self.R_all) + self.H
        self.Lambda = np.diag(self.X.shape[0] / np.diag(self.Delta))

    def vem(self):
        cnt = 0
        while cnt < self.maxiter:
            self.update()
            cnt += 1

In [172]:
res = MNMASH(X,Y)

In [174]:
res.update()



In [175]:
res.Rx

array([[ -105.68247929,    17.97861755,   204.49481699, ...,
         -150.96948701,  -616.95091084,  -182.08212114],
       [ 5599.28482416,  1402.5326597 ,  2071.44767636, ...,
         2153.99886232,  7952.26995702,  2152.01292885]])

In [176]:
res.Y

array([[ 0.16348104,  0.77010917],
       [ 0.43588995,  0.77798736],
       [-0.21237311, -0.65361193],
       ..., 
       [ 0.62036618, -0.0035004 ],
       [ 0.00279156, -0.05439095],
       [-0.14650835,  0.29935286]])

In [177]:
res.E

array([[ -0.41189173,   0.29964363,   2.06530046, ...,  -1.55582062,
         -1.22096704,  -1.85730112],
       [ 21.82290883,  23.37554433,  20.920637  , ...,  22.19810052,
         15.73781539,  21.95128224]])

In [178]:
res.R

array([[[ -4.8891728 ,  17.50965015],
        [-12.53027764,   5.74081502],
        [ 35.10013698,  30.31447758],
        ..., 
        [ -7.41623605,  33.55498201],
        [-15.46802748,  16.12465266],
        [ 13.34749681,  23.83955985]],

       [[ -4.86830939,  17.53081006],
        [-12.53027764,   5.74081502],
        [ 35.03991555,  30.26511094],
        ..., 
        [ -7.41623605,  33.55498201],
        [-15.46802748,  16.12465266],
        [ 13.34749681,  23.83955985]],

       [[ -4.86830939,  17.53081006],
        [-12.53027764,   5.74081502],
        [ 35.10013698,  30.31447758],
        ..., 
        [ -7.41623605,  33.55498201],
        [-15.46802748,  16.12465266],
        [ 13.34749681,  23.83955985]],

       ..., 
       [[ -4.86830939,  17.53081006],
        [-12.53027764,   5.74081502],
        [ 35.10013698,  30.31447758],
        ..., 
        [ -7.41623605,  33.55498201],
        [-15.46802748,  16.12465266],
        [ 13.34749681,  23.83955985]],

       [[ -

In [179]:
res.Sigma[0][0] @ res.R[0].T @ res.X[:,0]

array([-0.17460838,  5.72910948])

In [182]:
res.Mu

array([[[ -0.4112112 ],
        [ 21.80057994]],

       [[  0.29753795],
        [ 23.2736074 ]],

       [[  2.05648133],
        [ 20.86525817]],

       ..., 
       [[ -1.54904212],
        [ 22.1381448 ]],

       [[ -1.21994188],
        [ 15.72963468]],

       [[ -1.84929135],
        [ 21.89259667]]])

In [136]:
res.gamma

array([[ 1.],
       [ 1.],
       [ 1.],
       ..., 
       [ 1.],
       [ 1.],
       [ 1.]])

In [137]:
res.Lambda

array([[ 2.35504843,  0.        ],
       [ 0.        ,  3.80523012]])

In [138]:
res.Mu

array([[[-0.02086341],
        [-0.02115991]],

       [[-0.06022143],
        [-0.04936664]],

       [[-0.04123255],
        [-0.06044189]],

       ..., 
       [[ 0.03684816],
        [-0.00270303]],

       [[-0.00668126],
        [ 0.01637876]],

       [[ 0.04042619],
        [ 0.00057867]]])

In [139]:
res.B

array([[-0.02086341, -0.02115991],
       [-0.06022143, -0.04936664],
       [-0.04123255, -0.06044189],
       ..., 
       [ 0.03684816, -0.00270303],
       [-0.00668126,  0.01637876],
       [ 0.04042619,  0.00057867]])

In [140]:
res.H

array([[ 106.11214038,    0.        ],
       [   0.        ,  106.11214038]])

In [141]:
res.B

array([[-0.02086341, -0.02115991],
       [-0.06022143, -0.04936664],
       [-0.04123255, -0.06044189],
       ..., 
       [ 0.03684816, -0.00270303],
       [-0.00668126,  0.01637876],
       [ 0.04042619,  0.00057867]])

In [142]:
res.X @ res.B

array([[  5.03179042, -16.76070089],
       [ 12.96616759,  -4.96282766],
       [-35.31251009, -30.96808951],
       ..., 
       [  8.03660223, -33.55848241],
       [ 15.47081903, -16.17904361],
       [-13.49400516, -23.54020699]])

In [143]:
res.Mu

array([[[-0.02086341],
        [-0.02115991]],

       [[-0.06022143],
        [-0.04936664]],

       [[-0.04123255],
        [-0.06044189]],

       ..., 
       [[ 0.03684816],
        [-0.00270303]],

       [[-0.00668126],
        [ 0.01637876]],

       [[ 0.04042619],
        [ 0.00057867]]])

In [144]:
res.Y - res.X @ res.B

array([[ -4.86830939,  17.53081006],
       [-12.53027764,   5.74081502],
       [ 35.10013698,  30.31447758],
       ..., 
       [ -7.41623605,  33.55498201],
       [-15.46802748,  16.12465266],
       [ 13.34749681,  23.83955985]])

In [145]:
res.pi

array([ 1.])

In [146]:
res.X.shape

(698, 7492)

In [148]:
res.B.shape

(7492, 2)

In [149]:
res.Lambda

array([[ 2.35504843,  0.        ],
       [ 0.        ,  3.80523012]])

In [158]:
res.X_norm

array([ 256.57829887,   60.        ,   99.01456041, ...,   97.03527832,
        505.29693985,   98.03586438])

In [171]:
np.sum(np.square(res.X[:,0]))

256.57829886674881

In [183]:
np.min(res.Mu)

-26.731661026046595

In [186]:
res.B

array([[ -0.4112112 ,  21.80057994],
       [  0.29753795,  23.2736074 ],
       [  2.05648133,  20.86525817],
       ..., 
       [ -1.54904212,  22.1381448 ],
       [ -1.21994188,  15.72963468],
       [ -1.84929135,  21.89259667]])

In [187]:
res.Mu

array([[[ -0.4112112 ],
        [ 21.80057994]],

       [[  0.29753795],
        [ 23.2736074 ]],

       [[  2.05648133],
        [ 20.86525817]],

       ..., 
       [[ -1.54904212],
        [ 22.1381448 ]],

       [[ -1.21994188],
        [ 15.72963468]],

       [[ -1.84929135],
        [ 21.89259667]]])

In [188]:
np.min(res.gamma)

nan

In [190]:
np.min(res.Rx)

-5915.7879239429385

In [191]:
res.Lambda

array([[ nan,   0.],
       [  0.,  nan]])

In [192]:
res.Delta

array([[             nan,  468225.13784254],
       [ 249377.18134286,              nan]])

In [193]:
res.H

array([[ nan,   0.],
       [  0.,  nan]])

In [195]:
np.min(res.Sigma[0])

0.0