In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Gaussian mixture model-based clustering

In this notebook, we are going to take a look at how to cluster Gaussian-distributed data.

Imagine you have data that are multi-modal.
A task of interest, naturally, is to cluster that data.
Let's see how we can accomplish this using nothing but the NumPy API and some gradients from JAX. (And maybe some vmaps too!)

## Generate mixture gaussians

As always, when exploring a method,
we start with a highly simplified version of it
that contains a ton of constraints that we know of,
which we can always break later.
Thus, we'll start with a simple version of our problem:
the setting where we have bimodal Gaussian data,
and we want to identify the cluster centers.

In [None]:
from jax import vmap, grad
import jax.numpy as np
from jax.scipy import stats
from jax import random
from jax.scipy.special import logsumexp
from functools import partial

Let's start by describing how the data are generated. We'll set it up such that there are two cluster centers: one at -2 and one at 3. We'll also impose a 1:5 ratio of data between those two.

In [None]:
import matplotlib.pyplot as plt
weights_true = np.array([1, 5])
locs_true = np.array([-2., 5.])
scale_true = np.array([1.1, 2])

base_n_draws = 1000
key = random.PRNGKey(100)

k1, k2 = random.split(key)

draws_1 = scale_true[0] * random.normal(k1, shape=(base_n_draws * weights_true[0],)) + locs_true[0]
draws_2 = scale_true[1] * random.normal(k2, shape=(base_n_draws * weights_true[1],)) + locs_true[1]
data_mixture = np.concatenate([draws_1, draws_2])
plt.hist(data_mixture);

Our learning task at hand is thus to learn the two cluster centers, and their relative weighting.

When faced with probabilistically-generated data, it is oftentimes desirable to impose a likelihood distribution
on the observed data.
Doing so gives us a quantitative measure of "goodness of fit" for our parameters.
Here, because we observe bimodal data, we might hazard a guess that
a two-component mixture distribution likelihood would be good. 

Here, the likelihood of the each data point is the sum of the likelihood of each data point under each of the components. In other words, for each data point, we calculate the likelihood of observing that datum under each component's distribution, adjust the likelihood by multiplying it by the weight, and sum up the component weights. Because we assume that each data point is independently drawn, we therefore multiply each datum's likelihood to get the joint likelihood of all data observed.

To see some of the JAX programming idioms in action, we are going to build things up from the core as usual.

Let's write the log likelihood of one datum under one component.

In [None]:
from dl_workshop.gaussian_mixture import loglike_one_component

loglike_one_component??

The summation here is because we are operating in logarithmic space.

You might ask, why do we use "log" of the component scale,
and why do we use the "logit" of the component probability?
This is a math trick that helps us whenever we are doing computations in an unbounded space.
When doing gradient descent,
we can never guarantee that a gradient update on a parameter that ought to be positive-only
will give us a positive number.
Thus, for positive numbers, we operate in logarithmic space.


We can quickly write a test here. If the component probability is 1.0, the component $\mu$ is 0, and the observed datum is also 0, it should equal to the log-likelihood of 0 under a unit Gaussian.

In [None]:
loglike_one_component(
    component_weight=1.0, 
    component_mu=0., 
    log_component_scale=1., 
    datum=0.) == (
    stats.norm.logpdf(x=0, loc=0, scale=1)
)

Leveraging what we know now, let's write a function that calculates the total log likelihood of our data
under all of the component probability distributions.

In [None]:
from jax.scipy.special import logit

def normalize_weights(weights):
    """Normalize a weights vector to sum to 1."""
    return weights / np.sum(weights)

def loglike_across_components(
    log_component_weights,
    component_mus,
    log_component_scales,
    datum
):
    """Log likelihood of datum under all components of the mixture."""
    component_weights = normalize_weights(
        np.exp(log_component_weights)
    )
    loglike_components = vmap(
        partial(
            loglike_one_component,
            datum=datum
        )
    )(component_weights, component_mus, log_component_scales)
    return logsumexp(loglike_components)

