## Question 2
Implement from scratch a sampling method to draw samples from a multivariate Normal (MVN) distribution in JAX. [10 Marks]

- Your code should work for any number of dimensions but please set the number of dimensions (random variables of MVN) to 10 for this task.
- You are only allowed to use jax.random.uniform. You are especially not allowed to use jax.random.normal.
- You should randomly create the mean and covariance matrix to fully specify an MVN distribution.
- Implement a sampling method from scratch using which you can draw samples from the specified MVN distribution.
- Use your sampling method to draw multiple samples from the MVN distribution and reconstruct the parameters of your MVN distribution (mean and covariance matrix) to confirm that your sampling method is working correctly.


---




## Approach Used

To generate samples $\mathbf{x} \sim \mathcal{N}(\mathbf{m}, K)$ with arbitrary mean $\mathbf{m}$ and covariance matrix $K$ using a scalar Gaussian generator (which is readily available in many programming environments) we proceed as follows: first, compute the Cholesky decomposition (also known as the "matrix square root") $L$ of the positive definite symmetric covariance matrix $K=L L^{\top}$, where $L$ is a lower triangular matrix. Then generate $\mathbf{u} \sim \mathcal{N}(\mathbf{0}, I)$ by multiple separate calls to the scalar Gaussian generator. Compute $\mathbf{x}=\mathbf{m}+L \mathbf{u}$, which has the desired distribution with mean $\mathrm{m}$ and covariance $L \mathbb{E}\left[\mathbf{u u}^{\top}\right] L^{\top}=L L^{\top}=K$ (by the independence of the elements of $\mathbf{u}$ ).

In practice it may be necessary to add a small multiple of the identity matrix $\epsilon I$ to the covariance matrix for numerical reasons. This is because the eigenvalues of the matrix $K$ can decay very rapidly and without this stabilization the Cholesky decomposition fails. The effect on the generated samples is to add additional independent noise of variance $\epsilon$. From the context $\epsilon$ can usually be chosen to have inconsequential effects on the samples, while ensuring numerical stability.

Refrence : <a href="http://gaussianprocess.org/gpml/chapters/RWA.pdf">Gaussian Process for Machine Learning | Appendix A | Gaussian Identities</a>



In [1]:
import jax.numpy as jnp
from jax.numpy import linalg as JLA
import jax

## Declare Dimension, Means , Covariance Matrix


Let $M$ be an $n \times n$ Hermitian matrix (this includes real symmetric matrices).$M$ is positive definite if and only if all of its eigenvalues are positive.

Refrence : <a href="https://en.wikipedia.org/wiki/">Wikipedia | Definite matrix</a>


In [2]:
# # of Random Variables
d = 10

# epsilon to stabalize cholesky
epsilon = 0.1

# Radom Mean
mean = jax.random.uniform(jax.random.PRNGKey(0), shape=(d, 1), minval=0, maxval=1, dtype=jnp.float32)

# Random matrix for covariance
mat = jax.random.uniform(jax.random.PRNGKey(2), shape=(d, d), minval=0, maxval=1, dtype=jnp.float32)

# Make a symmetric matrix
sym_mat = (mat+mat.T)/2

# Make matrix positive definite
cov = jnp.dot(sym_mat, sym_mat.T)

# Add identity matrix for statbility
cov = cov + epsilon*jnp.identity(d)




### Verify if Covariance Matrix is Positive Definite

In [3]:
# Calculate Eigen values,all of them must be positive
cov_eignvals = JLA.eigvals(cov)
if(jnp.any(cov_eignvals < 0)):
    print("The covariance matrix is not positive definite, generate again")


## Cholesky decomposition

In linear algebra, the Cholesky decomposition or Cholesky factorization is a decomposition of a Hermitian, positive-definite matrix into the product of a lower triangular matrix and its conjugate transpose.

Refrences:

- <a href="https://www.youtube.com/watch?v=xloCwiVDkho&list=PLYdroRCLMg5MgczmIkeY_XVJiJ5LVDFuh&index=44">Linear Algebra: Cholesky Decomposition</a>

