Animate bivariate normal distribution.

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/8/8e/MultivariateNormal.png/330px-MultivariateNormal.png"></img>

- Reproduce the above figure showing samples from bivariate normal with marginal PDFs from scratch using JAX and matplotlib.
- Add interactivity to the figure by adding sliders with ipywidgets. You should be able to vary the parameters of bivariate normal distribution (mean and covariance matrix) using ipywidgets.

---



## Introduction

### Understanding Bivariate Normal Distribution

Let $U$ and $V$ be two independent normal random variables, and consider two new random variables $X$ and $Y$ of the form
$$
\begin{aligned}
&X=a U+b V \\
&Y=c U+d V
\end{aligned}
$$
where $a, b, c, d$, are some scalars. Each one of the random variables $X$ and $Y$ is normal, since it is a linear function of independent normal random variables.Furthermore, because $X$ and $Y$ are linear functions of the same two independent normal random variables, their joint PDF takes a special form, known as the **bivariate normal PDF**. 

Refrences:

- Bivariate Normal Distribution
    - http://athenasc.com/Bivariate-Normal.pdf
    - https://webspace.maths.qmul.ac.uk/i.goldsheid/MTH5118/Notes11-09.pdf
- <a href="https://en.wikipedia.org/wiki/Definite_matrix#Negative-definite.2C_semidefinite_and_indefinite_matrices">Definite Matrix<a/>

- <a href="https://stackoverflow.com/questions/619335/a-simple-algorithm-for-generating-positive-semidefinite-matrices">Creating a positive semidefinite matrix</a>
    
- <a href="https://www.cuemath.com/algebra/covariance-matrix/">Variance Covariance Matrix Info</a>
    
- Official documentation of JAX,plotly,ipywidgets and matplotlib

In [1]:
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display
import time
import jax
from jax.random import multivariate_normal
from jax import numpy as jnp
from jax.numpy import linalg as JLA
import random as rand
import plotly.graph_objects as go

In [2]:
# Add Sliders to change parameters
var_X = widgets.FloatSlider(description='var_X', min=1,
                            value=1.5, step=0.1, continuous_update=False)
var_Y = widgets.FloatSlider(description='var_Y', min=1,
                            value=1.0, step=0.1, continuous_update=False)
mean_X = widgets.FloatSlider(description='mean_X', value=0,
                             step=0.1, continuous_update=False)
mean_Y = widgets.FloatSlider(description='mean_Y', value=0,
                             step=0.1, continuous_update=False)
corr_XY = widgets.FloatSlider(description='corr_XY', min=-0.9999, max=0.9999, 
                              value=0.4, step=0.01, continuous_update=False)
sample_size = widgets.IntSlider(description='sample_size', min=100,
                                max=1000, continuous_update=False)

# Create a dict for easy access
slider_dict = {"var_X": var_X,
               "var_Y": var_Y,
               "mean_X": mean_X,
               "mean_Y": mean_Y,
               "cov_XY": corr_XY,
               "sample_size": sample_size}


## Generating Bivariate Normal Data usin JAX

$X$ and $Y$ are bivariately normally distributed with mean vector components $\mu_{1}$ and $\mu_{2}$ and variance-covariance matrix shown below:
$$
\left(\begin{array}{l}
X \\
Y
\end{array}\right) \sim N\left[\left(\begin{array}{l}
\mu_{1} \\
\mu_{2}
\end{array}\right),\left(\begin{array}{cc}
\sigma_{1}^{2} & \rho \sigma_{1} \sigma_{2} \\
\rho \sigma_{1} \sigma_{2} & \sigma_{2}^{2}
\end{array}\right)\right]
$$
In this case we have the variances for the two variables on the diagonal and on the off-diagonal we have the covariance between the two variables. This covariance is equal to the correlation($\rho$) times the product of the two standard deviations($\sigma_{1} and \ \sigma_{2}$) 

In [3]:
def generate_bivariate_norm():
    '''
        Sample from a Bivariate Normal distribution using the based on the
        covariance matrix and mean and return JAX numy arrays

        Args:
            None
        Returns:
            x1 - (sample_size, ) JAX Numpy array containing
                 values of a Random variable
            x2 - (sample_size, ) JAX Numpy array containing
                 values of a Random variable
    '''

    key = jax.random.PRNGKey(0)
    covariance = corr_XY.value*jnp.sqrt(var_X.value)*jnp.sqrt(var_Y.value)
    cov = jnp.array([[var_X.value, covariance], [covariance, var_Y.value]])
    mean = jnp.array([mean_X.value, mean_Y.value])

    # Ensuirng the variance-covariance matrix is positive definite
    # using eigen values, if not generate a new random positive
    # definite variance-covariance matrix

    # Calculate eigenvalues
    cov_eignvals = JLA.eigvals(cov)

    # Check Eigen values
    if(jnp.any(cov_eignvals <= 0)):
        print(f"The current covariance matrix {cov} is not positive definite hence a \
             new random covariance martrix will be created")
        rand_int = rand.randint(0, 10000)
        rand_mat = jax.random.uniform(jax.random.PRNGKey(rand_int),
                                      shape=(2, 2), minval=0.0, maxval=1.0)
        cov = jnp.dot(rand_mat, rand_mat.T)
        corr_XY.value = cov[:, -1][0]

    x1, x2 = multivariate_normal(key, mean, cov, (sample_size.value,)).T
    return x1, x2


## Marginal Distribution of random variables in Bivariate Normal Distribution

