### Application: dirichlet mixture model

We will now see some of the ingredients in action in a simple but more realistic setting and write a dirichlet mixture model in GenJAX.

The goal here is to cluster datapoints on the real line.
To do so, we model a fixed number of clusters, each as a 1D-Gaussian with fixed variance, and we want to infer their means.

In more details, the model of the world postulates a fixed number of 1D Gaussians.
Each Gaussian is assigned a weight to represent the proportion of the number of points assigned to each cluster.
Finally, each datapoint belongs to a cluster, separated proportionally to cluster weights.

We turn this into a generative model as follows.
We have a fix prior mean and variance for where the clusters centres might be.
We sample a mean for each cluster. 
We sample the cluster weights.
For each datapoint, 
- we sample a cluster assignment for that data point proportional to the cluster weights
- we sampled the TODO: add explanation for what's going on and what we will do and why.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import categorical, dirichlet, gen, normal, pretty
from genjax._src.core.pytree import Const, Pytree

pretty()
key = jax.random.key(0)

# Hyper parameters
PRIOR_MEAN = 10.0
PRIOR_VARIANCE = 3.0
OBS_VARIANCE = 1.0
ALPHA = 1.0
N_DATAPOINTS = 1000
N_CLUSTERS = 10
N_ITER = 1000

# Debugging mode
DEBUG = True


@Pytree.dataclass
class Cluster(Pytree):
    mean: float


@gen
def generate_cluster(mean, var):
    cluster_mean = normal(mean, var) @ "mean"
    return Cluster(cluster_mean)


@gen
def generate_cluster_weight(alphas):
    probs = dirichlet(alphas) @ "probs"
    return probs


@gen
def generate_datapoint(probs, clusters):
    idx = categorical(jnp.log(probs)) @ "idx"
    obs = normal(clusters.mean[idx], OBS_VARIANCE) @ "obs"
    return obs


@gen
def generate_data(n_clusters: Const[int], n_datapoints: Const[int], alpha: float):
    clusters = (
        generate_cluster.repeat(n=n_clusters.unwrap())(PRIOR_MEAN, PRIOR_VARIANCE)
        @ "clusters"
    )

    probs = generate_cluster_weight.inline(
        alpha / n_clusters.unwrap() * jnp.ones(n_clusters.unwrap())
    )

    datapoints = (
        generate_datapoint.repeat(n=n_datapoints.unwrap())(probs, clusters)
        @ "datapoints"
    )

    return datapoints

In [None]:
datapoints = C["datapoints", "obs"].set(
    jnp.concatenate([
        jax.random.uniform(jax.random.key(i), shape=(int(N_DATAPOINTS / N_CLUSTERS),))
        + 2 * i
        for i in range(N_CLUSTERS)
    ])
)

In [None]:
def infer(datapoints):
    key = jax.random.key(32421)
    args = (Const(N_CLUSTERS), Const(N_DATAPOINTS), ALPHA)
    key, subkey = jax.random.split(key)
    tr, _ = generate_data.importance(subkey, datapoints, args)

    if DEBUG:
        all_posterior_means = [tr.get_choices()["clusters", "mean"]]
        all_posterior_weights = [tr.get_choices()["probs"]]
        all_cluster_assignment = [tr.get_choices()["datapoints", "idx"]]

        jax.debug.print("Initial means: {v}", v=all_posterior_means[0])
        jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        key, subkey = jax.random.split(key)
        tr = jax.jit(update_cluster_means)(subkey, tr)
        all_posterior_means.append(tr.get_choices()["clusters", "mean"])

        jax.debug.print("Initial means V2: {v}", v=tr.get_choices()["clusters", "mean"])

        # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
        key, subkey = jax.random.split(key)
        tr = jax.jit(update_datapoint_assignment)(subkey, tr)
        all_cluster_assignment.append(tr.get_choices()["datapoints", "idx"])

        jax.debug.print("Initial means V3: {v}", v=tr.get_choices()["clusters", "mean"])

        # Gibbs update on `probs`
        key, subkey = jax.random.split(key)
        tr = jax.jit(update_cluster_weights)(subkey, tr)
        all_posterior_weights.append(tr.get_choices()["probs"])

        jax.debug.print("Initial means V4: {v}", v=tr.get_choices()["clusters", "mean"])

        for _ in range(N_ITER):
            # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_means)(subkey, tr)
            all_posterior_means.append(tr.get_choices()["clusters", "mean"])

            # # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_datapoint_assignment)(subkey, tr)
            all_cluster_assignment.append(tr.get_choices()["datapoints", "idx"])

            # # Gibbs update on `probs`
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_weights)(subkey, tr)
            all_posterior_weights.append(tr.get_choices()["probs"])

        return all_posterior_means, all_posterior_weights, all_cluster_assignment, tr

    else:

        def update(carry, _):
            key, tr = carry
            # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_means)(subkey, tr)

            # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_datapoint_assignment)(subkey, tr)

            # Gibbs update on `probs`
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_weights)(subkey, tr)
            return (key, tr), None

        (key, tr), _ = jax.lax.scan(update, (key, tr), None, length=N_ITER)
        return tr


