# Implement from scratch a sampling method to draw samples from a multivariate Normal (MVN) distribution in JAX.

In [136]:
import numpy as np
import jax.numpy as jnp
from jax import random
import math

In [137]:
# Change dimensions here
dim = 10

**Generating positive definite matrix Sigma(covarience matrix) by multiplying a Square Matrix with it's own Transpose since by doing `A @ A.T` all the eigen values will be positive and `A @ A.T` is also symmetric hence it will be positive definite.**

In [138]:
key = random.PRNGKey(2)
y = random.uniform(key,(dim,dim),minval = 1,maxval = 10)
sigma = jnp.dot(y,y.T)

In [139]:
# generating mean randomly

mean = random.uniform(key,(dim,1))
mean.shape

(10, 1)

**Here we are doing cholesky decomposition of the Sigma(covarience matrix) which decomposes Sigma into `L @ L.T` where L is a Lower Triangular matrix.We are using cholesky here because say we have `Z ~ N(0,I)` and if we do affine Transformation of Z into `C = A + B @ Z` the resultant `C ~ N(A, B @ B.T)` but here we want B @ B.T to be our Sigma so we are taking `B = L`(lower triangular matrix obtained from cholesky decomposition)**

In [140]:
L = jnp.linalg.cholesky(sigma)

In [141]:
# number of samples

num_samples = 5000 * dim

**Generating uniform random variables using random.uniform**

In [142]:
U1 = random.uniform(key,(num_samples,1))

In [143]:
U2 = random.uniform(key,(num_samples,1))

**Since it was asked not to use random.normal we are generating standard normal samples from uniform distribution. We can do it in multiple ways by using CLT(Central Limit Theorom) or using Box-muller transformation,etc. I used Box-muller transformation here which takes the Uniform random variables and maps them to polar from to make them normal random variables which can be then made into standard normal by subtracting mean and dividing by standard deviation**

In [144]:
X1 = jnp.sqrt(-2*(math.pi)* jnp.log(U1)) * jnp.cos(2*(math.pi)*U2)

In [145]:
X2 = jnp.sqrt(-2*(math.pi)* jnp.log(U1)) * jnp.sin(2*(math.pi)*U2)

In [146]:
X = (X1 - jnp.mean(X1))/jnp.std(X1)
X.shape

(50000, 1)

In [147]:
X = jnp.reshape(X,(dim,num_samples//dim))
X.shape

(10, 5000)

**Affine transformation of the generated standard normal rv X to give us samples from MVN**

In [148]:
Y = L @ X + mean

In [149]:
cov = jnp.cov(Y)

**Checking the euclidian norm between the original Sigma matrix and the sigma of our sample**

In [150]:
jnp.linalg.norm(sigma - cov)

DeviceArray(31.951365, dtype=float32)

**Checking error in mean using euclidean norm but here the norm is not taken as a function, made from scratch**

In [151]:
mean_samp = jnp.mean(Y,axis = 1)

In [152]:
temp = jnp.mean(Y,axis = 1).reshape(dim,1) - mean

In [153]:
sum = 0
for i in range(dim):
    sum = sum + temp[i]**2
print(f"error in mean is {math.sqrt(sum)}")

error in mean is 0.2766456818354105


In [154]:
""" References:
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform """

' References:\nhttps://jax.readthedocs.io/en/latest/notebooks/quickstart.html\nhttps://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform '

**Note: We can also use gibbs sampling in this case to sample from MVN**