# First implementation of our model

Have a look at the model [here](https://github.com/commons-research/common_dws_public_storage/blob/main/docs/anticipated_lotus/model/Metabolites.pdf). 

* Simulate $P_m$ and $Q_s$ from a Poisson distribution
* Simulate $\sigma_{tf}$ using a Wishart distribution

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import pandas as pd
import scipy.special

### Choose the number of molecules and species

In [2]:
T = ['m', 's']
n_t = [1000, 10]

In [3]:
assert len(T)==len(n_t)

### Create some $\mu$, $\alpha$ and $\beta$ for each molecule and and each species

In [4]:
def simulate_from_prior(T, n_t):
    mu = {T[i]: np.random.normal(loc=0, scale=1, size=n_t[i]).astype("float16") for i in range(len(T))}
    
    alpha = {T[i]: np.random.exponential(scale=2, size=1).astype("float16") for i in range(len(T))}
    
    beta = {T[i]: np.random.exponential(scale=3, size=n_t[i]).astype("float16") for i in range(len(T))}
    
    sigma = {T[i]: stats.wishart.rvs(df=n_t[i],
                                     scale=1/n_t[i]*np.eye(n_t[i]),
                                     size=1).astype("float16") for i in range(len(T))}
    return mu, alpha, beta, sigma

In [5]:
def simulate_data(T, n_t):
    return np.random.binomial(n=1, p=0.1, size=[i for i in n_t]).astype("float16")

In [6]:
μ, α, β, σ= simulate_from_prior(T, n_t)

In [7]:
x = simulate_data(T, n_t)

In [8]:
γ = np.random.exponential(scale=1, size=1)

In [9]:
δ = np.random.exponential(scale=0.1, size=1)

Change to 1+ Poisson --> We shouldn't change to 1 + Poisson. I fact, the Lotus database will very likely be smaller than the true data **x**. This is why some molecules or some species might not have any research papers yet. 

In [10]:
#P_m = np.random.poisson(0.5, size=n_t[0])

In [11]:
L = np.random.poisson(0.01, size=[i for i in n_t]).astype("float16")

In [12]:
Q_s = L.sum(axis=0)

In [13]:
P_m = L.sum(axis=1)

In [14]:
#Q_s = 1 + np.random.poisson(0.5, size=n_t[1])

In [15]:
def research_effort(γ, δ, P_m, Q_s):
    return 1 - np.exp(-γ * P_m[:, None] - δ * Q_s[None, :])

In [16]:
#R = research_effort(γ, δ, P_m, Q_s)

In [17]:
#look at the condition in the model to calculate the proba of L_sm
def compute_prob_of_L(L, x, γ, δ, P_m, Q_s):
    prob_L = np.zeros(L.shape)
    R_ms = 1 - np.exp(-γ * P_m[:, None] - δ * Q_s[None, :], dtype="float16")
    
    condition_2 = (x == 0) & (L == 0)
    condition_3 = np.where((x == 1) & (L >= 1))
    condition_4 = np.where((x == 1) & (L == 0))
    
    prob_L[condition_2] = 1
    prob_L[condition_3] = R_ms[condition_3]
    prob_L[condition_4] = 1 - R_ms[condition_4]
    
    return prob_L.astype("float16")

In [18]:
prob_L = compute_prob_of_L(L, x, γ, δ, P_m, Q_s)

In [19]:
#df = pd.DataFrame(np.array([L.flatten(), prob_L.flatten()]))

## Calculate proba of **$x$**

In [20]:
def prob_X(mu):
    #extract each array of the mu dictionary
    arrays = list(mu.values())
    
    #create a meash of each combination of each value
    mesh = np.meshgrid(*arrays, indexing='ij')

    sum_combinations = np.sum(mesh, axis=0).astype("float16")
    return scipy.special.expit(sum_combinations)

In [21]:
prob_X(μ).shape

(1000, 10)

In [22]:
def prob_sigma(alpha, sigma):
    multiplied_matrices = {key: np.ndarray.flatten(alpha[key] * sigma[key]).astype('float16') for key in alpha.keys()}
    
    matrices = list(multiplied_matrices.values())
    
    mesh = np.meshgrid(*matrices, indexing='ij')
    return np.sum(mesh, axis=0)

In [23]:
cov_mat = prob_sigma(alpha=α, sigma=σ)

In [24]:
α

{'m': array([2.094], dtype=float16), 's': array([1.19], dtype=float16)}

In [25]:
σ['m'][0,1]

-0.02852

In [31]:
pd.DataFrame(prob_X(μ))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,0.745028,0.231614,0.116362,0.614741,0.335895,0.194674,0.882428,0.201185,0.535949,0.141994
1,0.936981,0.605333,0.401453,0.890294,0.720368,0.551574,0.974483,0.561696,0.854642,0.457137
2,0.946399,0.645120,0.442877,0.905990,0.753285,0.593138,0.978385,0.602997,0.874507,0.499512
3,0.730290,0.218169,0.108756,0.596433,0.319121,0.182863,0.874185,0.189071,0.516839,0.132852
4,0.807764,0.302288,0.159217,0.696475,0.421135,0.257952,0.915205,0.265880,0.624179,0.192236
...,...,...,...,...,...,...,...,...,...,...
995,0.950595,0.664758,0.464416,0.913058,0.769080,0.613931,0.980129,0.623606,0.883638,0.521227
996,0.925903,0.562898,0.360264,0.872130,0.683842,0.508056,0.969785,0.518302,0.831554,0.414188
997,0.826991,0.330039,0.177241,0.722918,0.452657,0.283186,0.924552,0.291586,0.653790,0.212886
998,0.795817,0.286668,0.149407,0.680348,0.402862,0.243642,0.909103,0.251280,0.606382,0.180684


In [32]:
pd.DataFrame(cov_mat)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,90,91,92,93,94,95,96,97,98,99
0,3.273438,2.111328,1.753906,2.468750,1.759766,2.339844,2.187500,2.285156,1.981445,1.811523,...,1.811523,2.228516,2.695312,2.601562,2.128906,1.921875,2.406250,2.164062,2.441406,3.339844
1,0.988281,-0.173828,-0.530762,0.184692,-0.524902,0.054657,-0.097656,0.001678,-0.303223,-0.473145,...,-0.473145,-0.055756,0.411377,0.316650,-0.156372,-0.362061,0.121460,-0.120850,0.157227,1.056641
2,1.123047,-0.039185,-0.396240,0.319336,-0.390381,0.189209,0.036987,0.136230,-0.168701,-0.338379,...,-0.338379,0.078857,0.545898,0.451416,-0.021729,-0.227295,0.256104,0.013763,0.291748,1.191406
3,0.982422,-0.179688,-0.537109,0.178711,-0.531250,0.048706,-0.103577,-0.004272,-0.309082,-0.479004,...,-0.479004,-0.061707,0.405518,0.310791,-0.162354,-0.367920,0.115479,-0.126831,0.151245,1.050781
4,1.118164,-0.043457,-0.400635,0.314941,-0.394775,0.185059,0.032715,0.132080,-0.172852,-0.342773,...,-0.342773,0.074585,0.541992,0.447021,-0.026001,-0.231689,0.251709,0.009491,0.287598,1.186523
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
999995,1.096680,-0.064819,-0.421875,0.293701,-0.416016,0.163574,0.011322,0.110596,-0.194336,-0.364014,...,-0.364014,0.053192,0.520508,0.425781,-0.047394,-0.252930,0.230347,-0.011902,0.266113,1.165039
999996,1.068359,-0.093567,-0.450684,0.264893,-0.444824,0.134888,-0.017395,0.081909,-0.223022,-0.392822,...,-0.392822,0.024490,0.491699,0.396973,-0.076111,-0.281738,0.201660,-0.040619,0.237427,1.136719
999997,1.034180,-0.127319,-0.484375,0.231201,-0.478516,0.101196,-0.051117,0.048187,-0.256836,-0.426514,...,-0.426514,-0.009232,0.458008,0.363281,-0.109863,-0.315430,0.167969,-0.074341,0.203735,1.102539
999998,1.064453,-0.097656,-0.454834,0.260742,-0.448975,0.130859,-0.021454,0.077881,-0.227051,-0.396973,...,-0.396973,0.020432,0.487549,0.392822,-0.080200,-0.285889,0.197632,-0.044678,0.233398,1.132812