def update_cluster_means(key, trace):
    # We can update each cluster in parallel
    # For each cluster, we find the datapoints in that cluster and compute their mean
    datapoint_indexes = trace.get_choices()["datapoints", "idx"]
    datapoints = trace.get_choices()["datapoints", "obs"]
    n_clusters = trace.get_args()[0].unwrap()
    current_means = trace.get_choices()["clusters", "mean"]

    # Count number of points per cluster
    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )

    # Will contain some NaN due to clusters having no datapoint
    cluster_means = (
        jax.vmap(
            lambda i: jnp.sum(jnp.where(datapoint_indexes == i, datapoints, 0)),
            in_axes=(0),
            out_axes=(0),
        )(jnp.arange(n_clusters))
        / category_counts
    )

    # Conjugate update for Normal-iid-Normal distribution
    # See https://people.eecs.berkeley.edu/~jordan/courses/260-spring10/lectures/lecture5.pdf
    posterior_means = (
        PRIOR_VARIANCE
        / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts)
        * cluster_means
        + OBS_VARIANCE / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts) * PRIOR_MEAN
    )

    posterior_variances = 1 / (1 / PRIOR_VARIANCE + category_counts / OBS_VARIANCE)

    # Gibbs resampling of cluster means
    key, subkey = jax.random.split(key)
    new_means = (
        generate_cluster.vmap()
        .simulate(key, (posterior_means, posterior_variances))
        .get_choices()["mean"]
    )

    # Remove the sampled Nan due to clusters having no datapoint and pick previous mean in that case, i.e. no Gibbs update for them
    chosen_means = jnp.where(category_counts == 0, current_means, new_means)

    if DEBUG:
        jax.debug.print("Category counts: {v}", v=category_counts)
        jax.debug.print("Current means: {v}", v=cluster_means)
        jax.debug.print("Posterior means: {v}", v=posterior_means)
        posterior_variances = 1 / (1 / PRIOR_VARIANCE + category_counts / OBS_VARIANCE)
        jax.debug.print(fmt="Posterior variance: {v}", v=posterior_variances)
        jax.debug.print("Resampled means: {v}", v=new_means)
        jax.debug.print("Chosen means: {v}", v=chosen_means)

    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["clusters", "mean"].set(chosen_means), argdiffs
    )
    return new_trace


def update_datapoint_assignment(key, trace):
    # We want to update the index for each datapoint, in parallel.
    # It means we want to resample the i, but instead of being from the prior
    # P(i | probs), we do it from the local posterior P(i | probs, xs).
    # We need to do it for all addresses ["datapoints", "idx", i],
    # and as these are independent (when conditioned on the rest)
    # we can resample them in parallel.

    # Conjugate update for a categorical is just exact posterior via enumeration
    # P(x | y ) = P(x, y) \ sum_x P(x, y).
    # P(x | y1, y2) = P(x | y1)
    # Sampling from Categorical(P(x = 1 | y ), P(x = 2 | y), ...) is the same as
    # sampling from Categorical(P(x = 1, y), P(x = 2, y))
    # as the weights need not be normalized
    def compute_local_density(x, i):
        datapoint_mean = trace.get_choices()["datapoints", "obs", x]
        chm = C["obs"].set(datapoint_mean).at["idx"].set(i)
        clusters = Cluster(trace.get_choices()["clusters", "mean"])
        probs = trace.get_choices()["probs"]
        args = (probs, clusters)
        model_logpdf, _ = generate_datapoint.assess(chm, args)
        return model_logpdf

    n_clusters = trace.get_args()[0].unwrap()
    n_datapoints = trace.get_args()[1].unwrap()
    local_densities = jax.vmap(
        lambda x: jax.vmap(lambda i: compute_local_density(x, i))(
            jnp.arange(n_clusters)
        )
    )(jnp.arange(n_datapoints))

    # Conjugate update by sampling from posterior categorical
    # Note: I think I could've used something like
    # generate_datapoint.vmap().importance which would perhaps
    # work more generally but would definitely be slower here
    key, subkey = jax.random.split(key)
    new_datapoint_indexes = (
        genjax.categorical.vmap().simulate(key, (local_densities,)).get_choices()
    )
    # Gibbs resampling of datapoint assignment to clusters
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["datapoints", "idx"].set(new_datapoint_indexes), argdiffs
    )
    return new_trace