Inside that function, we first calculated elementwise the log-likelihood of observing that data under each component.
That only gives us per-component log-likelihoods though.
Because our data could have been drawn from any of those components,
the total likelihood is a _sum_ of the per-component likelihoods.
Thus, we have to elementwise exponentiate the log-likelihoods, 
Now, we have sum up each of those probability components together,
so we have to use the [logsumexp](https://en.wikipedia.org/wiki/LogSumExp) function,
which first exponentiates each of the probabilities,
sums them up,
and then takes their log again.
(We could have written our own version of the function,
but I think it makes a ton of sense
to trust the numerically-stable,
professionally-implemented version provided
in SciPy!
Let us now test-drive this function,
which should give us a scalar value at the end.

In [None]:
weights_bad = np.array([1 - 0.0001, 0.0001])
weights_bad

In [None]:
loglike_across_components(
    log_component_weights=np.log(weights_true),
    component_mus=locs_true,
    log_component_scales=np.log(scale_true),
    datum=data_mixture[1],
)

Great, that worked!

Now that we've got the log-likelihood of each datum under each component,
we can now `vmap` the function across all data given to us.

In [None]:
def mixture_loglike(log_component_weights, component_mus, log_component_scales, data):
    """Log likelihood of data (not datum!) under all components of the mixture."""
    ll_per_data = vmap(
        partial(
            loglike_across_components,
            log_component_weights,
            component_mus,
            log_component_scales
        )
    )(data)
    return np.sum(ll_per_data)

In [None]:
log_weights_true = np.log(weights_true)

In [None]:
mixture_loglike(
    log_component_weights=np.log(weights_true),
    component_mus=locs_true,
    log_component_scales=np.log(scale_true),
    data=data_mixture,
)

If we play around with the mixture loglike though, we'll notice that it isn't the end of the story.
The component weights can be "hacked" to produce higher log-likelihood values,
by minimizing one of the components.
We need thus need to postulate a generative story for the weights,
which will provide an anchoring distribution.
One reasonable thing is to postulate that it came from a Dirichlet distribution
that gave equal weight across each of the components..

In [None]:
def weights_loglike(log_component_weights):
    component_weights = np.exp(log_component_weights)
    component_weights = normalize_weights(component_weights)
    return stats.dirichlet.logpdf(x=component_weights, alpha=2 * np.ones_like(component_weights))

In [None]:
weights_loglike(log_weights_true)

In [None]:
weights_bad = np.array([3., 2.])
log_weights_bad = np.log(weights_bad)
weights_loglike(log_weights_bad), weights_loglike(log_weights_true)

Now that we have composed together our generative story for the data,
let's pause for a moment and break down our model a bit.
This will serve as a review of what we've done.

Firstly, we have our "model", i.e. the log-likelihood of our data
conditioned on some parameter set and their values.

Secondly, our parameters to tweak and adjust to find maximum likelihood values for are:

1. Component weights.
2. Component central tendencies/means
3. Component scales/variances.

What we're going to attempt next is to optimize those parameters, leveraging what we've learned before.

## Gradient descent to find maximum likelihood values

Using JAX's optimizers, we're always interested in finding the minima of a function.
However, we're faced with a _maximum_ likelihood problem.
We can get around the problem by simply inverting the sign of our problem.
Let's see this in action.

As always, we begin with a loss function to minimize, and its derivative:

In [None]:
from jax import grad

def loss(params, data):
    log_component_weights, component_mus, log_component_scales = params
    loglike_mixture = mixture_loglike(
        log_component_weights,
        component_mus,
        log_component_scales,
        data
    )
    loglike_weights = weights_loglike(log_component_weights)
    
    total = loglike_mixture + loglike_weights
    return -total

dloss = grad(loss)
dloss

Next, we initialize our three parameters with random numbers.

In [None]:
N_MIXTURE_COMPONENTS = 2

k1, k2, k3, k4 = random.split(key, 4)
log_component_weights_init = random.normal(k1, shape=(N_MIXTURE_COMPONENTS,))
component_mus_init = random.normal(k2, shape=(N_MIXTURE_COMPONENTS,))
log_component_scales_init = random.normal(k3, shape=(N_MIXTURE_COMPONENTS,))

params_init = log_component_weights_init, component_mus_init, log_component_scales_init
params_true = np.log(weights_true), locs_true, np.log(scale_true)

Let's now test-drive the functions to make sure they execute properly.

In [None]:
loss(params_true, data_mixture)

In [None]:
loss(params_init, data_mixture)

In [None]:
dloss(params_init, data_mixture)

As you can see, in contrast to vanilla `grad`, `value_and_grad` also gives us
the loss as the first element in the tuple.

Now, we are going to use JAX's optimizers inside a `lax.scan`-ed training loop
to get fast training going.

We begin with the elementary "step" function.

In [None]:
def step(i, state, get_params_func, dloss_func, update_func, data):
    params = get_params_func(state)
    g = dloss_func(params, data)
    state = update_func(i, g, state)
    return state

We then make the elementary step function a scannable one using `lax.scan`.
This will allow us to "scan" the function across an array
that represents the number of optimization steps we will be using.

In [None]:
def make_step_scannable(get_params_func, dloss_func, update_func, data):
    def inner(previous_state, iteration):
        new_state = step(
            i=iteration,
            state=previous_state,
            get_params_func=get_params_func,
            dloss_func=dloss_func,
            update_func=update_func,
            data=data,
        )
        return new_state, previous_state
    return inner

Now we actually instantiate the scannable step.

In [None]:
from jax.experimental.optimizers import adam

adam_init, adam_update, adam_get_params = adam(0.5)

step_scannable = make_step_scannable(
    get_params_func=adam_get_params,
    dloss_func=dloss,
    update_func=adam_update,
    data=data_mixture, 
)

Then, we `lax.scan` `step_scannable` over 1000 iterations (constructed as an `np.arange()` array).

In [None]:
from jax import lax

initial_state = adam_init(params_init)

final_state, state_history = lax.scan(step_scannable, initial_state, np.arange(1000))

Let's now unpack our parameters!

In [None]:
params_opt = adam_get_params(final_state)
log_component_weights_opt, component_mus_opt, log_component_scales_opt = params_opt

Let's first check that we indeed have "learned". The loss function value should be pretty darn close to the loss function when we put in true params.
Keep in mind that because we have data that are an imperfect sample of the ground truth distribution,
it is possible that our optimized params' negative log likelihood will be lower than that of the true params.

In [None]:
loss(params_opt, data_mixture), loss(params_true, data_mixture)

Next up: what do the component probabilities look like? Do they reflect what we expect?

In [None]:
np.exp(log_component_weights_opt), weights_true

That's so rad! We're at the 1:5 ratio that was prescribed at the beginning!

And how about the component means?

In [None]:
component_mus_opt, locs_true

Also really close! And finally, the component scales:

In [None]:
np.exp(log_component_scales_opt), scale_true

Very nice, really close to the ground truth too.

## Let's visualize the mixture distributions.

We're going to visualize the mixture distributions to help us get a handle over what exactly happened during training.

To start, we need a function that plots the mixture distributions.

In [None]:
from jax.scipy.stats import norm

def plot_component_norm_pdfs(log_component_weights, component_mus, log_component_scales, xmin, xmax, ax, title):
    component_weights = normalize_weights(np.exp(log_component_weights))
    component_scales = np.exp(log_component_scales)
    x = np.linspace(xmin, xmax, 1000).reshape(-1,1)
    pdfs = component_weights * norm.pdf(x, loc=component_mus, scale=component_scales)
    for component in range(pdfs.shape[1]):
        ax.plot(x, pdfs[:, component])
    ax.set_title(title)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 4))
