In [323]:
import jax.numpy as jnp
from jax.numpy import linalg as JLA
import jax
import numpy as np
import plotly.graph_objects as go
from numpy import linalg as LA

## Declare Dimension, Means , Covariance Matrix

In [324]:
d = 10
mean = jax.random.uniform(jax.random.PRNGKey(0), shape=(d,1), minval=0, maxval=10, dtype=jnp.float32)
mat = jax.random.uniform(jax.random.PRNGKey(2), shape=(d,d), minval=0, maxval=2, dtype=jnp.float32)
sym_mat = (mat+mat.T)/2
cov = jnp.dot(sym_mat,sym_mat.T)
cov = cov + d*jnp.identity(d)

### Verify if Covariance MAtrix is Positive Semidefinite

In [325]:
cov_eignvals = JLA.eigvals(cov)
if(np.any(cov_eignvals < 0)):
  print("The covariance matrix is not positive semidefinite, generate again")

In [326]:
def Chol_decomp(mat):

    shape = cov.shape[0]
 
    L = jnp.zeros((d,d))

    for i in range(shape):
        for j in range(i + 1):
            sum = 0
 
            # For Variance
            if (j == i):
                for k in range(j):
                    sum += (L[j][k])**2 
                L = L.at[i,j].set(jnp.sqrt(mat[i][j] - sum))
            # For covariace
            else:
                for k in range(j):
                    sum += L[i][k] * L[j][k]
                if(L[j][j] > 0):
                    L = L.at[i,j].set((mat[i][j] - sum) / (L[j][j]))

 
    return L

### Validate Cholsky Decomposition

In [327]:
L = Chol_decomp(cov)

decomp_val = jnp.round(cov,4) == jnp.round(jnp.dot(L,L.T),4)
if(jnp.any(decomp_val == False)):
  print("Seems like the Decomposition did not work correctly. Check if you have rounded of the values")
else:
  print("Decomposition worked Correctly")

Decomposition worked Correctly


In [688]:
def gen_std_MVN(sample_size=10000,d=10):

  std_MVN = jnp.zeros((sample_size,d))
  for i in range(d):

    unf = jax.random.uniform(jax.random.PRNGKey(i), shape=(d,sample_size), minval=0, maxval=100, dtype=jnp.float32)
    normal = jnp.sum(unf,axis=0)
    std_normal = (normal - jnp.mean(normal))/ jnp.sqrt(jnp.var(normal))
    std_MVN = std_MVN.at[:,i].set(std_normal)
    # print(std_normal)
    # print(samples)
    # print('----------------------------------------------')
    # print(jnp.mean(samples[:,i]),jnp.var(samples[:,i]))
    # print("______________________________________________")

  return std_MVN.T


def gen_std_MVN_1(sample_size=10000,d=10):

  std_MVN = jnp.zeros((d,sample_size))
  for i in range(d):

    unf = jax.random.uniform(jax.random.PRNGKey(i), shape=(sample_size,d), minval=0, maxval=100, dtype=jnp.float32)
    normal = jnp.sum(unf,axis=1)
    std_normal = (normal - jnp.mean(normal))/ jnp.sqrt(jnp.var(normal))
    std_MVN = std_MVN.at[i,:].set(std_normal)
    # print(std_normal)
    # print(samples)
    # print('----------------------------------------------')
    # print(jnp.mean(samples[i,:]),jnp.var(samples[i,:]))
    # print("______________________________________________")


  return std_MVN



In [715]:
std_MVN= gen_std_MVN()
samples = mean + jnp.dot(L,std_MVN)


std_MVN_1= gen_std_MVN_1()
samples_1 = mean + jnp.dot(L,std_MVN_1)

n = 10000
u = np.random.normal(loc=0, scale=1, size=10*n).reshape(10, n)
samples_2 = mean + jnp.dot(L,u)

In [690]:
mean

DeviceArray([[3.5490513],
             [6.0419903],
             [4.2758427],
             [2.3061597],
             [3.2985854],
             [4.3953657],
             [2.5099766],
             [2.7730572],
             [7.6782074],
             [7.147456 ]], dtype=float32)

In [707]:
jnp.sum(jnp.abs(jnp.mean(samples,axis=1) - mean.squeeze()))

DeviceArray(5.9604645e-06, dtype=float32)

In [708]:
jnp.sum(jnp.abs(jnp.mean(samples_1,axis=1) - mean.squeeze()))

DeviceArray(5.722046e-06, dtype=float32)

In [714]:
jnp.sum(jnp.abs(jnp.cov(samples) - cov)) 

DeviceArray(10.808604, dtype=float32)

In [713]:
jnp.sum(jnp.abs(jnp.cov(samples_1) - cov)) 

DeviceArray(14.157106, dtype=float32)

In [716]:
jnp.sum(jnp.abs(jnp.cov(samples_2) - cov)) 

DeviceArray(22.32653, dtype=float32)