def update_cluster_weights(key, trace):
    # Count number of points per cluster
    n_clusters = trace.get_args()[0].unwrap()
    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )

    # Conjugate update for Dirichlet distribution
    # See https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial
    new_alpha = ALPHA / n_clusters * jnp.ones(n_clusters) + category_counts

    # Gibbs resampling of cluster weights
    key, subkey = jax.random.split(key)
    new_probs = generate_cluster_weight.simulate(key, (new_alpha,)).get_retval()

    if DEBUG:
        jax.debug.print(fmt="Category counts: {v}", v=category_counts)
        jax.debug.print(fmt="New alpha: {v}", v=new_alpha)
        jax.debug.print(fmt="New probs: {v}", v=new_probs)
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(subkey, C["probs"].set(new_probs), argdiffs)
    return new_trace

Running inference.

In [None]:
if DEBUG:
    (
        all_posterior_means,
        all_posterior_weights,
        all_cluster_assignment,
        posterior_trace,
    ) = infer(datapoints)
else:
    posterior_trace = infer(datapoints)

Plotting results

In [None]:
# Function to create plot for a given index
def create_plot(idx):
    datapoint = datapoints["datapoints", "obs"]
    posterior_means = all_posterior_means[idx]
    posterior_weights = all_posterior_weights[idx]
    cluster_assignment = all_cluster_assignment[idx]

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 5))

    # Plot datapoints colored by cluster assignment and posterior means together
    for i in range(len(posterior_means)):
        # Only plot points assigned to this cluster
        mask = cluster_assignment == i
        if not jnp.any(mask):  # Skip if no points assigned to this cluster
            continue

        ax.scatter(
            datapoint[mask],
            jnp.zeros_like(datapoint)[mask],
            color=f"C{i}",
            alpha=0.5,
            s=20,
        )

        # Plot posterior means with size proportional to weights
        weight = posterior_weights[i]  # Get current weight for this iteration
        ax.scatter(
            posterior_means[i],
            0,
            color=f"C{i}",
            marker="*",
            s=100 + weight * 600,  # Use current weight for size
            alpha=1,
            label=f"Cluster {i + 1} (Prob: {weight:.6f})",  # Use current weight for label
        )

        # Plot standard deviation of the Gaussian means
        ax.errorbar(
            posterior_means[i],
            0,
            xerr=jnp.sqrt(PRIOR_VARIANCE),
            fmt="o",
            color=f"C{i}",
            capsize=5,
        )

    ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3)
    ax.set_title(f"Iteration {idx}")
    plt.tight_layout()
    return fig


NUM_FRAMES = 50

# Create animation
frames = []
for i in range(NUM_FRAMES):
    fig = create_plot(int(N_ITER / NUM_FRAMES * i))
    # Convert figure to image array
    fig.canvas.draw()
    image = np.frombuffer(
        fig.canvas.buffer_rgba(), dtype=np.uint8
    )  # Updated to use buffer_rgba
    image = image.reshape(
        fig.canvas.get_width_height()[::-1] + (4,)
    )  # Note: buffer_rgba returns RGBA
    frames.append(image[:, :, :3])  # Convert RGBA to RGB by dropping alpha channel
    plt.close(fig)

# Create animation from frames
fig = plt.figure(figsize=(12, 5))
from matplotlib import animation

ani = animation.ArtistAnimation(
    fig,
    [[plt.imshow(frame)] for frame in frames],
    interval=1000,  # 1 second between frames
    blit=True,
)

# Display animation
from IPython.display import HTML

HTML(ani.to_jshtml())

In [None]:
# Example vectors of size 10
category_counts = jnp.array([0, 0, 0, 0, 0, 4, 0, 2, 1, 7])
current_means = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
new_means = jnp.array([1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5])

# Where category_counts is 0, use current_means, otherwise use new_means
chosen_means = jnp.where(category_counts == 0, current_means, new_means)
chosen_means