In [1]:
import json
import logging

import numpy as np

from graphical_models.gaussian_mixture.univariate import (
    UnivariateGaussianMixture,
    learn_em,
)


# setup

## logging

In [2]:
class UnivariateGaussianMixtureHandler(logging.Handler):
    def emit(self, record):
#         print(record.data["improvement"])
        if record.data["iteration"] == 1:
            self._previous_value = -1e100
            self._iterations = ()
            self._likelihoods = ()
            self._improvements = ()
        self._iterations = self._iterations + (record.data["iteration"],)
        self._likelihoods = self._likelihoods + (record.data["average_log_likelihood"],)
        self._improvements = self._improvements + (record.data["improvement"],)
        improvement_since_last = record.data["average_log_likelihood"] - self._previous_value
        ignore = improvement_since_last < 0.01
        ignore = ignore and not record.data["is_done"]
        ignore = ignore and not len(self._iterations) > 100
        if ignore:
            return
        update_learned_gmm_plots(record.data["gaussian_mixture"])
        update_progress_plot(
            self._iterations,
            self._likelihoods,
            self._improvements)
        # reset history
        self._previous_value = record.data["average_log_likelihood"]
        self._iterations = ()
        self._likelihoods = ()
        self._improvements = ()
        return


logging.basicConfig(level='DEBUG')
logging.getLogger("graphical_models.gaussian_mixture.univariate").addHandler(UnivariateGaussianMixtureHandler())
logger = logging.getLogger(__name__)


In [3]:
logging.getLogger().handlers

[<StreamHandler stderr (NOTSET)>]

In [4]:
logging.getLogger("graphical_models.gaussian_mixture.univariate").handlers

[<UnivariateGaussianMixtureHandler (NOTSET)>]

In [5]:
logging.getLogger(__name__).handlers

[]

## data

In [6]:
np.set_printoptions(suppress=True)
np_rng = np.random.default_rng(seed=0)


In [7]:
k = 8
alpha_dirichlet = np.ones(shape=(k,)) * 5
weights = np_rng.dirichlet(alpha=alpha_dirichlet)
locs = np_rng.normal(size=(k,)) * 5
scales = np.abs(np_rng.normal(size=(k,)))
oracle = UnivariateGaussianMixture(weights, locs, scales)

logger.info("oracle: %s", oracle)

INFO:__main__:oracle: GaussianMixture(
	weights=array([0.14860343, 0.18611145, 0.10829077, 0.24313792, 0.09936389,
       0.10357373, 0.03698758, 0.07393123]),
	locs=array([-2.72129491, -1.58150078,  2.05815268,  5.21256685, -0.64267331,
        6.83231735, -3.32597337,  1.75755035]),
	scales=array([0.90347018, 0.0940123 , 0.74349925, 0.92172538, 0.45772583,
       0.22019512, 1.00961818, 0.20917557])
)


In [8]:
x = oracle.sample(seed=1, n=10_000)


# figure

In [9]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [10]:
x_axis_min_and_max = (min(x), max(x))

In [11]:
data_hist = go.Histogram(
    name='data',
    x=x,
#     nbinsx=nbinsx,
#     histnorm='probability'
)

def make_empty_gmm_pointwise_plot(name):
    return go.Scatter(
        name=name,
        x=[],
        y=[],
        mode='markers',
        error_x=dict(
            type='data',
            array=[],
            thickness=1,
        ),
        marker=dict(
            size=10,
            color=px.colors.qualitative.Plotly,
            symbol='circle',
            line=dict(
                color='black',
                width=3
            )
    ))

def make_empty_gmm_density_plot(name):
    x_values = np.linspace(*x_axis_min_and_max, num=101)
    y_values = np.zeros_like(x_values)
    return go.Scatter(
        name=name,
        x=x_values,
        y=y_values,
        mode='lines',
        line_shape='spline',
    )
    


In [18]:
fig = make_subplots(
    rows=4,
    cols=1,
    shared_xaxes=True,
    subplot_titles=[
        "data",
        "generative gaussian mixture (complete density)",
        "learned gaussian mixture (complete density)",
        "learned gaussian mixture (individual components)",
    ],
    vertical_spacing=0.05,

)
fig.update_layout(height=1000, title_text="gaussian mixture learning")
# fig.update_layout(legend=dict(x=1, y=1))
fig.update_layout(showlegend=False)
fig.add_trace(data_hist, row=1, col=1)
fig.add_trace(make_empty_gmm_density_plot('generative gmm'), row=2, col=1)
fig.add_trace(make_empty_gmm_density_plot('learned gmm'), row=3, col=1)
fig.add_trace(make_empty_gmm_pointwise_plot('learned gmm (point-wise)'), row=4, col=1)
fig.update_xaxes(range=x_axis_min_and_max)
fig.update_yaxes(
    title_text="density",
#     range=[-2, 0],
#     type="log",
    row=2,
    col=1)