# Plot ground truth
plot_component_norm_pdfs(np.log(weights_true), locs_true, np.log(scale_true), -7.5, 10, axes[0], title="Ground Truth")

# Plot initialized
plot_component_norm_pdfs(
    log_component_weights_init,
    component_mus_init,
    log_component_scales_init,
    xmin=-7.5,
    xmax=10,
    ax=axes[1],
    title="Initialized",
)

# Plot optimized
plot_component_norm_pdfs(
    log_component_weights_opt, 
    component_mus_opt,
    log_component_scales_opt,
    xmin=-7.5,
    xmax=10,
    ax=axes[2],
    title="Optimized",
)

### Learning over history

Let's also take a look at the mixture PDFs over training iterations.

In [None]:
(
    log_component_weights_history,
    component_mus_history,
    log_component_scales_history
) = adam_get_params(state_history)

In [None]:
%%capture
from celluloid import Camera

fig, ax = plt.subplots()
cam = Camera(fig)

for w, m, s in zip(log_component_weights_history[::10], component_mus_history[::10], log_component_scales_history[::10]):
    ax.hist(data_mixture, bins=40, normed=True, color="blue")
    plot_component_norm_pdfs(
        w, m, s, xmin=-20, xmax=20, ax=ax, title=None,
    )
    cam.snap()
    
