In [1]:
import json
import logging

import numpy as np

from graphical_models.gaussian_mixture.univariate import (
    UnivariateGaussianMixture,
    learn_em,
)


# setup

## logging

In [2]:
class UnivariateGaussianMixtureHandler(logging.Handler):
    def emit(self, record):
#         print(record.data["improvement"])
        if record.data["iteration"] == 1:
            self._previous_value = -1e100
            self._iterations = ()
            self._likelihoods = ()
            self._improvements = ()
        self._iterations = self._iterations + (record.data["iteration"],)
        self._likelihoods = self._likelihoods + (record.data["average_log_likelihood"],)
        self._improvements = self._improvements + (record.data["improvement"],)
        improvement_since_last = record.data["average_log_likelihood"] - self._previous_value
        ignore = improvement_since_last < 0.01
        ignore = ignore and not record.data["is_done"]
        ignore = ignore and not len(self._iterations) > 10
        if ignore:
            return
        update_learned_gmm_plots(record.data["gaussian_mixture"])
        update_progress_plot(
            self._iterations,
            self._likelihoods,
            self._improvements)
        # reset history
        self._previous_value = record.data["average_log_likelihood"]
        self._iterations = ()
        self._likelihoods = ()
        self._improvements = ()
        return


logging.basicConfig(level='DEBUG')
logging.getLogger("graphical_models.gaussian_mixture.univariate").addHandler(UnivariateGaussianMixtureHandler())
logger = logging.getLogger(__name__)


In [3]:
logging.getLogger().handlers

[<StreamHandler stderr (NOTSET)>]

In [4]:
logging.getLogger("graphical_models.gaussian_mixture.univariate").handlers

[<UnivariateGaussianMixtureHandler (NOTSET)>]

In [5]:
logging.getLogger(__name__).handlers

[]

## data

In [6]:
np.set_printoptions(suppress=True)
np_rng = np.random.default_rng(seed=0)


In [7]:
k = 5
alpha_dirichlet = np.ones(shape=(k,)) * 5
weights = np_rng.dirichlet(alpha=alpha_dirichlet)
locs = np_rng.normal(size=(k,)) * 5
scales = np.abs(np_rng.normal(size=(k,)))
oracle = UnivariateGaussianMixture(weights, locs, scales)

logger.info("oracle: %s", oracle)

INFO:__main__:oracle: GaussianMixture(
	weights=array([0.18918143, 0.23693148, 0.13786091, 0.30952975, 0.12649643]),
	locs=array([ -3.11637231,   0.2066299 , -11.62515387,  -1.09395832,
        -6.22955474]),
	scales=array([0.73226735, 0.54425898, 0.31630016, 0.41163054, 1.04251337])
)


In [8]:
x = oracle.sample(seed=1, n=10_000)


# figure

In [9]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [10]:
data_hist = go.Histogram(
    name='data',
    x=x,
#     nbinsx=nbinsx,
#     histnorm='probability'
)

def make_empty_gmm_plot(name):
    return go.Scatter(
        name=name,
        x=[],
        y=[],
        mode='markers',
        error_x=dict(
            type='data',
            array=[],
            thickness=1,
        ),
        marker=dict(
            size=10,
            color=px.colors.qualitative.Plotly,
            symbol='circle',
            line=dict(
                color='black',
                width=3
            )
    ))


In [11]:
fig = make_subplots(
    rows=3,
    cols=1,
    shared_xaxes=True,
    subplot_titles=[
        "data",
        "generative gmm (log scale)",
        "learned gmm (log scale)",
#         "generative gmm",
#         "learned gmm",
    ],
    vertical_spacing=0.1,

)
fig.update_layout(height=800, title_text="gaussian mixture learning")
# fig.update_layout(legend=dict(x=1, y=1))
fig.update_layout(showlegend=False)
fig.add_trace(data_hist, row=1, col=1)
fig.add_trace(make_empty_gmm_plot('generative gmm (log scale)'), row=2, col=1)
fig.add_trace(make_empty_gmm_plot('learned gmm (log scale)'), row=3, col=1)
# fig.add_trace(make_empty_gmm_plot('generative gmm'), row=4, col=1)
# fig.add_trace(make_empty_gmm_plot('learned gmm'), row=5, col=1)
fig.update_xaxes(title_text='x', range=[min(x), max(x)])
fig.update_yaxes(
    title_text="mixture weights",
    range=[-2, 0],
    type="log",
    row=2,
    col=1)
