### Application: dirichlet mixture model

We will now see some of the ingredients in action in a simple but more realistic setting and write a dirichlet mixture model in GenJAX.

The goal here is to cluster datapoints on the real line.
To do so, we model a fixed number of clusters, each as a 1D-Gaussian with fixed variance, and we want to infer their means.

In more details, the model of the world postulates a fixed number of 1D Gaussians.
Each Gaussian is assigned a weight to represent the proportion of the number of points assigned to each cluster.
Finally, each datapoint belongs to a cluster, separated proportionally to cluster weights.

We turn this into a generative model as follows.
We have a fix prior mean and variance for where the clusters centres might be.
We sample a mean for each cluster. 
We sample the cluster weights.
For each datapoint, 
- we sample a cluster assignment for that data point proportional to the cluster weights
- we sampled the TODO: add explanation for what's going on and what we will do and why.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import categorical, dirichlet, gen, normal, pretty
from genjax._src.core.pytree import Const, Pytree

pretty()
key = jax.random.key(0)

# Hyper parameters
PRIOR_MEAN = 10.0
PRIOR_VARIANCE = 10.0
OBS_VARIANCE = 1.0
ALPHA = 1.0
N_DATAPOINTS = 10000
N_CLUSTERS = 10
N_ITER = 1000


@Pytree.dataclass
class Cluster(Pytree):
    mean: float


@gen
def generate_cluster(mean, var):
    cluster_mean = normal(mean, var) @ "mean"
    return Cluster(cluster_mean)


@gen
def generate_cluster_weight(alphas):
    probs = dirichlet(alphas) @ "probs"
    return probs


@gen
def generate_datapoint(probs, clusters):
    idx = categorical(jnp.log(probs)) @ "idx"
    obs = normal(clusters.mean[idx], OBS_VARIANCE) @ "obs"
    return obs


@gen
def generate_data(n_clusters: Const[int], n_datapoints: Const[int], alpha: float):
    clusters = (
        generate_cluster.repeat(n=n_clusters.unwrap())(PRIOR_MEAN, PRIOR_VARIANCE)
        @ "clusters"
    )

    probs = generate_cluster_weight.inline(
        alpha / n_clusters.unwrap() * jnp.ones(n_clusters.unwrap())
    )

    datapoints = (
        generate_datapoint.repeat(n=n_datapoints.unwrap())(probs, clusters)
        @ "datapoints"
    )

    return datapoints

In [None]:
datapoints = C["datapoints", "obs"].set(
    jnp.concatenate((
        jax.random.uniform(jax.random.key(0), shape=(1000,)),
        3 + jax.random.uniform(jax.random.key(1), shape=(1000,)),
        5 + jax.random.uniform(jax.random.key(2), shape=(1000,)),
        7 + jax.random.uniform(jax.random.key(3), shape=(1000,)),
        9 + jax.random.uniform(jax.random.key(4), shape=(1000,)),
        11 + jax.random.uniform(jax.random.key(5), shape=(1000,)),
        13 + jax.random.uniform(jax.random.key(6), shape=(1000,)),
        15 + jax.random.uniform(jax.random.key(7), shape=(1000,)),
        17 + jax.random.uniform(jax.random.key(8), shape=(1000,)),
        19 + jax.random.uniform(jax.random.key(9), shape=(1000,)),
    ))
)

In [None]:
def infer(datapoints):
    key = jax.random.key(3421)
    args = (Const(N_CLUSTERS), Const(N_DATAPOINTS), ALPHA)
    key, subkey = jax.random.split(key)
    tr, _ = generate_data.importance(subkey, datapoints, args)

    # TODO: rewrite using scan
    # def update(carry, _):
    #     key, tr = carry
    #     # # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
    #     # key, subkey = jax.random.split(key)
    #     # tr = update_cluster_means(subkey, tr)

    #     # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
    #     key, subkey = jax.random.split(key)
    #     tr = update_datapoint_assignment(subkey, tr)

    #     # Gibbs update on `probs`
    #     key, subkey = jax.random.split(key)
    #     tr = update_cluster_weights(subkey, tr)

    #     return (tr, key)

    # tr, _ = jax.lax.scan(update, (key, tr), None, length=N_ITER)

    for _ in range(N_ITER):
        # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
        key, subkey = jax.random.split(key)
        tr = jax.jit(update_cluster_means)(subkey, tr)

        # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
        key, subkey = jax.random.split(key)
        tr = jax.jit(update_datapoint_assignment)(subkey, tr)

        # Gibbs update on `probs`
        key, subkey = jax.random.split(key)
        tr = jax.jit(update_cluster_weights)(subkey, tr)

    return tr