animation = cam.animate()

In [None]:
from IPython.display import HTML

HTML(animation.to_html5_video())

There's some comments to be said on the dynamics here:

1. At first, one Gaussian is used to approximate over the entire distribution. It's not a good fit, but approximates it fine enough.
1. However, our optimization routine continues to push forward, eventually finding the bimodal pattern. Once this happens, the PDFs fit very nicely to the data samples.

This phenomena is also reflected in the loss:

In [None]:
def get_loss(state):
    params = adam_get_params(state)
    loss_score = loss(params, data_mixture)
    return loss_score

losses = vmap(get_loss)(state_history)
plt.plot(losses)
plt.yscale("log");


You should notice the first plateau, followed by the second plateau.
This corresponds to the two phases of learning.

Now, thus far, we have set up the problem in a fashion that is essentially "trivial".
What if, however, we wanted to try fitting a mixture Gaussian where we didn't know exactly how many mixture components there _ought_ to be?

## Generalizing this to "unknown" numbers of modes

We're going to see how we can generalize this to an "unknown" number of modes.

To make the problem a bit harder, we'll start by expanding our data to have more mixture components:

In [None]:
weights_true = np.array([1, 5, 0.9, 3])
locs_true = np.array([-2., -5., 3., 8.])
scale_true = np.array([1.1, 2, 1., 1.5,])

base_n_draws = 1000

keys = random.split(key, 4)

draws = []
for i in range(4):
    shape = int(base_n_draws * weights_true[i]),
    draw = scale_true[i] * random.normal(keys[i], shape=shape) + locs_true[i]
    draws.append(draw)
data_mixture = np.concatenate(draws)
plt.hist(data_mixture);

From the histogram, it should be easy to tell that this is not going to be an easy problem to solve.
Firstly, the mixture distributions in _reality_ have 4 components.
But what we get looks more like 2 components... or really?
Could it be that we're lying by using a histogram?

In [None]:
plt.hist(data_mixture, bins=40);

Aha! The case against histograms reveals itself. Turns out there's lots of problems using histograms, and I shan't go deeper into them here, but obscuring data is one of those issues. To learn more, I wrote [a blog post on the matter](https://ericmjl.github.io/blog/2018/7/14/ecdfs/).

Let us now go back to pretending that we don't know the _actual_ number of mixture components. How would we handle this situation?

One practical way to handle this is to provide a very large number of possible component weights, and then let the optimization routine figure out how to get us to the maximum likelihood estimation of each Gaussian components' weights, means, and variances.

In [None]:
N_MIXTURE_COMPONENTS = 20

k1, k2, k3 = random.split(key, 3)
log_component_weights_init = random.normal(k1, shape=(N_MIXTURE_COMPONENTS,))
component_mus_init = random.normal(k2, shape=(N_MIXTURE_COMPONENTS,))
log_component_scales_init = random.normal(k3, shape=(N_MIXTURE_COMPONENTS,))

params_init = log_component_weights_init, component_mus_init, log_component_scales_init
params_true = np.log(weights_true), locs_true, np.log(scale_true)

In [None]:
from jax import jit
initial_state = adam_init(params_init)

final_state, state_history = lax.scan(jit(step_scannable), initial_state, np.arange(10000))

In [None]:
params_opt = adam_get_params(final_state)
log_component_weights_opt, component_mus_opt, log_component_scales_opt = params_opt

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 4))
# Plot ground truth
plot_component_norm_pdfs(np.log(weights_true), locs_true, np.log(scale_true), -10, 10, axes[0], title="Ground Truth")

