# Task 2 : Implement from scratch a sampling method to draw samples from a multivariate Normal (MVN) distribution in JAX

* Your code should work for any number of dimensions but please set the number of dimensions (random variables of MVN) to 10 for this task.
* You are only allowed to use jax.random.uniform. You are especially not allowed to use jax.random.normal.
* You should randomly create the mean and covariance matrix to fully specify an MVN distribution.
* Implement a sampling method from scratch using which you can draw samples from the specified MVN distribution.
* Use your sampling method to draw multiple samples from the MVN distribution and reconstruct the parameters of your MVN distribution (mean and covariance matrix) to confirm that your sampling method is working correctly.


In [1]:
import jax.numpy as jnp
import jax.random as random
key = random.PRNGKey(23)



In [3]:
dimension = 10
mean_vector = jnp.array([1,2,3,4,5,6,7,8,9,10])
K_0 = jnp.array([[1,0,0,0,0,0,0,0,0,0],
       [0,1,0,0,0,0,0,0,0,0],
       [0,0,1,0,0,0,0,0,0,0],
       [0,0,0,1,0,0,0,0,0,0],
       [0,0,0,0,1,0,0,0,0,0],
       [0,0,0,0,0,1,0,0,0,0],
       [0,0,0,0,0,0,1,0,0,0],
       [0,0,0,0,0,0,0,1,0,0],
       [0,0,0,0,0,0,0,0,1,0],
       [0,0,0,0,0,0,0,0,0,1]])

K_0 , mean_vector.reshape(10,1)
epsilon = 0.0001
K = K_0 + epsilon*jnp.identity(dimension)
L = jnp.linalg.cholesky(K)
jnp.dot(L, jnp.transpose(L))
n = 10000
u = random.uniform(key,shape=(dimension, n),minval=-3 , maxval=3)
x = mean_vector + jnp.dot(L, u).T

In [4]:
key = random.PRNGKey(67)
cov = jnp.array([[1.2, 0.4], [0.4, 1.0]])
mean = jnp.array([3,-1])
x1 = random.multivariate_normal(key, mean, cov, (10000,)).T
x1

DeviceArray([[ 2.8958414 ,  3.3322747 ,  4.738451  , ...,  2.310788  ,
               3.790872  ,  3.5540874 ],
             [-0.38591176,  1.1789055 ,  1.1619446 , ..., -1.744822  ,
               0.53555226, -1.161852  ]], dtype=float32)