# Prototype of VEM in M&M ASH model
This is the VEM part of M&M ASH model.

In [None]:
dat = readRDS('/home/gaow/Documents/GTExV8/Thyroid.Lung.FMO2.filled.rds')
str(dat)
attach(dat)

In [None]:
%get X Y --from R
import numpy as np
from scipy.stats import multivariate_normal

## Data preview

In [None]:
X

In [None]:
Y

## Initial values

We need to initialize M&M model with effect size estimates and a version of MASH results. We can use:

1. LD-convoluted effect estimate for mash model, to learn $\pi_t$ corresponding to $V_t$
2. Use multivariate LASSO to get the ordering of initialization of $X$, and initial $X$ values

### Multivariate LASSO 
For initialization of effects.
FIXME: not working yet.

In [None]:
from sklearn import linear_model
import numpy as np
model = linear_model.lasso_path(X, Y, 0.1, fit_intercept = False)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.scatter(model.coef_[0], model.coef_[1], cmap="viridis")
ax = plt.gca()
plt.show()

In [None]:
max(model.coef_[1])

### MASH priors and weights
FIXME: not working yet

## Initialization

In [None]:
Y = Y.as_matrix()
B = np.zeros((X.shape[1], Y.shape[1])) * np.nan
test = 1
kk  = 0
tt = 0
t = 100
pi = np.array([0.2,0.8])
V = np.ones((2, Y.shape[1], Y.shape[1]))
Lambda = np.diag(np.ones(Y.shape[1]))
iLambda = np.linalg.inv(Lambda)

## Core updates

In [None]:
# X is N by K matrix, X_norm is 1 by K vector of L2 norm of column k's
X_norm = np.linalg.norm(X, ord = 2, axis = 0)
X_std = X / X_norm
if test:
    np.testing.assert_array_almost_equal(np.sum(np.square(X_std), axis = 0), np.ones(X.shape[1]))
# B is M by K matrix
R_all = Y - X @ B
R = np.ones((X.shape[1], Y.shape[0], Y.shape[1])) * np.nan
E = np.ones((Y.shape[1], X.shape[1])) * np.nan
Rx = np.ones((Y.shape[1], X.shape[1])) * np.nan
wdE = np.ones((X.shape[1], len(pi))) * np.nan
Sigma = {tt: np.ones((X.shape[1], Y.shape[1], Y.shape[1])) * np.nan for tt in range(t)}
Mu = np.ones((X.shape[1], Y.shape[1], t)) * np.nan
gamma = np.ones((X.shape[1], t)) * np.nan
# K is number of effects
for kk in range(X.shape[1]):
    # R[kk] is N x M, where M is number of conditions
    R[kk,:,:] = R_all + np.outer(X[:,kk], (B[kk,:].T))
    # E[kk] is M x 1
    E[:,kk] = (R[kk,:,:].T @ X_std[:,kk]).ravel()
    # Rx[kk] is M x 1
    Rx[:,kk] = (R[kk,:,:].T @ X[:,kk]).ravel()
    for tt in range(len(pi)):
        wdE[kk,tt] = multivariate_normal.pdf(E[:,kk], np.zeros(Y.shape[1]), V[tt] + iLambda / X_norm[kk]) * pi[tt]
    wdE_sum = sum(wdE[kk,:])
    for tt in range(t):
        # Can be made faster via low-rank approximation
        Sigma[tt][kk,:,:] = np.linalg.inv(np.identity(Y.shape[1]) + V[tt] * X_norm[kk] @ Lambda) @ V[tt]
        Mu[kk,:,tt] = Sigma[tt][kk,:,:] @ Lambda @ Rx[:,kk]
        gamma[kk,tt] = wdE[kk,tt] / wdE_sum
    # dot product for weighted sums
    B[kk,:] = Mu[kk,:,:] @ gamma[kk,:]
# Recalculate h, a M by M matrix
H = np.diag(np.sum([np.sum([gamma[kk,tt] * ((Mu[kk,:,tt] - B[kk,:]) ** 2 + np.diag(Sigma[tt][kk,:,:])) for tt in range(t)], axis = 0) for kk in range(X.shape[1])], axis = 0))
# Update Lambda
Delta = np.diag(R_all.T @ R_all) + H
Lambda = X.shape[0] / np.diag(Delta)