# Plot initialized
plot_component_norm_pdfs(
    log_component_weights_init,
    component_mus_init,
    log_component_scales_init,
    xmin=-10,
    xmax=10,
    ax=axes[1],
    title="Initialized",
)

# Plot optimized
plot_component_norm_pdfs(
    log_component_weights_opt, 
    component_mus_opt,
    log_component_scales_opt,
    xmin=-10,
    xmax=10,
    ax=axes[2],
    title="Optimized",
)

In [None]:
(
    log_component_weights_history,
    component_mus_history,
    log_component_scales_history
) = adam_get_params(state_history)

In [None]:
%%capture
from celluloid import Camera

fig, ax = plt.subplots()
cam = Camera(fig)

for w, m, s in zip(log_component_weights_history[::100], component_mus_history[::100], log_component_scales_history[::100]):
    ax.hist(data_mixture, bins=40, normed=True, color="blue")
    plot_component_norm_pdfs(
        w, m, s, xmin=-20, xmax=20, ax=ax, title=None,
    )
    cam.snap()
    
animation = cam.animate()

In [None]:
from IPython.display import HTML

HTML(animation.to_html5_video())

In [None]:
losses = vmap(get_loss)(state_history)
plt.plot(losses)
plt.yscale("log");

When I look at the mixture distribution PDFs generated from the optimized weights,
I see something a tad unsatisfactory.
Our optimization routine has given us a mix of Gaussians that struggle to model the ground truth data convincingly.
By occam's razor, we would want to find the _parsimonious_ set of mixture components that give us our data,
i.e. assign those components the largest amount of weight, and assign vanishingly small weights to the rest.
In other words, we should be able to do better on this learning task.

If that sounds appealing to you, then read on. We're going to walk into the world of Bayesian non-parametrics!

## Dirichlet Process Priors

From the previous section, it appeared that simply providing a large number of Gaussian mixture components, initialized with random weighting, was insufficient for learning the true number of Gaussian components (at least in simulated data). We need a better way of approaching the problem.

We'll try formulating the problem slightly differently. Earlier on, we evaluated the likelihood of our weights matrix under a Dirichlet distribution with equally-distributed concentrations. The prior of equally distributed concentrations reflects our belief that a every component could contribute more or less equally to the observed data. If instead we wanted to express the prior belief that a constrained set of components were responsible for the data, we need a Dirichlet process prior with a concentration term that governs how many components to give weighting to.

### Dirichlet Processes

A Dirichlet process expresses the idea that there are an infinite number of possible states. It is governed by a "concentration" parameter, which specifies how "concentrated" probability mass is assigned across the infinite number of states. From a practical perspective, though, we don't use "infinite" states, but rather a "countably large number" of states, just as we did above. 

Let's first explore how to generate a Dirichlet-distributed set of weights by using the "stick-breaking" process.
The key idea is simple.
We take a length 1 stick, draw a probability value from a Beta distribution, break the length 1 stick into two at the point drawn, and record the left side's value. We then take the right side, draw another probability value from a Beta distribution again, break that stick into two portions at the point drawn. and record the absolute length of the left side's value, and break the right side again. We repeat this until we have the countably large number of states that we desire. In code, the process looks like a `lax.scan`-ed function:

In [None]:
def stick_breaking_weights(beta_draws):
    """Return weights from a stick breaking process.
    
    :param beta_draws: i.i.d draws from a Beta distribution.
        This should be a row vector.
    """
    def weighting(occupied_probability, beta_i):
        """
        :param occupied_probability: The cumulative occupied probability taken up.
        :param beta_i: Current value of beta to consider.
        """
        weight = (1 - occupied_probability) * beta_i
        return occupied_probability + weight, weight
    
    occupied_probability, weights = lax.scan(weighting, np.array(0.), beta_draws)
    
    weights = weights / np.sum(weights)
    return occupied_probability, weights

We can visualize what one draw with 50 possible slots looks like:

In [None]:
concentration = 3
beta_draws = random.beta(key=key, a=1, b=concentration, shape=(50,))
occupied_probability, weights = stick_breaking_weights(beta_draws)
plt.plot(weights);

As you can see here, we have most of the probability mass concentrated on the first few states.

If we plotted multiple draws from the same concentration value, what might it look like?

