# Multivariate Gaussian

- [numpy positive semi-definite warning](https://stackoverflow.com/questions/41515522/numpy-positive-semi-definite-warning)

## Sample data

In [34]:
import pandas as pd
import numpy as np

np.random.seed(37)
size = 1_000

a = np.random.normal(1, 1, size)
b = np.random.normal(1 + 2 * a, 1)
c = np.random.normal(1 + 3 * b, 1)

df = pd.DataFrame({'a': a, 'b': b, 'c': c})
df.shape

(1000, 3)

In [35]:
df.head()

Unnamed: 0,a,b,c
0,0.945536,2.231118,7.747877
1,1.674308,4.276228,12.709732
2,1.346647,2.705018,8.896767
3,-0.300346,0.909744,1.78162
4,2.518512,6.479523,21.004581


## Compute mean and covariance

In [36]:
m = df.mean()
c = df.cov()
s = np.random.multivariate_normal(m, c, size=size)

In [37]:
m.values

array([ 1.01277839,  3.00947207, 10.01831793])

In [13]:
c.values

array([[ 0.9634615 ,  1.92320946,  5.75834371],
       [ 1.92320946,  4.81764836, 14.37726705],
       [ 5.75834371, 14.37726705, 43.87527426]])

In [40]:
pd.DataFrame(s, columns=['a', 'b', 'c']).mean()

a    0.969306
b    2.930238
c    9.780896
dtype: float64

In [41]:
pd.DataFrame(s, columns=['a', 'b', 'c']).cov()

Unnamed: 0,a,b,c
a,1.009078,1.995169,6.004501
b,1.995169,4.961893,14.870129
c,6.004501,14.870129,45.498812


## Assertion

In [26]:
m = pd.Series([1.5, 3.009472, 10.018318], ['a', 'b', 'c'])
c = pd.DataFrame([
    [ 0.0,  1.92320946,  5.75834371],
    [ 1.92320946,  4.81764836, 14.37726705],
    [ 5.75834371, 14.37726705, 43.87527426]
])

In [27]:
min_eig = np.min(np.real(np.linalg.eigvals(c)))
min_eig

-0.7480182977029629

In [28]:
if min_eig < 0:
    c -= 10*min_eig * np.eye(*c.shape)
c

Unnamed: 0,0,1,2
0,7.480183,1.923209,5.758344
1,1.923209,12.297831,14.377267
2,5.758344,14.377267,51.355457


In [30]:
s = np.random.multivariate_normal(m, c, size=size)

In [31]:
pd.DataFrame(s, columns=['a', 'b', 'c']).cov()

Unnamed: 0,a,b,c
a,7.651128,1.810444,5.699311
b,1.810444,11.553545,12.799748
c,5.699311,12.799748,51.446776


In [32]:
pd.DataFrame(s, columns=['a', 'b', 'c']).mean()

a     1.508072
b     3.043604
c    10.272000
dtype: float64