def update_cluster_means(key, trace):
    # We can update each cluster in parallel
    # For each cluster, we find the datapoints in that cluster and compute their mean
    datapoint_indexes = trace.get_choices()["datapoints", "idx"]
    datapoints = trace.get_choices()["datapoints", "obs"]
    n_clusters = trace.get_args()[0].unwrap()
    cluster_means = jax.vmap(
        lambda i: jnp.mean(jnp.where(datapoint_indexes == i, datapoints, 0)),
        in_axes=(0),
        out_axes=(0),
    )(jnp.arange(n_clusters))

    # Count number of points per cluster
    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )
    # Conjugate update for Normal-iid-Normal distribution
    # See https://people.eecs.berkeley.edu/~jordan/courses/260-spring10/lectures/lecture5.pdf
    posterior_means = (
        PRIOR_VARIANCE
        / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts)
        * cluster_means
        + OBS_VARIANCE / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts) * PRIOR_MEAN
    )
    posterior_variances = 1 / (1 / PRIOR_VARIANCE + category_counts / OBS_VARIANCE)

    # Gibbs resampling of cluster means
    key, subkey = jax.random.split(key)
    new_means = (
        generate_cluster.vmap()
        .simulate(key, (posterior_means, posterior_variances))
        .get_choices()["mean"]
    )
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["clusters", "mean"].set(new_means), argdiffs
    )
    return new_trace


def update_datapoint_assignment(key, trace):
    # We want to update the index for each datapoint, in parallel.
    # It means we want to resample the i, but instead of being from the prior
    # P(i | probs), we do it from the local posterior P(i | probs, xs).
    # We need to do it for all addresses ["datapoints", "idx", i],
    # and as these are independent (when conditioned on the rest)
    # we can resample them in parallel.

    # Conjugate update for a categorical is just exact posterior via enumeration
    # P(x | y ) = P(x, y) \ sum_x P(x, y).
    # Sampling from Categorical(P(x = 1 | y ), P(x = 2 | y), ...) is the same as
    # sampling from Categorical(P(x = 1, y), P(x = 2, y))
    # as the weights need not be normalized
    def compute_local_density(x, i):
        datapoint_mean = trace.get_choices()["datapoints", "obs", x]
        chm = C["obs"].set(datapoint_mean).at["idx"].set(i)
        clusters = Cluster(trace.get_choices()["clusters", "mean"])
        probs = trace.get_choices()["probs"]
        args = (probs, clusters)
        model_logpdf, _ = generate_datapoint.assess(chm, args)
        return model_logpdf

    n_clusters = trace.get_args()[0].unwrap()
    n_datapoints = trace.get_args()[1].unwrap()
    local_densities = jax.vmap(
        lambda x: jax.vmap(lambda i: compute_local_density(x, i))(
            jnp.arange(n_clusters)
        )
    )(jnp.arange(n_datapoints))

    # Conjugate update by sampling from posterior categorical
    # Note: I think I could've used something like
    # generate_datapoint.vmap().importance which would perhaps
    # work more generally but would definitely be slower here
    key, subkey = jax.random.split(key)
    new_datapoint_indexes = (
        genjax.categorical.vmap().simulate(key, (local_densities,)).get_choices()
    )
    # Gibbs resampling of datapoint assignment to clusters
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["datapoints", "idx"].set(new_datapoint_indexes), argdiffs
    )
    return new_trace


def update_cluster_weights(key, trace):
    # Count number of points per cluster
    n_clusters = trace.get_args()[0].unwrap()
    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )

    # Conjugate update for Dirichlet distribution
    # See https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial
    new_alpha = ALPHA / n_clusters * jnp.ones(n_clusters) + category_counts

    # Gibbs resampling of cluster weights
    key, subkey = jax.random.split(key)
    new_probs = generate_cluster_weight.simulate(key, (new_alpha,)).get_choices()
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(subkey, C["probs"].set(new_probs), argdiffs)
    return new_trace

Running inference.

In [None]:
posterior_trace = infer(datapoints)

Plotting results

In [None]:
datapoint = datapoints["datapoints", "obs"]
posterior_means = posterior_trace.get_choices()["clusters", "mean"]
posterior_weights = posterior_trace.get_choices()["probs"]

# Plotting datapoints
plt.scatter(datapoint, jnp.zeros_like(datapoint), alpha=0.1, label="Datapoints", s=20)

# Plotting posterior means with size proportional to posterior_weights
for i, (mean, weight) in enumerate(zip(posterior_means, posterior_weights)):
    plt.scatter(
        mean,
        0,
        color=f"C{i}",
        label=f"Cluster {i + 1} Mean (Prob: {weight:.6f})",
        s=100 + weight * 600,
        alpha=1,
    )

# Plotting standard deviation of the Gaussian means
for i, (mean, weight) in enumerate(zip(posterior_means, posterior_weights)):
    plt.errorbar(
        mean, 0, xerr=jnp.sqrt(PRIOR_VARIANCE), fmt="o", color=f"C{i}", capsize=5
    )

plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3)

plt.show()