In [None]:
beta_draws = random.beta(key=key, a=1, b=concentration, shape=(20, 50))
occupied_probability, weights = vmap(stick_breaking_weights)(beta_draws)

import seaborn as sns
sns.heatmap(weights)

As is visible, over 20 realizations, most of the probability mass is concentrated in the first few states.

Now, what if we wanted to see the effect of varying concentration? This is another `vmap`!

In [None]:
concentrations = np.array([0.5, 1, 3, 5, 10, 20])

def weights_one_concentration(concentration, num_draws, num_components):
    beta_draws = random.beta(key=key, a=1, b=concentration, shape=(num_draws, num_components))
    occupied_probability, weights = vmap(stick_breaking_weights)(beta_draws)
    return occupied_probability, weights

occupied_probabilities, weights = vmap(partial(weights_one_concentration, num_draws=20, num_components=50))(concentrations)
weights.shape

`weights` is now a matrix of size (6, 20, 50), which corresponds to 6 concentrations, 20 i.i.d draws each, with 50 component weights available.

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(3*3, 3*2))

for ax, weights_mat, conc in zip(axes.flatten(), weights, concentrations):
    sns.heatmap(weights_mat, ax=ax)
    ax.set_title(f"Concentration = {conc}")
plt.tight_layout()

As we increase the concentration value, the probabilities get more diffuse.

From this forward process of generating Dirichlet-distributed weights,
instead of evaluating the log likelihood of the component weights
under a "fixed" Dirichlet distribution prior,
we can instead evaluate it under a Dirichlet process with a "concentration" prior.
The requirement here is that we be able to recover correctly the i.i.d. Beta draws
that generated the Dirichlet process weights.

Let's try that out.

In [None]:
def beta_draw_from_weights(weights, tol=1e-8):
    def beta_from_w(accounted_probability, weights_i):
        """
        :param accounted_probability: The cumulative probability acounted for.
        :param weights_i: Current value of weights to consider.
        """
        denominator = 1 - accounted_probability
        log_denominator = np.log(denominator)
        
        log_beta_i = np.log(weights_i) - log_denominator

        newly_accounted_probability = accounted_probability + weights_i
        
        return newly_accounted_probability, np.exp(log_beta_i)
    final, betas = lax.scan(beta_from_w, np.array(0.), weights)
    return final, betas

In [None]:
concentration = 3
beta_draws = random.beta(key=key, a=1, b=concentration, shape=(50,))
occupied_probability, weights = stick_breaking_weights(beta_draws)
final, beta_hat = beta_draw_from_weights(weights)
plt.plot(beta_draws, label="original")
plt.plot(beta_hat, label="inferred");

As is visible from the plot above, we were able to recover about 1/2 to 2/3 of the weights
before the divergence in the two curves shows up.

One of the difficulties that we have is that when we get back the observed weights in real life,
we have no access to how much of the length 1 "stick" is leftover.
This alongside numerical underflow issues arising from small numbers
means we can only use about 1/2 of the drawn weights
to recover the Beta-distributed draws
from which we can evaluate our log likelihoods.
Let's try performing that evaluation.

In [None]:
def component_probs_loglike(log_component_probs, log_concentration):
    """
    :param log_concentration: Real-valued scalar.
    """
    concentration = np.exp(log_concentration)
    component_probs = normalize_weights(np.exp(log_component_probs))
    _, beta_draws = beta_draw_from_weights(component_probs)
    num_components = np.floor(len(beta_draws) / 2).astype(np.int32)
    return np.sum(stats.beta.logpdf(x=beta_draws[:num_components], a=1, b=concentration))

component_probs_loglike(np.log(weights), log_concentration=1.0)

Once again, let's build up our understanding by seeing how the log likelihood of our weights
under an assumed Dirichlet process from a Beta distribution
changes as we vary the concentration parameter.

In [None]:
log_concentrations = np.linspace(-3, 3, 10000)
logps = vmap(partial(component_probs_loglike, np.log(weights)))(log_concentrations)
plt.plot(np.exp(log_concentrations), logps)
plt.xlabel("concentration")
plt.ylabel("logp")

Looks quite good. Let's see if we can visualize how the log probability changes with multiple weights draws from a Dirichlet process.

