# Animate bivariate normal distribution

**Using libraries like jax,math,matplotlib for sampling and plotting** 

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import random
import math
import jax
import matplotlib.pyplot as plt
from scipy.stats import norm
from ipywidgets import interact
from ipywidgets import interactive
from ipywidgets import fixed
from ipywidgets import interact_manual
import ipywidgets as widgets
%matplotlib inline
from mpl_toolkits import mplot3d

In [None]:
# number of samples can be changed here

num_samples = 1000

**Generating two standard normal random varaibles X1,X2 by using the function jax.random.multivariate_normal which takes mean and Sigma(covarience matrix) as inputs and generates samples from MVN**

In [None]:
# Initializing mean and Sigma(covarience matrix)

key = random.PRNGKey(2)
mu = jnp.array([0,0])
sigma = jnp.array([[1, 0.6],[0.6, 2]])

**Using the function to generate samples from MVN**

In [None]:
def sample_generator(mu,sigma):


    X1,X2 = jax.random.multivariate_normal(key+40,mu,sigma,
                                           shape = (num_samples,)).T
    L = jnp.linalg.cholesky(sigma)
    X1 = X1.reshape((num_samples,1))
    X2 = X2.reshape((num_samples,1))
    return X1,X2,L

**bivariate_plot is the plotting function which mainly uses matplotlib functions for plotting, we are generating the ellipse by multiplying with the cholesky decompostion of Sigma with the parametric coordinates of a circle since ellipse parametric points are (acost,bsint).Pdf curves of the rv's X1,X2 were found by using norm.pdf function from scipy.**

In [None]:
def bivariate_plot(mu,sigma,L,X1,X2):


    a = jnp.arange(-5,5,0.1)
    b = jnp.arange(-5,5,0.1)
    pdf_x = norm.pdf(np.array(a),mu[0],math.sqrt(sigma[0][0]))
    pdf_y = norm.pdf(np.array(b),mu[1],math.sqrt(sigma[1][1]))
    t = jnp.linspace(0,2*(math.pi),num = 100)
    C = jnp.array([jnp.cos(t),jnp.sin(t)])
    E = jnp.dot(3,jnp.dot(L,C))

    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('z')
    ax.scatter3D(xs = X1,ys = X2,zs = jnp.zeros((num_samples,)),
                 s = 2,c = 'black')
    ax.plot3D(E[0][:]+mu[0],E[1][:]+mu[1],jnp.zeros((100,)),'darkgreen')

    ax.plot3D(a,jnp.ones(len(a))*4.9,pdf_x)
    ax.plot3D(jnp.ones(len(b))*(-5),b,pdf_y,'red')
    ax.set_xlim3d(-5,5)
    ax.set_ylim3d(-5,5)
    ax.set()
    ax.grid()
    
    plt.show()

    return 

**Updating mean and Sigma obtained from user response and running the code with new mean and Sigma**

In [None]:
def change_mean_sigma(
    RV1_mean = 0,RV2_mean = 0,RV1_var = 1,
    RV2_var = 2,cov = 0.6):


    mu = jnp.array([RV1_mean,RV2_mean])
    sigma = jnp.array([[RV1_var,cov],[cov,RV2_var]])
    X1,X2,L = sample_generator(mu,sigma)
    bivariate_plot(mu,sigma,L,X1,X2)
    # print(RV1_mean,RV2_mean)
    return 

**Mean and Covarience can be changed by the interactive environment provided using ipywidgets. We know Sigma (covarience matrix) in this case is a 2x2 matrix. Let A be a 2x2 matrix then, A is positive definite iff A is symmetric, trace(A) > 0 and det(A)> 0. So here (cov^2) < (RV1_var*RV2_var) where RV1 is X1,RV2 is X2. If the conditions are violated Sigma won't be positive definite**

In [None]:
if __name__ == "__main__":
    
    interact(change_mean_sigma,RV1_mean = (-3.0,3.0),RV2_mean = (-3.0,3.0),RV1_var = (1.0,4.0),RV2_var=(1.01,4.0),cov = (0.0,1.0));

interactive(children=(FloatSlider(value=0.0, description='RV1_mean', max=3.0, min=-3.0), FloatSlider(value=0.0…

**Note: Instead of sampling using the function jax.random.multivariate_normal we can use our own sampling methods , for example gibbs sampling or L@X + mean where X ~ N(0,1) and still we can generate similar plots.**