- <a href="https://www.youtube.com/watch?v=NppyUqgQqd0&list=PLYdroRCLMg5MgczmIkeY_XVJiJ5LVDFuh&index=44">Linear Algebra: Cholesky Decomposition Example</a>

- <a href="https://ocw.mit.edu/courses/10-34-numerical-methods-applied-to-chemical-engineering-fall-2005/resources/lecturenotes142/"> Cholesky MIT OCW</a>

In [4]:
def Chol_decomp(mat):
    '''
        Perform Cholesky decomposition of input matrix and return L i.e
        lower triangular matrix

        Args:
            mat - (d, d) JAX Numpy array containing covariance matrix
                  (d is the number of random variables/dimensions)
        Returns:
            L - (d, d) JAX Numpy array containing the lower trainagular matrix
    '''

    shape = cov.shape[0]

    # Declare a dxd zero matrix
    L = jnp.zeros((d, d))

    # Loop over covariance matrix
    for i in range(shape):
        for j in range(i + 1):
            sum = 0
            # For Variance values
            if (j == i):
                for k in range(j):
                    sum += (L[j][k])**2
                L = L.at[i, j].set(jnp.sqrt(mat[i][j] - sum))
            # For covariace values
            else:
                for k in range(j):
                    sum += L[i][k] * L[j][k]
                if(L[j][j] > 0):
                    L = L.at[i, j].set((mat[i][j] - sum) / (L[j][j]))
    return L


### Validate Cholsky Decomposition

**NOTE**: If the Decomposition did not work correctly.You might want to check the following:
- If you have rounded of the values while comparing covariance matrix $K$ with $L L^{\top}$.
- If rounding is not the issue, you might want to increase the value of the epsilon, this will ensure that Cholesky decomposition works correctly. Increasing the value of epsilon will increase the variance since we add $\epsilon I$ to the covariance matrix $K$ to stabilize cholesky hence you must also increase the sample_size while calling the gen_std_MVN(<a href="https://ocw.mit.edu/courses/res-6-012-introduction-to-probability-spring-2018/resources/the-weak-law-of-large-numbers/"> Weak Law of Large Numbers </a>) to ensure that the covariance and mean obtained after sampling are close to the original values. Example values `d=30,epsilon=10,sample_size=100000`

In [5]:
L = Chol_decomp(cov)

# Verify if the dot product of L and L transpose return covariance matrix
decomp_val = jnp.round(cov, 4) == jnp.round(jnp.dot(L, L.T), 4)
if(jnp.any(decomp_val == False)):
    print("Seems like the Decomposition did not work correctly.")
else:
    print("Decomposition worked Correctly")


Decomposition worked Correctly


## Generating Standard Normal Using CLT

Suppose $X_{1}, X_{2}, \ldots, X_{n}$ are independent random variables with the same underlying distribution. In this case, we say that the $X_{i}$ are independent and identically-distributed, or i.i.d. In particular, the $X_{i}$ all have the same mean $\mu$ and standard deviation $\sigma$.
Let $\bar{X}_{n}$ be the average of $X_{1}, \ldots, X_{n}$ :
$$
\bar{X}_{n}=\frac{X_{1}+X_{2}+\cdots+X_{n}}{n}=\frac{1}{n} \sum_{i=1}^{n} X_{i}
$$
Note that $\bar{X}_{n}$ is itself a random variable. The law of large numbers and central limit theorem tell us about the value and distribution of $\bar{X}_{n}$, respectively.

CLT: As $n$ grows, the distribution of $\bar{X}_{n}$ converges to the normal distribution $N\left(\mu, \sigma^{2} / n\right)$.

Refrences:

- <a href="https://ocw.mit.edu/courses/18-05-introduction-to-probability-and-statistics-spring-2014/resources/mit18_05s14_reading6b/">MIT OCW | Introduction to Probability and Statistics</a>

- <a href="https://ocw.mit.edu/courses/6-041-probabilistic-systems-analysis-and-applied-probability-fall-2010/resources/mit6_041f10_l19/"> MIT OCW | Probabilistic Systems Analysis and Applied Probability</a>