fig.update_yaxes(
    title_text="mixture weights",
    range=[-2, 0],
    type="log",
    row=3,
    col=1)
# fig.update_yaxes(
#     title_text="mixture weights",
#     range=[0, 1],
#     row=4,
#     col=1)
# fig.update_yaxes(
#     title_text="mixture weights",
#     range=[0, 1],
#     row=5,
#     col=1)

fig = go.FigureWidget(fig)


def update_gmm_plot(gmm, trace):
    fig.data[trace].x = gmm.locs
    fig.data[trace].y = gmm.weights
    fig.data[trace].error_x.array = gmm.scales
    return


def set_oracle_plots(gmm):
    update_gmm_plot(gmm, trace=1)
#     update_gmm_plot(gmm, trace=3)
    return


def update_learned_gmm_plots(gmm):
    update_gmm_plot(gmm, trace=2)
#     update_gmm_plot(gmm, trace=4)
    return


In [12]:
def make_empty_line_plot():
    return go.Scatter(x=[], y=[], mode='lines')


fig_progress = make_subplots(
    rows=2,
    cols=1,
    shared_xaxes=True,
    subplot_titles=[
        "average_log_likelihood",
        "Δ average_log_likelihood"],
    vertical_spacing=0.1,
)
fig_progress.update_layout(height=800, showlegend=False)
fig_progress.add_trace(make_empty_line_plot(), row=1, col=1)
fig_progress.add_trace(make_empty_line_plot(), row=2, col=1)
fig_progress.update_xaxes(title_text='iteration #')
fig_progress.update_yaxes(
    type="log",
    row=2,
    col=1)


fig_progress = go.FigureWidget(fig_progress)


def update_progress_plot(iterations, log_likelihoods, improvements):
    fig_progress.data[0].x = fig_progress.data[0].x + iterations
    fig_progress.data[0].y = fig_progress.data[0].y + log_likelihoods
    fig_progress.data[1].x = fig_progress.data[1].x + iterations
    fig_progress.data[1].y = fig_progress.data[1].y + improvements
    return


In [13]:
fig

FigureWidget({
    'data': [{'name': 'data',
              'type': 'histogram',
              'uid': '62af7036…

In [14]:
fig_progress

FigureWidget({
    'data': [{'mode': 'lines',
              'type': 'scatter',
              'uid': '0b23426c-…

# train GMM

In [15]:
set_oracle_plots(oracle)

In [16]:
gmm_learned = learn_em(x, k, oracle, max_iter=1e10)


DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=1, improvement=1.031931e+01
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=2, improvement=4.769650e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=3, improvement=1.337420e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=4, improvement=5.776745e-03
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=5, improvement=4.053237e-03
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=6, improvement=3.629810e-03
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=7, improvement=3.417139e-03
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=8, improvement=3.166520e-03
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=9, improvement=2.838801e-03
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=10, improvement=2.458930e-03
DEBUG:graphical_mod

DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=84, improvement=5.161193e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=85, improvement=4.617080e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=86, improvement=4.504938e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=87, improvement=4.641312e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=88, improvement=4.949667e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=89, improvement=5.398701e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=90, improvement=5.977495e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=91, improvement=6.684958e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=92, improvement=7.524730e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=93, improvement=8.502109e-05
DEBUG:grap

DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=167, improvement=1.845750e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=168, improvement=1.522825e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=169, improvement=1.239781e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=170, improvement=9.979024e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=171, improvement=7.957228e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=172, improvement=6.298497e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=173, improvement=4.958186e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=174, improvement=3.888065e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=175, improvement=3.041365e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=176, improvement=2.375795e-05


In [None]:
print("done. learned gaussian mixture:")
print(gmm_learned)

print("and the 'oracle' (truth) was:")
print(oracle)