### Application: dirichlet mixture model

We will now see some of the ingredients in action in a simple but more realistic setting and write a dirichlet mixture model in GenJAX.

In [None]:
import jax
import jax.numpy as jnp

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import categorical, dirichlet, gen, normal, pretty
from genjax._src.core.pytree import Const, Pytree

pretty()
key = jax.random.PRNGKey(0)

# Hyper parameters
PRIOR_MEAN = 0.0
PRIOR_VARIANCE = 10.0
OBS_VARIANCE = 1.0
ALPHA = 1.0
N_DATAPOINTS = 10000
N_CLUSTERS = 10
N_ITER = 1000


@Pytree.dataclass
class Cluster(Pytree):
    mean: float


@gen
def generate_cluster(mean, var):
    cluster_mean = normal(mean, var) @ "mean"
    return Cluster(cluster_mean)


@gen
def generate_cluster_weight(alphas):
    probs = dirichlet(alphas) @ "probs"
    return probs


@gen
def generate_datapoint(probs, clusters):
    idx = categorical(jnp.log(probs)) @ "idx"
    obs = normal(clusters.mean[idx], OBS_VARIANCE) @ "obs"
    return obs


@gen
def generate_data(n_clusters: Const[int], n_datapoints: Const[int], alpha: float):
    clusters = (
        generate_cluster.repeat(n=n_clusters.unwrap())(PRIOR_MEAN, PRIOR_VARIANCE)
        @ "clusters"
    )

    probs = generate_cluster_weight.inline(
        alpha / n_clusters.unwrap() * jnp.ones(n_clusters.unwrap())
    )

    datapoints = (
        generate_datapoint.repeat(n=n_datapoints.unwrap())(probs, clusters)
        @ "datapoints"
    )

    return datapoints

In [None]:
# posterior_means = jnp.ones(10)
# posterior_variances = jnp.ones(10)

# key, subkey = jax.random.split(key)
# new_means = (
#     generate_cluster.vmap()
#     .simulate(key, (posterior_means, posterior_variances))
#     .get_retval()
# )
# new_means

# key, subkey = jax.random.split(key)
# new_means = (
#     generate_cluster.vmap()
#     .simulate(key, (posterior_means, posterior_variances))
#     .get_choices()["mean"]
# )
# argdiffs = genjax.Diff.no_change(tr.args)
# new_trace, _, _, _ = tr.update(subkey, C["clusters", "mean"].set(new_means), argdiffs)

In [None]:
datapoints = C["datapoints", "obs"].set(
    jnp.concatenate((
        jax.random.uniform(jax.random.PRNGKey(0), shape=(1000,)),
        3 * jax.random.uniform(jax.random.PRNGKey(1), shape=(1000,)),
        5 * jax.random.uniform(jax.random.PRNGKey(2), shape=(1000,)),
        7 * jax.random.uniform(jax.random.PRNGKey(3), shape=(1000,)),
        9 * jax.random.uniform(jax.random.PRNGKey(4), shape=(1000,)),
        11 * jax.random.uniform(jax.random.PRNGKey(5), shape=(1000,)),
        13 * jax.random.uniform(jax.random.PRNGKey(6), shape=(1000,)),
        15 * jax.random.uniform(jax.random.PRNGKey(7), shape=(1000,)),
        17 * jax.random.uniform(jax.random.PRNGKey(8), shape=(1000,)),
        19 * jax.random.uniform(jax.random.PRNGKey(9), shape=(1000,)),
    ))
)

In [None]:
def infer(datapoints):
    key = jax.random.PRNGKey(342)
    args = (Const(N_CLUSTERS), Const(N_DATAPOINTS), ALPHA)
    key, subkey = jax.random.split(key)
    tr, _ = generate_data.importance(subkey, datapoints, args)

    for _ in range(N_ITER):
        # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
        key, subkey = jax.random.split(key)
        tr = update_cluster_means(subkey, tr)

        # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
        key, subkey = jax.random.split(key)
        tr = update_datapoint_assignment(subkey, tr)

        # Gibbs update on `probs`
        key, subkey = jax.random.split(key)
        tr = update_cluster_weights(subkey, tr)

    return tr


def update_cluster_means(key, trace):
    # We can update each cluster in parallel
    # For each cluster, we find the datapoints in that cluster and compute their mean
    datapoint_indexes = trace.get_choices()["datapoints", "index"]
    datapoints = trace.get_choices()["datapoints", "obs"]
    n_clusters = trace.get_args[1].unwrap()
    cluster_means = jax.vmap(
        lambda i: jnp.mean(datapoints[datapoint_indexes == i]),
        in_axes=(None,),
        out_axes=(0),
    )(n_clusters)

    # Count number of points per cluster
    category_counts = jnp.bincount(
        trace.get_choices["datapoints", "index"], minlength=n_clusters
    )

    # Conjugate update for Normal-iid-Normal distribution
    # See https://people.eecs.berkeley.edu/~jordan/courses/260-spring10/lectures/lecture5.pdf
    posterior_means = (
        PRIOR_VARIANCE
        / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts)
        * cluster_means
        + OBS_VARIANCE / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts) * PRIOR_MEAN
    )
    posterior_variances = 1 / (1 / PRIOR_VARIANCE + category_counts / OBS_VARIANCE)

    # Gibbs resampling of cluster means
    key, subkey = jax.random.split(key)
    new_means = (
        generate_cluster.vmap()
        .simulate(key, (posterior_means, posterior_variances))
        .get_choices()["mean"]
    )
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["clusters", "mean"].set(new_means), argdiffs
    )
    return new_trace


def update_datapoint_assignment(key, trace):
    pass

    # Conjugate update for a categorical is just exact posterior via enumeration
    # P(i | xs ) = P(i, xs) \ sum_i P(i, xs)


def update_cluster_weights(key, trace):
    # Count number of points per cluster
    n_clusters = trace.get_args[1].unwrap()
    category_counts = jnp.bincount(
        trace.get_choices["datapoints", "index"], minlength=n_clusters
    )

    # Conjugate update for Dirichlet distribution
    # See https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial
    new_alpha = ALPHA / n_clusters * jnp.ones(n_clusters) + category_counts

    # Gibbs resampling of cluster_weights
    key, subkey = jax.random.split(key)
    new_probs = generate_cluster_weight.simulate(key, new_alpha)
    argdiffs = genjax.Diff.no_change(trace.args())
    new_trace, _, _, _ = trace.update(subkey, C["probs"].set(new_probs), argdiffs)
    return new_trace