<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#imports-&amp;-config" data-toc-modified-id="imports-&amp;-config-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>imports &amp; config</a></span></li><li><span><a href="#make-data" data-toc-modified-id="make-data-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>make data</a></span></li><li><span><a href="#make-plots" data-toc-modified-id="make-plots-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>make plots</a></span><ul class="toc-item"><li><span><a href="#define-subplots" data-toc-modified-id="define-subplots-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>define subplots</a></span></li><li><span><a href="#make-figure" data-toc-modified-id="make-figure-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>make figure</a></span></li><li><span><a href="#update-data" data-toc-modified-id="update-data-3.3"><span class="toc-item-num">3.3&nbsp;&nbsp;</span>update data</a></span></li></ul></li></ul></div>

# imports & config

In [1]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats

from graphical_models.gaussian_mixture.univariate import UnivariateGaussianMixture

In [2]:
np.set_printoptions(suppress=True)

# make data

In [3]:
seed = 0
np_rng = np.random.default_rng(seed)

k = 5  # number of mixture components
n = 5000  # number of data points to sample

In [4]:
# define the gaussian mixture and its parameters

# weights: the relative contribution of each component
alpha_dirichlet = np.ones(shape=(k,)) * 3 # make a well-behaved prior for weights of mixture components
weights = np_rng.dirichlet(alpha=alpha_dirichlet)
# locs: the means of each component
locs = np_rng.normal(size=(k,)) * 10
# scales: the standard deviations (data are univariate, so no covariance)
scales = np.abs(np_rng.normal(size=(k,))) + 0.5

gmm = UnivariateGaussianMixture(weights, locs, scales)

In [5]:
gmm

GaussianMixture(
	weights=array([0.18321651, 0.2454824 , 0.119972  , 0.3446932 , 0.1066359 ]),
	locs=array([ -6.23274463,   0.41325979, -23.25030775,  -2.18791664,
       -12.45910947]),
	scales=array([1.23226735, 1.04425898, 0.81630016, 0.91163054, 1.54251337])
)

In [6]:
x = gmm.sample(seed, n)

In [7]:
x.shape

(5000,)

# make plots

## define subplots

In [8]:
# nbinsx = n//50

In [9]:
data_hist = go.Histogram(
    name='data',
    x=x,
#     nbinsx=nbinsx,
#     histnorm='probability'
)

In [10]:
mixtures_bar = go.Bar(
    name='gmm as bar',
    x=gmm.locs,
    y=gmm.weights,
    width=gmm.scales * 4,
#     opacity=0.5,
#     width=0.2,
#     error_x=dict(
#         type='data',
#         array=gmm.scales
#     ),
    marker_color=px.colors.qualitative.Plotly,
)

In [11]:
mixtures_points = go.Scatter(
    name='gmm components',
    x=gmm.locs,
    y=gmm.weights,
    mode='markers',
    error_x=dict(
        type='data',
        array=gmm.scales * 2,
#         width=3,
        thickness=1,
    ),
    marker=dict(
#         size=gmm.scales * 50,
        size=10,
        color=px.colors.qualitative.Plotly,
        symbol='circle',
        line=dict(
            color='black',
            width=3
        )
#         symbol='line-ew',
#         line_width=5,
#         line_color=px.colors.qualitative.Plotly,
))

In [12]:
def make_scatter_for_component(component_index):
    x = np.linspace(
        start=gmm.locs[component_index] - 4 * gmm.scales[component_index],
        stop=gmm.locs[component_index] + 4 * gmm.scales[component_index],
        num=51)
    y = gmm.weights[component_index] * scipy.stats.norm.pdf(
        x,
        loc=gmm.locs[component_index],
        scale=gmm.scales[component_index]
    )
    return go.Scatter(
        name=f'gmm component {component_index}',
        x=x,
        y=y,
        mode='lines',
        opacity=0.5,
        line=dict(
            shape='spline',
        )
    )