fig.update_yaxes(
    title_text="mixture weights",
    range=[-3, 0],
    type="log",
    row=4,
    col=1)

fig = go.FigureWidget(fig)


def update_gmm_pointwise_plot(gmm, trace):
    fig.data[trace].x = gmm.locs
    fig.data[trace].y = gmm.weights
    fig.data[trace].error_x.array = gmm.scales
    return


def update_gmm_density_plot(gmm, trace):
    x_values = fig.data[trace].x
    y_values = gmm.pdf(x_values)
    fig.data[trace].y = y_values
    return


def set_oracle_plots(gmm):
    update_gmm_density_plot(gmm, trace=1)
#     update_gmm_pointwise_plot(gmm, trace=1)
    return


def update_learned_gmm_plots(gmm):
    update_gmm_density_plot(gmm, trace=2)
    update_gmm_pointwise_plot(gmm, trace=3)
    return



In [25]:
def make_empty_line_plot():
    return go.Scatter(x=[], y=[], mode='lines')


fig_progress = make_subplots(
    rows=2,
    cols=1,
    shared_xaxes=True,
    subplot_titles=[
        "average_log_likelihood",
        "average_log_likelihood improvement per iteration"],
    vertical_spacing=0.12,
)
fig_progress.update_layout(height=600, showlegend=False)
fig_progress.add_trace(make_empty_line_plot(), row=1, col=1)
fig_progress.add_trace(make_empty_line_plot(), row=2, col=1)
fig_progress.update_xaxes(title_text='iteration #')
fig_progress.update_yaxes(
    type="log",
    row=2,
    col=1)


fig_progress = go.FigureWidget(fig_progress)


def update_progress_plot(iterations, log_likelihoods, improvements):
    fig_progress.data[0].x = fig_progress.data[0].x + iterations
    fig_progress.data[0].y = fig_progress.data[0].y + log_likelihoods
    fig_progress.data[1].x = fig_progress.data[1].x + iterations
    fig_progress.data[1].y = fig_progress.data[1].y + improvements
    return


In [26]:
fig

FigureWidget({
    'data': [{'name': 'data',
              'type': 'histogram',
              'uid': 'cbb754e8…

In [27]:
fig_progress

FigureWidget({
    'data': [{'mode': 'lines',
              'type': 'scatter',
              'uid': '2cf46787-…

In [28]:
set_oracle_plots(oracle)

# train GMM

In [29]:
gmm_learned = learn_em(x, k, oracle, max_iter=1e10)


DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=1, improvement=5.616820e+00
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=2, improvement=4.209689e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=3, improvement=2.155189e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=4, improvement=1.571697e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=5, improvement=1.469269e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=6, improvement=1.456203e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=7, improvement=1.572185e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=8, improvement=1.885965e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=9, improvement=2.377293e-02
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=10, improvement=2.897693e-02
DEBUG:graphical_mod

DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=84, improvement=1.320704e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=85, improvement=1.173822e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=86, improvement=1.046185e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=87, improvement=9.346453e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=88, improvement=8.368202e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=89, improvement=7.509003e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=90, improvement=6.754725e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=91, improvement=6.093855e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=92, improvement=5.516562e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=93, improvement=5.014100e-05
DEBUG:grap

DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=167, improvement=1.686161e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=168, improvement=1.765626e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=169, improvement=1.853231e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=170, improvement=1.949815e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=171, improvement=2.056336e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=172, improvement=2.173898e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=173, improvement=2.303765e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=174, improvement=2.447399e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=175, improvement=2.606489e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=176, improvement=2.782992e-05


DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=249, improvement=9.370496e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=250, improvement=9.368738e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=251, improvement=9.382346e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=252, improvement=9.412036e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=253, improvement=9.458689e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=254, improvement=9.523322e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=255, improvement=9.607077e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=256, improvement=9.711207e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=257, improvement=9.837076e-06
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=258, improvement=9.986160e-06


DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=331, improvement=1.178679e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=332, improvement=1.162228e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=333, improvement=1.144964e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=334, improvement=1.126193e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=335, improvement=1.105219e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=336, improvement=1.081394e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=337, improvement=1.054169e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=338, improvement=1.023146e-04
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=339, improvement=9.881151e-05
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=340, improvement=9.490810e-05


DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=413, improvement=1.431981e-08
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=414, improvement=1.278558e-08
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=415, improvement=1.143518e-08
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=416, improvement=1.024558e-08
DEBUG:graphical_models.gaussian_mixture.univariate.learn_em:iteration=417, improvement=9.196675e-09


In [None]:
print("done. learned gaussian mixture:")
print(gmm_learned)

print("and the 'oracle' (truth) was:")
print(oracle)