In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
from jax import vmap, grad
import jax.numpy as np
from jax.scipy import stats
from jax import random
from jax.scipy.special import logsumexp
from functools import partial

## Generate mixture gaussians

In [None]:
import matplotlib.pyplot as plt
weights_true = np.array([1, 5])
means_true = np.array([-2, 3])

base_n_draws = 1000
key = random.PRNGKey(100)

draws_1 = random.normal(key, shape=(base_n_draws * weights_true[0],)) + means_true[0]
draws_2 = random.normal(key, shape=(base_n_draws * weights_true[1],)) + means_true[1]
data_mixture = np.concatenate([draws_1, draws_2])
plt.hist(data_mixture)

Firstly, try using a two-component mixture distribution to calculate log-likelihood of data under parameters.

The likelihood of the each data point is the sum of the likelihood of each data point under each of the components.

The likelihood of a data point under a component $i$ is the likelihood of drawing that component * likelihood of observing that data point under that component's distribution. In Math:

$$L_i = P(x|\mu_i, w_i)P(w_i)$$

In [None]:
def loglike_one_component(component_prob, component_mu, datum):
    """Log likelihood of datum under one component of the mixture."""
    return np.log(component_prob) + stats.norm.logpdf(datum, loc=component_mu, scale=1)

Now, we test-drive it:

In [None]:
loglike_one_component(component_prob=0.25, component_mu=0., datum=0.)

In [None]:
def loglike_across_components(component_probs, component_mus, datum):
    """Log likelihood of datum under all components of the mixture."""
    component_probs = component_probs / np.sum(component_probs, axis=-1)
    return logsumexp(vmap(partial(loglike_one_component, datum=datum))(component_probs, component_mus))

In [None]:
loglike_across_components(component_probs=np.array([10, 0.1]), component_mus=np.array([0., 3]), datum=0.)

In [None]:
def mixture_loglike(component_probs, component_mus, data):
    """Log likelihood of data (not datum!) under all components of the mixture."""
    ll_per_data = vmap(partial(loglike_across_components, component_probs, component_mus))(data)
    return np.sum(ll_per_data)

In [None]:
mixture_loglike(component_probs=np.array([0.1, 0.1]), component_mus=np.array([0., 3]), data=np.array([0., 0., 0., 3., 3., 3.,]))

In [None]:
def loss(params, data):
    log_component_probs, component_mus = params
    component_probs = np.exp(log_component_probs)
    return -mixture_loglike(component_probs, component_mus, data)

dloss = grad(loss)

N_MIXTURE_COMPONENTS = 10

log_component_probs_init = np.abs(random.normal(key, shape=(N_MIXTURE_COMPONENTS,)))
component_mus_init = 10 * random.normal(key, shape=(N_MIXTURE_COMPONENTS,)) # np.array([0., 3., 5., 20.])
observed_data = np.array([0., 0., 0., 3., 3., 3., 3., 3., 3., 3.,])

params_init = log_component_probs_init, component_mus_init

loss(params_init, observed_data), dloss(params_init, observed_data)

Let's try optimizing!

In [None]:
from jax import jit
from jax.experimental.optimizers import adam

def optimize_params(params, data, dloss, n_iter):
    """
    Generic optimizer loop, using ADAM optimizer.
    
    Exists purely for convenience.
    
    :param params: The params to optimize.
    :param data: The data, gets passed into the dloss function.
    :param dloss: A function that returns a scalar loss to be minimized.
        Accepts only `params` and `data`.
    :param n_iter: Number of iterations to optimize for.
    """
    init, update, get_params = adam(0.05)
    get_params = jit(get_params)
    
    @jit
    def step(i, state):
        params = get_params(state)
        g = dloss(params, data)
        state = update(i, g, state)
        return state

    state = init(params)
    for i in range(n_iter):
        state = step(i, state)
    return get_params(state)

In [None]:
params_opt = optimize_params(params_init, data_mixture, dloss, n_iter=1000)

In [None]:
loss(params_opt, data_mixture)

In [None]:
log_component_probs_opt = params_opt[0]
component_probs_opt = np.exp(params_opt[0])
component_probs_opt = component_probs_opt / np.sum(component_probs_opt)
component_probs_opt

In [None]:
means = params_opt[1]
means

## What do the mixture PDFs look like here?

In [None]:
from jax.scipy.stats import norm

def plot_component_norm_pdfs(component_probs, component_mus, xmin, xmax):
    x = np.linspace(xmin, xmax, 1000).reshape(-1,1)
    pdfs = component_probs * norm.pdf(x, loc=component_mus)
    for component in range(pdfs.shape[1]):
        plt.plot(x, pdfs[:, component])

In [None]:
plot_component_norm_pdfs(np.exp(log_component_probs_init), component_mus_init, -10, 20)

Now, we'd like to learn the concentration parameter for the component probs.

In [None]:
from jax import lax
def beta_draw_from_weights(weights, tol=1e-8):
    def beta_from_w(accounted_probability, weights_i):
        """
        :param accounted_probability: The cumulative probability acounted for.
        :param weights_i: Current value of weights to consider.
        """
        denominator = 1 - accounted_probability
        log_denominator = np.log(denominator)
        
        log_beta_i = np.log(weights_i) - log_denominator

        newly_accounted_probability = accounted_probability + weights_i
        
        return newly_accounted_probability, np.exp(log_beta_i)
    final, betas = lax.scan(beta_from_w, np.array(0.), weights)
    return final, betas

In [None]:
def component_probs_logpdf(component_probs, log_concentration):
    """
    :param log_concentration: Real-valued scalar.
    """
    concentration = np.exp(log_concentration)
    component_probs = component_probs / np.sum(component_probs)
    _, beta_draws = beta_draw_from_weights(component_probs)
    return np.sum(stats.beta.logpdf(x=component_probs, a=1, b=concentration))
_, beta_draws = beta_draw_from_weights(component_probs_opt)
beta_draws

component_probs_logpdf(component_probs_opt, log_concentration=1.0)

Now that we can calculate the component logpdfs, let's jointly look at them together.

In [None]:
def joint_loss(params, data):
    log_component_probs, log_concentration, component_mus = params
    component_probs = np.exp(log_component_probs)
    
    # component probability distribution logpdf against beta distribution
    comp_probs_logp = component_probs_logpdf(component_probs, log_concentration)
    
    # mixture distribution logpdf
    mixture_logp = mixture_loglike(component_probs, component_mus, data)
    
    total_logp = comp_probs_logp + mixture_logp
    regularization = np.power(log_concentration, 4)
    return -total_logp + regularization

In [None]:
djoint_loss = grad(joint_loss)

concentration_init = 3.

params_init = log_component_probs_init, np.log(concentration_init), component_mus_init
joint_loss(params_init, observed_data)

In [None]:
params_opt = optimize_params(params_init, data_mixture, djoint_loss, n_iter=2000)

In [None]:
log_component_probs_opt, log_concentration_opt, component_mus_opt = params_opt

In [None]:
component_probs_opt = np.exp(log_component_probs_opt)
component_probs_opt = component_probs_opt / component_probs_opt.sum()
component_probs_opt

In [None]:
component_mus_opt

In [None]:
concentration_opt = np.exp(log_concentration_opt)
concentration_opt, concentration_init

In [None]:
plot_component_norm_pdfs(component_probs_opt, component_mus_opt, -10, 20)

In [None]:
plot_component_norm_pdfs(np.exp(log_component_probs_init), component_mus_init, -10, 20)