In [6]:
def gen_std_MVN(sample_size=10000, d=10):
    '''
        Returns a standard Multivariate gaussian based on the dimesions/
        number of random variables and sample size. This is based on the
        idea of Central Limit theorem

        Args:
            d -  integer | the number of random variables
            sample_size - integer | # of samples to be added to generate a gaussian

        Returns:
            std_MVN - (d, sample_size) JAX Numpy array containing the standard Multivariate
                       gaussian matrix
    '''

    std_MVN = jnp.zeros((d, sample_size))

    for i in range(d):

        # Draw samples from Uniform Distribution
        unf = jax.random.uniform(jax.random.PRNGKey(i), shape=(sample_size, d), minval=0, maxval=100, dtype=jnp.float32)
        # Add them to create a scalar normal distribution
        normal = jnp.sum(unf, axis=1)
        # Covert to standard Normal
        std_normal = (normal - jnp.mean(normal)) / jnp.sqrt(jnp.var(normal))
        std_MVN = std_MVN.at[i, :].set(std_normal)

    return std_MVN


In [7]:
std_MVN= gen_std_MVN(d=d)
samples = mean + jnp.dot(L, std_MVN)

### Validating the results

In [8]:
absolut_diff_mean = jnp.sum(jnp.abs(jnp.mean(samples, axis=1) - mean.squeeze()))
absolut_diff_cov = jnp.sum(jnp.abs(jnp.cov(samples) - cov))

In [9]:
print(f"The absolute difference between all the values of the original mean vs the mean obtained from sampling is {absolut_diff_mean}")

The absolute difference between all the values of the original mean vs the mean obtained from sampling is 1.6391277313232422e-06


In [10]:
print(f"The average of absolute difference between all the values of the original covariance vs the covariance obtained from sampling is {absolut_diff_cov/(d*d)}")

The average of absolute difference between all the values of the original covariance vs the covariance obtained from sampling is 0.012451081536710262


In [11]:
print("Original Mean \n",mean.squeeze())
print("\n")
print("Sample Mean \n",jnp.mean(samples,axis=1))

Original Mean 
 [0.35490513 0.60419905 0.4275843  0.23061597 0.32985854 0.43953657
 0.25099766 0.27730572 0.7678207  0.71474564]


Sample Mean 
 [0.35490483 0.60419923 0.42758447 0.230616   0.32985836 0.43953642
 0.2509974  0.27730566 0.76782095 0.7147457 ]


In [12]:
print("Original Covariance Matrix \n",cov)
print("\n")
print("Sample Covariance Matrix \n",jnp.cov(samples))

Original Covariance Matrix 
 [[2.8252301 2.163516  1.8826021 2.1715758 2.6237981 2.8255095 2.7453787
  2.6538415 1.9432937 2.3962338]
 [2.163516  2.6559076 2.0175126 2.2711544 2.4282575 2.3687296 2.4298837
  2.4429803 1.9467937 2.2675002]
 [1.8826021 2.0175126 2.778164  2.2047598 2.5771775 2.5989413 2.5783846
  2.4079242 1.9919269 2.4645398]
 [2.1715758 2.2711544 2.2047598 2.5554626 2.573891  2.616433  2.574296
  2.430259  1.9815549 2.3292205]
 [2.6237981 2.4282575 2.5771775 2.573891  3.2553551 3.0281265 3.1051197
  2.992762  2.156797  2.6898634]
 [2.8255095 2.3687296 2.5989413 2.616433  3.0281265 3.5964012 3.1709008
  2.9781177 2.277362  2.7988725]
 [2.7453787 2.4298837 2.5783846 2.574296  3.1051197 3.1709008 3.5501003
  3.177404  2.230322  3.0325356]
 [2.6538415 2.4429803 2.4079242 2.430259  2.992762  2.9781177 3.177404
  3.294318  2.2810152 2.6716766]
 [1.9432937 1.9467937 1.9919269 1.9815549 2.156797  2.277362  2.230322
  2.2810152 2.1277618 2.0248706]
 [2.3962338 2.2675002 2.46453