In [38]:
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display
import time
import jax
from jax.random import multivariate_normal
import numpy as np
import plotly.graph_objects as go
from numpy import linalg as LA

var_X = widgets.FloatSlider(description='var_X',min=1,value=1.5,step=0.1,continuous_update=False)
var_Y = widgets.FloatSlider(description='var_Y',min=1,value=1.0,step=0.1,continuous_update=False)
mean_X = widgets.FloatSlider(description='mean_X',value=0,step=0.1,continuous_update=False)
mean_Y = widgets.FloatSlider(description='mean_Y',value=0,step=0.1,continuous_update=False)
cov_XY = widgets.FloatSlider(description='cov_XY',value=0.4,step=0.1,continuous_update=False)
sample_size = widgets.IntSlider(description='sample_size',min=100,max=1000,continuous_update=False)

slider_dict = {"var_X":var_X,
               "var_Y":var_Y,
               "mean_X":mean_X,
               "mean_Y":mean_Y,
               "cov_XY":cov_XY,
               "sample_size":sample_size}

key = jax.random.PRNGKey(0)
cov = np.array([[var_X.value, cov_XY.value], [cov_XY.value, var_Y.value]])
mean = np.array([mean_X.value,mean_Y.value])
x1,x2 = multivariate_normal(key, mean, cov, (sample_size.value,)).T




fig_mar_x1 = go.Scatter3d(
    x=np.sort(x1),
    y=sample_size.value*[np.min(x2)],
    z=marginal_PDF(np.sort(x1)),
    mode='lines',
    marker=dict(
        size=2,
        # color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))

fig_mar_x2 = go.Scatter3d(
    x=sample_size.value*[np.min(x1)],
    y=np.sort(x2),
    z=marginal_PDF(np.sort(x2)),
    mode='lines',
    marker=dict(
        size=2,
        # color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))


fig_XY = go.Scatter3d(
    x=x1,
    y=x2,
    z=sample_size.value*[0],
    mode='markers',
    marker=dict(
        size=2,
        # color=z,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))


fig = go.FigureWidget(data=[fig_XY,fig_mar_x1,fig_mar_x2])

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0),
                    scene = dict(
                      #  xaxis = dict(nticks=4, range=[-100,100],),
                      #  yaxis = dict(nticks=4, range=[-50,100],),
                       zaxis = dict(range=[0,1]))
                    )


container_slider = widgets.VBox(list(slider_dict.values()))

def response(v):
    ## DEBUG : Remove before sending
    print(var_X.value,var_Y.value,mean_X.value,mean_Y.value,cov_XY.value,sample_size.value)

    key = jax.random.PRNGKey(0)
    cov = np.array([[var_X.value, cov_XY.value], [cov_XY.value, var_Y.value]])
    mean = np.array([mean_X.value,mean_Y.value])
    x1,x2 = multivariate_normal(key, mean, cov, (sample_size.value,)).T
    
    cov_eignvals = LA.eigvals(cov)
    
    #DEBUG
    print(cov_eignvals)
    # print(np.any(cov_eignvals < 0))
    
    if(np.any(cov_eignvals < 0)):
        print(f"The current covariance matrix {cov} is not positive semidefinite hence a new random covariance martrix will be created")
        rand_mat = np.random.rand(2,2)
        cov = np.dot(rand_mat, rand_mat.T)
        cov_XY.value = cov[:, -1][0]
        x1,x2 = multivariate_normal(key, mean, cov, (sample_size.value,)).T
    
        print(LA.eigvals(cov))      
              

    with fig.batch_update():
        fig.data[0].x = x1
        fig.data[0].y = x2
        fig.data[0].z = sample_size.value*[0]
        
        fig.data[1].x = np.sort(x1)
        fig.data[1].y = sample_size.value*[np.min(x2)]
        fig.data[1].z = marginal_PDF(np.sort(x1))
        
        fig.data[2].x = sample_size.value*[np.min(x1)]
        fig.data[2].y = np.sort(x2)
        fig.data[2].z = marginal_PDF(np.sort(x2))
        


for slider in slider_dict.values():
    slider.observe(response, names='value')

widgets.VBox([container_slider,fig])

VBox(children=(VBox(children=(FloatSlider(value=1.5, continuous_update=False, description='var_X', min=1.0), F…

In [1]:
def marginal_PDF(x):
    mean = np.mean(x)
    var  = np.var(x)
    std_dev = np.sqrt(var)
    
    const = 1/(2.506628274631*std_dev)
    
    power = 1/np.exp(((x-mean)**2)/(2*var))
    return const*power

Refrence for Positive Semideifinite Matrix
https://en.wikipedia.org/wiki/Definite_matrix#Negative-definite.2C_semidefinite_and_indefinite_matrices

Creating a positive semidefinite matrix
https://stackoverflow.com/questions/619335/a-simple-algorithm-for-generating-positive-semidefinite-matrices

Variance Covariance Matrix Info
https://www.cuemath.com/algebra/covariance-matrix/

Bivariate Normal
- http://athenasc.com/Bivariate-Normal.pdf
- https://webspace.maths.qmul.ac.uk/i.goldsheid/MTH5118/Notes11-09.pdf

In [29]:
cov

array([[1.5, 0.4],
       [0.4, 1. ]])

0.4