The marginal distributions of $N\left(\mu_{1}, \mu_{2}, \sigma_{1}^{2}, \sigma_{2}^{2}, \rho\right)$ are normal with r.v's $X$ and $Y$ having density functions
$$
f_{X}(x)=\frac{1}{\sqrt{2 \pi} \sigma_{1}} e^{-\frac{\left(x-\mu_{1}\right)^{2}}{2 \sigma_{1}^{2}}}, \quad f_{Y}(y)=\frac{1}{\sqrt{2 \pi} \sigma_{2}} e^{-\frac{\left(y-\mu_{2}\right)^{2}}{2 \sigma_{2}^{2}}}
$$


In [4]:
def marginal_PDF(x):
    '''
        Return Marginal Distribution of a bivariate random variable

        Args:
            None
        Returns:
            m_PDF - (sample_size, ) JAX Numpy array containing
                    marginal PDF of a Random variable
    '''
    mean = jnp.mean(x)
    var = jnp.var(x)
    std_dev = jnp.sqrt(var)
    m_PDF = (1/(2.506628274631*std_dev))*(1/jnp.exp(((x-mean)**2)/(2*var)))

    return m_PDF


## Graphing and Plotting Data

### Function to handle updates from ipywidgets

In [5]:
def response(v):
    '''
        Updates graph data based on interaactive slider values

        Args:
            v : Trialets objects, contains new values
        Returns:
            None
    '''

    # Generate a Bivariate Normal Distribution again
    x1, x2 = generate_bivariate_norm()

    # Batch update graph data
    with fig.batch_update():
        fig.data[0].x = x1
        fig.data[0].y = x2
        fig.data[0].z = sample_size.value*[0]

        fig.data[1].x = jnp.sort(x1)
        fig.data[1].y = sample_size.value*[jnp.min(x2)]
        fig.data[1].z = marginal_PDF(jnp.sort(x1))

        fig.data[2].x = sample_size.value*[jnp.min(x1)]
        fig.data[2].y = jnp.sort(x2)
        fig.data[2].z = marginal_PDF(jnp.sort(x2))


### Generate data for default graph

In [None]:
# Generate a Bivariate Normal Distribution using JAX
x1, x2 = generate_bivariate_norm()

# Plot for marginal PDF of rv x1
fig_mar_x1 = go.Scatter3d(
    x=jnp.sort(x1),
    y=sample_size.value*[jnp.min(x2)],
    z=marginal_PDF(jnp.sort(x1)),
    mode='lines',
    name="Marginal Dist of x",
    marker=dict(
        size=2,
        colorscale='Viridis',
        opacity=0.8
    ))

# Plot for marginal PDF of rv x2
fig_mar_x2 = go.Scatter3d(
    x=sample_size.value*[jnp.min(x1)],
    y=jnp.sort(x2),
    z=marginal_PDF(jnp.sort(x2)),
    mode='lines',
    name="Marginal Dist of y",
    marker=dict(
        size=2,
        colorscale='Viridis',
        opacity=0.8
    ))

# Scatter Plot for both rv's
fig_XY = go.Scatter3d(
    x=x1,
    y=x2,
    z=sample_size.value*[0],
    mode='markers',
    marker=dict(
        size=2,
        colorscale='Viridis',
        opacity=0.8
    ))

# Add graphs to plotly wiget
fig = go.FigureWidget(data=[fig_XY, fig_mar_x1, fig_mar_x2])

# Update Layout of graph
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0),
                  paper_bgcolor="rgba(241, 241, 241, 0.8)",
                  width=1000,
                  height=600,
                  legend=dict(
                            bgcolor="rgba(241, 241, 241, 1)",
                            bordercolor="Black",
                            borderwidth=2),
                  scene=dict(
                            xaxis_title='R.V x',
                            yaxis_title='R.V y',
                            zaxis_title='Marginal Dist of R.V',
                            zaxis=dict(range=[0, 1]))
                  )

# Add function to wigets for value changes
for slider in slider_dict.values():
    slider.observe(response, names='value')

# Set alingnment of graph and widgets and display data
container_slider = widgets.VBox(list(slider_dict.values()))
widgets.VBox([container_slider, fig])




HBox(children=(VBox(children=(FloatSlider(value=1.5, continuous_update=False, description='var_X', min=1.0), F…

### MatplotLib graph

In [6]:
from ipywidgets import *
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


def update(mean_x=mean_X, var_x=var_X, mean_y=mean_Y,
           var_y=var_Y, corr_xy=corr_XY, sample_sze=sample_size):

    fig = plt.figure()
    fig.set_figwidth(10)
    fig.set_figheight(10)

    ax = fig.add_subplot(1, 1, 1, projection='3d')
    ax.set_xlabel('R.V x')
    ax.set_ylabel('R.V y')
    ax.set_xlabel('Marginal Dist of R.V')

    x1, x2 = generate_bivariate_norm()
    z = sample_sze*[0]
    ax = plt.axes(projection='3d')

    x_line_1 = jnp.sort(x1)
    y_line_1 = sample_sze*[jnp.min(x2)]
    z_line_1 = marginal_PDF(jnp.sort(x1))

    x_line_2 = sample_sze*[jnp.min(x1)]
    y_line_2 = jnp.sort(x2)
    z_line_2 = marginal_PDF(jnp.sort(x2))

    ax.scatter3D(x1, x2, z, cmap='Greens')
    ax.plot3D(x_line_1, y_line_1, z_line_1)
    ax.plot3D(x_line_2, y_line_2, z_line_2)

    fig.canvas.draw()

interact(update, mean_x=mean_X, var_x=var_X, mean_y=mean_Y,
         var_y=var_Y, corr_xy=corr_XY, sample_sze=sample_size)


interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='mean_X'), FloatSlider(value…

<function __main__.update>