### Application: dirichlet mixture model

We will now see some of the ingredients in action in a simple but more realistic setting and write a dirichlet mixture model in GenJAX.

In [None]:
import jax.numpy as jnp

from genjax import categorical, dirichlet, gen, normal
from genjax._src.core.pytree import Const, Pytree


@Pytree.dataclass
class Cluster(Pytree):
    mean: float


@gen
def generate_cluster():
    mean = normal(0, 1) @ "mean"
    return Cluster(mean)


@gen
def generate_datapoint(probs, clusters):
    idx = categorical(jnp.log(probs)) @ "idx"
    obs = normal(clusters[idx].mean, 1) @ "obs"
    return obs


@gen
def generate_data(n_clusters: Const[int], n_datapoints: Const[int], alpha: float):
    clusters = generate_cluster.repeat(n=n_clusters.unwrap())() @ "clusters"
    probs = (
        dirichlet(alpha / n_clusters.unwrap() * jnp.ones(n_clusters.unwrap())) @ "probs"
    )
    datapoints = (
        generate_datapoint.repeat(n=n_datapoints.unwrap)(probs, clusters) @ "datapoints"
    )
    return datapoints

In [None]:
# def infer(datapoints):
# Initialize a trace
# For N iterations:
# Gibbs update on `("clusters", i, "mean")` for each i, in parallel
# Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
# Gibbs update on `probs`