def make_scatter_for_mixture():
    x = np.linspace(
        start=np.min(gmm.locs - 4 * gmm.scales),
        stop=np.max(gmm.locs + 4 * gmm.scales),
        num=200)
    y = np.sum(
        [
            weight * scipy.stats.norm.pdf(x, loc=loc, scale=scale)
            for loc, scale, weight in zip(gmm.locs, gmm.scales, gmm.weights)
        ],
        axis=0
    )
    print(x.shape)
    print(y.shape)
    return go.Scatter(
        name=f'gmm density',
        x=x,
        y=y,
        line=dict(dash='dot', color='black'),
        mode='lines',
#         line_shape='spline'
    )


## make figure

In [13]:
fig = make_subplots(
    rows=3,
    cols=1,
    shared_xaxes=True,
    subplot_titles=(
        "Sampled Data",
        "GMM - Density Plot",
        "GMM - Scatter Plot",
    ),
    vertical_spacing=0.05,
#     specs=[[{"secondary_y": True}], [{"secondary_y": False}]],
)
fig.update_layout(height=800, width=1000, title_text="Data + Gaussian Mixtures")

# add data
fig.add_trace(data_hist, row=1, col=1)

# add GMM density
fig.add_trace(make_scatter_for_mixture(), row=2, col=1)
# add mixture densities
for k_ in range(k):
    fig.add_trace(make_scatter_for_component(k_), row=2, col=1)

# add GMM components
fig.add_trace(mixtures_points, row=3, col=1)
fig.update_yaxes(
    title_text="mixture weights",
    range=[0, max(gmm.weights) *1.2],
    row=3,
    col=1)

f = go.FigureWidget(fig)
f

(200,)
(200,)


FigureWidget({
    'data': [{'name': 'data',
              'type': 'histogram',
              'uid': 'fdcfc47b…

## update data

In [14]:
from graphical_models.gaussian_mixture.univariate import UnivariateGaussianMixture

def update_visualization(gmm):
    print(gmm)
    w, l, s = gmm.weights, gmm.locs, gmm.scales
#     print(w, l, s)
    return


def update_continuously(gmm):
    i = 0
    while i < 5:
        i += 1
        w, l, s = gmm.weights, gmm.locs, gmm.scales
        w = 0.99 * w + 0.01 * np_rng.dirichlet(alpha=np.ones((k,)))
        l += np_rng.normal(size=l.shape)
        s += np_rng.normal(size=s.shape)
        gmm = UnivariateGaussianMixture(w, l, s)
        update_visualization(gmm)
    return


update_continuously(gmm)

GaussianMixture(
	weights=array([0.18148396, 0.24339188, 0.12154554, 0.34333132, 0.11024731]),
	locs=array([ -6.13873233,  -0.33023946, -24.17203312,  -2.64564246,
       -12.23891435]),
	scales=array([0.22264917, 0.83508341, 0.65707515, 1.45247612, 1.75717249])
)
GaussianMixture(
	weights=array([0.18104377, 0.2426447 , 0.12231188, 0.34460762, 0.10939202]),
	locs=array([ -7.39779786,   1.18368432, -22.8261577 ,  -1.86433106,
       -11.97445872]),
	scales=array([-0.09127364,  2.29310409,  2.61733346,  3.25411099,  3.07227626])
)
GaussianMixture(
	weights=array([0.18280507, 0.24108344, 0.126045  , 0.34133032, 0.10873617]),
	locs=array([ -7.0026758 ,   1.61354801, -22.13011497,  -3.04844903,
       -12.63616129]),
	scales=array([-0.52770889,  1.12330218,  4.35670134,  2.75820026,  3.40124589])
)
GaussianMixture(
	weights=array([0.18280325, 0.23977852, 0.12704654, 0.34229774, 0.10807394]),
	locs=array([ -6.95064683,   2.2972342 , -21.1261534 ,  -3.66635608,
       -10.81414993]),
	scales=