In [None]:
num_draws = 20
num_components = 50
concentration = 3
beta_draws = random.beta(key=key, a=1, b=concentration, shape=(num_draws, num_components))
occupied_probability, weights = vmap(stick_breaking_weights)(beta_draws)

def logp_curve(log_weights_vector, log_concentrations):
    """Logp curve for one weights vector."""
    logps = vmap(partial(component_probs_loglike, log_weights_vector))(log_concentrations)
    return logps

logps = vmap(partial(logp_curve, log_concentrations=log_concentrations))(np.log(weights))

for logp in logps:
    plt.plot(np.exp(log_concentrations), logp)
    plt.xlabel("concentration")
    plt.ylabel("logp")

Over multiple realizations of weights, we see that we should be able to approximately recover the true concentration value
if we used gradient descent.
This gives us hope!

Let us now write down the log likelihood for the full probabilistic model.
We can leverage some of the components we have already written before.
`mixture_loglike` is the one that we want to start with.

In [None]:
def joint_loglike(log_component_weights, log_concentration, component_mus, log_component_scales, observed_data):
    
    # logpdf of weights under concentrations prior
    logp_weights = component_probs_loglike(log_component_weights, log_concentration)
    
    logp_observed_data = mixture_loglike(log_component_weights, component_mus, log_component_scales, observed_data)
    return logp_weights + logp_observed_data

In [None]:
def joint_loss(params, data):
    log_component_weights, log_concentration, component_mus, log_component_scales = params
    
    nll = -joint_loglike(*params, observed_data=data)
    
    return nll + np.squeeze(log_concentration ** 2)

djoint_loss = grad(joint_loss)

In [None]:
k1, k2, k3, k4 = random.split(key, 4)
n_components = 50

log_component_weights_init = random.normal(k1, shape=(n_components,))
log_concentration_init = random.normal(k2, shape=(1,))
component_mus_init = random.normal(k3, shape=(n_components,))
log_component_scales_init = random.normal(k4, shape=(n_components,))

params_init = log_component_weights_init, log_concentration_init, component_mus_init, log_component_scales_init



In [None]:
joint_loss(params_init, data_mixture)

In [None]:
adam_init, adam_get_params, adam_update = adam(0.005)
step_scannable = make_step_scannable(
    get_params_func=adam_get_params,
    dloss_func=djoint_loss,
    update_func=adam_update,
    data=data_mixture, 
)

In [None]:
initial_state = adam_init(params_init)

final_state, state_history = lax.scan(step_scannable, initial_state, np.arange(1000))

In [None]:
params_history = adam_get_params(state_history)
params_history

Now, we'd like to learn the concentration parameter for the component probs.

Now that we can calculate the component logpdfs, let's jointly look at them together.

In [None]:
def joint_loss(params, data):
    log_component_probs, log_concentration, component_mus = params
    component_probs = np.exp(log_component_probs)
    
    # component probability distribution logpdf against beta distribution
    comp_probs_logp = component_probs_logpdf(component_probs, log_concentration)
    
    # mixture distribution logpdf
    mixture_logp = mixture_loglike(component_probs, component_mus, data)
    
    total_logp = comp_probs_logp + mixture_logp
    regularization = np.power(log_concentration, 4)
    return -total_logp + regularization

In [None]:
djoint_loss = grad(joint_loss)

concentration_init = 3.

params_init = log_component_probs_init, np.log(concentration_init), component_mus_init
joint_loss(params_init, observed_data)

In [None]:
params_opt = optimize_params(params_init, data_mixture, djoint_loss, n_iter=2000)

In [None]:
log_component_probs_opt, log_concentration_opt, component_mus_opt = params_opt

In [None]:
component_probs_opt = np.exp(log_component_probs_opt)
component_probs_opt = component_probs_opt / component_probs_opt.sum()
component_probs_opt

In [None]:
component_mus_opt

In [None]:
concentration_opt = np.exp(log_concentration_opt)
concentration_opt, concentration_init

In [None]:
plot_component_norm_pdfs(component_probs_opt, component_mus_opt, -10, 20)

In [None]:
plot_component_norm_pdfs(np.exp(log_component_probs_init), component_mus_init, -10, 20)