# Block-Gibbs on Dirichlet Mixture Model

We will now see some of the key ingredients in action in a simple but more realistic setting and write a Dirichlet mixture model in GenJAX.

## Clustering Points on the Real Line

The goal here is to cluster datapoints on the real line. To do so, we model a fixed number of clusters, each as a 1D-Gaussian with fixed variance, and we want to infer their means.

### Model Description

The "model of the world" postulates:
- A fixed number of 1D Gaussians
- Each Gaussian is assigned a weight, representing the proportion of points assigned to each cluster 
- Each datapoint is assigned to a cluster

### Generative Process

We turn this into a generative model as follows:
- We have a fixed prior mean and variance for where the cluster centers might be
- We sample a mean for each cluster
- We sample an initial weight per cluster (sum of weights is 1)
- For each datapoint:
  - We sample a cluster assignment proportional to the cluster weights
  - We sample the datapoint noisily around the mean of the cluster

### Implementation Details

We, the modelers, get to choose how this process is implemented. 
- We choose distributions for each sampling step in a way that makes **inference tractable**.
- More precisely, we choose conjugate pairs so that we can do inference via Gibbs sampling. 
  - Gibbs sampling is an MCMC method that samples an initial trace, and then updates the traced choices we want to infer over time. 
  - To update a choice, Gibbs sampling samples from a conditional distribution, which is tractable with conjugate relationships.

In [None]:
import genstudio.plot as Plot
import jax
import jax.numpy as jnp
import numpy as np

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import categorical, dirichlet, gen, normal, pretty
from genjax._src.core.pytree import Const

pretty()
key = jax.random.key(0)

We define the generative model we described above. It has several hyperparameters that are somewhat manually inferred. An extension to the model could instead do inference over these hyperparameters, and fix hyper-hyperparameters instead.

In [None]:
# Hyper parameters
PRIOR_VARIANCE = 10.0
OBS_VARIANCE = 1.0
N_DATAPOINTS = 5000
N_CLUSTERS = 40
ALPHA = float(N_DATAPOINTS / (N_CLUSTERS * 10))
PRIOR_MEAN = 50.0
N_ITER = 50

# Debugging mode
DEBUG = True


# Sub generative functions of the bigger model
@gen
def generate_cluster(mean, var):
    cluster_mean = normal(mean, var) @ "mean"
    return cluster_mean


@gen
def generate_cluster_weight(alphas):
    probs = dirichlet(alphas) @ "probs"
    return probs


@gen
def generate_datapoint(probs, clusters):
    idx = categorical(jnp.log(probs)) @ "idx"
    obs = normal(clusters[idx], OBS_VARIANCE) @ "obs"
    return obs


# Main model
@gen
def generate_data(n_clusters: Const[int], n_datapoints: Const[int], alpha: float):
    clusters = (
        generate_cluster.repeat(n=n_clusters.unwrap())(PRIOR_MEAN, PRIOR_VARIANCE)
        @ "clusters"
    )

    probs = generate_cluster_weight.inline(
        alpha / n_clusters.unwrap() * jnp.ones(n_clusters.unwrap())
    )

    datapoints = (
        generate_datapoint.repeat(n=n_datapoints.unwrap())(probs, clusters)
        @ "datapoints"
    )

    return datapoints

We create some synthetic data to test inference.

In [None]:
# Generate synthetic data with N_CLUSTERS clusters evenly spaced
points_per_cluster = int(N_DATAPOINTS / N_CLUSTERS)
cluster_indices = jnp.arange(N_CLUSTERS)
offsets = PRIOR_VARIANCE * (-4 + 8 * cluster_indices / N_CLUSTERS)

# Create keys for each cluster
keys = jax.random.split(jax.random.key(0), N_CLUSTERS)

# Generate uniform random points for each cluster
uniform_points = jax.vmap(lambda k: jax.random.uniform(k, shape=(points_per_cluster,)))(
    keys
)

# Add offset and prior mean to each cluster's points
shifted_points = uniform_points + (PRIOR_MEAN + offsets[:, None])

datapoints = C["datapoints", "obs"].set(shifted_points.reshape(-1))

We now write the main inference loop. As we said at the beginning, we do MCMC via Gibbs sampling. Inference therefore consist of a main loop and we evolve a trace over time. The final trace contains a sample from the approximate posterior.

In [None]:
def infer(datapoints):
    key = jax.random.key(32421)
    args = (Const(N_CLUSTERS), Const(N_DATAPOINTS), ALPHA)
    key, subkey = jax.random.split(key)
    initial_weights = C["probs"].set(jnp.ones(N_CLUSTERS) / N_CLUSTERS)
    constraints = datapoints | initial_weights
    tr, _ = generate_data.importance(subkey, constraints, args)

    if DEBUG:
        all_posterior_means = [tr.get_choices()["clusters", "mean"]]
        all_posterior_weights = [tr.get_choices()["probs"]]
        all_cluster_assignment = [tr.get_choices()["datapoints", "idx"]]

        jax.debug.print("Initial means: {v}", v=all_posterior_means[0])
        jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        for _ in range(N_ITER):
            # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_means)(subkey, tr)
            all_posterior_means.append(tr.get_choices()["clusters", "mean"])

            # # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_datapoint_assignment)(subkey, tr)
            all_cluster_assignment.append(tr.get_choices()["datapoints", "idx"])

            # # Gibbs update on `probs`
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_weights)(subkey, tr)
            all_posterior_weights.append(tr.get_choices()["probs"])

        return all_posterior_means, all_posterior_weights, all_cluster_assignment, tr

    else:
        # One Gibbs sweep consist of updating each latent variable
        def update(carry, _):
            key, tr = carry
            # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = update_cluster_means(subkey, tr)

            # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = update_datapoint_assignment(subkey, tr)

            # Gibbs update on `probs`
            key, subkey = jax.random.split(key)
            tr = update_cluster_weights(subkey, tr)
            return (key, tr), None

        # Overall inference performs a fixed number of Gibbs sweeps
        (key, tr), _ = jax.jit(jax.lax.scan)(update, (key, tr), None, length=N_ITER)
        return tr


def update_cluster_means(key, trace):
    # We can update each cluster in parallel
    # For each cluster, we find the datapoints in that cluster and compute their mean
    datapoint_indexes = trace.get_choices()["datapoints", "idx"]
    datapoints = trace.get_choices()["datapoints", "obs"]
    n_clusters = trace.get_args()[0].unwrap()
    current_means = trace.get_choices()["clusters", "mean"]

    # Count number of points per cluster
    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )

    # Will contain some NaN due to clusters having no datapoint
    cluster_means = (
        jax.vmap(
            lambda i: jnp.sum(jnp.where(datapoint_indexes == i, datapoints, 0)),
            in_axes=(0),
            out_axes=(0),
        )(jnp.arange(n_clusters))
        / category_counts
    )

    # Conjugate update for Normal-iid-Normal distribution
    # See https://people.eecs.berkeley.edu/~jordan/courses/260-spring10/lectures/lecture5.pdf
    # Note that there's a typo in the math for the posterior mean.
    posterior_means = (
        PRIOR_VARIANCE
        / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts)
        * cluster_means
        + (OBS_VARIANCE / category_counts)
        / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts)
        * PRIOR_MEAN
    )

    posterior_variances = 1 / (1 / PRIOR_VARIANCE + category_counts / OBS_VARIANCE)

    # Gibbs resampling of cluster means
    key, subkey = jax.random.split(key)
    new_means = (
        generate_cluster.vmap()
        .simulate(key, (posterior_means, posterior_variances))
        .get_choices()["mean"]
    )

    # Remove the sampled Nan due to clusters having no datapoint and pick previous mean in that case, i.e. no Gibbs update for them
    chosen_means = jnp.where(category_counts == 0, current_means, new_means)

    if DEBUG:
        jax.debug.print("Category counts: {v}", v=category_counts)
        jax.debug.print("Current means: {v}", v=cluster_means)
        jax.debug.print("Posterior means: {v}", v=posterior_means)
        jax.debug.print(fmt="Posterior variance: {v}", v=posterior_variances)
        jax.debug.print("Resampled means: {v}", v=new_means)
        jax.debug.print("Chosen means: {v}", v=chosen_means)

    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["clusters", "mean"].set(chosen_means), argdiffs
    )
    return new_trace


def update_datapoint_assignment(key, trace):
    # We want to update the index for each datapoint, in parallel.
    # It means we want to resample the i, but instead of being from the prior
    # P(i | probs), we do it from the local posterior P(i | probs, xs).
    # We need to do it for all addresses ["datapoints", "idx", i],
    # and as these are independent (when conditioned on the rest)
    # we can resample them in parallel.

    # Conjugate update for a categorical is just exact posterior via enumeration
    # P(x | y ) = P(x, y) \ sum_x P(x, y).
    # P(x | y1, y2) = P(x | y1)
    # Sampling from Categorical(P(x = 1 | y ), P(x = 2 | y), ...) is the same as
    # sampling from Categorical(P(x = 1, y), P(x = 2, y))
    # as the weights need not be normalized
    # In addition, if the model factorizes as P(x, y1, y2) = P(x, y1)P(y1 | y2),
    # we can further simplify P(y1 | y2) from the categorical as it does not depend on x. More generally We only need to look at the children and parents of x ("idx" in our situation, which are conveniently wrapped in the generate_datapoint generative function).
    def compute_local_density(x, i):
        datapoint_mean = trace.get_choices()["datapoints", "obs", x]
        chm = C["obs"].set(datapoint_mean).at["idx"].set(i)
        clusters = trace.get_choices()["clusters", "mean"]
        probs = trace.get_choices()["probs"]
        args = (probs, clusters)
        model_logpdf, _ = generate_datapoint.assess(chm, args)
        return model_logpdf

    n_clusters = trace.get_args()[0].unwrap()
    n_datapoints = trace.get_args()[1].unwrap()
    local_densities = jax.vmap(
        lambda x: jax.vmap(lambda i: compute_local_density(x, i))(
            jnp.arange(n_clusters)
        )
    )(jnp.arange(n_datapoints))

    # Conjugate update by sampling from posterior categorical
    # Note: I think we could've used something like
    # generate_datapoint.vmap().importance which would perhaps
    # work in a more general setting but would definitely be slower here.
    key, subkey = jax.random.split(key)
    new_datapoint_indexes = (
        genjax.categorical.vmap().simulate(key, (local_densities,)).get_choices()
    )
    # Gibbs resampling of datapoint assignment to clusters
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["datapoints", "idx"].set(new_datapoint_indexes), argdiffs
    )
    return new_trace


def update_cluster_weights(key, trace):
    # Count number of points per cluster
    n_clusters = trace.get_args()[0].unwrap()
    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )

    # Conjugate update for Dirichlet distribution
    # See https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial
    new_alpha = ALPHA / n_clusters * jnp.ones(n_clusters) + category_counts

    # Gibbs resampling of cluster weights
    key, subkey = jax.random.split(key)
    new_probs = generate_cluster_weight.simulate(key, (new_alpha,)).get_retval()

    if DEBUG:
        jax.debug.print(fmt="Category counts: {v}", v=category_counts)
        jax.debug.print(fmt="New alpha: {v}", v=new_alpha)
        jax.debug.print(fmt="New probs: {v}", v=new_probs)
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(subkey, C["probs"].set(new_probs), argdiffs)
    return new_trace

We can now run inference, obtaining the final trace and some intermediate traces for visualizing inference.

In [None]:
if DEBUG:
    (
        all_posterior_means,
        all_posterior_weights,
        all_cluster_assignment,
        posterior_trace,
    ) = infer(datapoints)
else:
    posterior_trace = infer(datapoints)

posterior_trace

Plotting results

In [None]:
# Prepare data for the animation
data_points = datapoints["datapoints", "obs"].tolist()
np.random.seed(42)
jitter = np.random.uniform(-0.05, 0.05, size=len(data_points)).tolist()
std_dev = np.sqrt(OBS_VARIANCE) * 1.5
all_cluster_assignments_list = [a.tolist() for a in all_cluster_assignment]
all_posterior_means_list = [m.tolist() for m in all_posterior_means]
all_posterior_weights_list = [w.tolist() for w in all_posterior_weights]

# Define a consistent color palette to use throughout the visualization
color_palette = """
const plotColors = [
    "#4c78a8", "#f58518", "#e45756", "#72b7b2", "#54a24b", 
    "#eeca3b", "#b279a2", "#ff9da6", "#9d755d", "#bab0ac",
    "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
    "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
];
"""

# Shared data initialization for all plot components
frame_data_js = (
    """
// Get current frame data
const frame = $state.frame;
const hoveredCluster = $state.hoveredCluster;
const means = """
    + str(all_posterior_means_list)
    + """[frame];
const weights = """
    + str(all_posterior_weights_list)
    + """[frame];
const assignments = """
    + str(all_cluster_assignments_list)
    + """[frame];
const stdDev = """
    + str(std_dev)
    + """;
"""
)

# Create a visualizer with animation
(
    Plot.initialState({"frame": 0, "hoveredCluster": None})
    |
    # Main visualization that updates based on the current frame
    Plot.plot(
        {
            "marks": [
                # 1. Data points with jitter - show all data points with optional highlighting
                Plot.dot(
                    Plot.js(
                        """function() {
                    """
                        + frame_data_js
                        + """
                    const dataPoints = """
                        + str(data_points)
                        + """;
                    const jitter = """
                        + str(jitter)
                        + """;
                    
                    """
                        + color_palette
                        + """
                    
                    // Return all points with hover-aware opacity
                    return dataPoints.map((x, i) => {
                        const clusterIdx = assignments[i];
                        // If a cluster is hovered, reduce opacity of other clusters' points
                        const isHovered = hoveredCluster !== null && clusterIdx === hoveredCluster;
                        const opacity = hoveredCluster === null ? 0.5 : (isHovered ? 0.7 : 0.15);
                        return {
                            x: x,
                            y: jitter[i],
                            color: plotColors[clusterIdx % 20],
                            opacity: opacity
                        };
                    });
                }()"""
                    ),
                    {"x": "x", "y": "y", "fill": "color", "r": 3, "opacity": "opacity"},
                ),
                # 2. Combined error bars (both horizontal lines and vertical caps)
                Plot.line(
                    Plot.js(
                        """function() {
                    """
                        + frame_data_js
                        + """
                    const capSize = 0.04;  // Size of the vertical cap lines
                    
                    """
                        + color_palette
                        + """
                    
                    // We'll collect all line segments in a flat array
                    const result = [];
                    
                    for (let i = 0; i < means.length; i++) {
                        // Only include error bars for clusters with weight >= 0.01
                        if (weights[i] >= 0.01) {
                            // Determine if this cluster is being hovered
                            const isHovered = hoveredCluster === i;
                            const opacity = hoveredCluster === null ? 0.7 : (isHovered ? 1.0 : 0.3);
                            const strokeWidth = isHovered ? 4 : 3;
                            const color = plotColors[i % 20];
                            
                            // Add horizontal line (error bar itself)
                            result.push({x: means[i] - stdDev, y: 0, cluster: i, color, opacity, width: strokeWidth});
                            result.push({x: means[i] + stdDev, y: 0, cluster: i, color, opacity, width: strokeWidth});
                            
                            // Add left cap (vertical line)
                            result.push({x: means[i] - stdDev, y: -capSize, cluster: i, color, opacity, width: strokeWidth});
                            result.push({x: means[i] - stdDev, y: capSize, cluster: i, color, opacity, width: strokeWidth});
                            
                            // Add right cap (vertical line)
                            result.push({x: means[i] + stdDev, y: -capSize, cluster: i, color, opacity, width: strokeWidth});
                            result.push({x: means[i] + stdDev, y: capSize, cluster: i, color, opacity, width: strokeWidth});
                        }
                    }
                    return result;
                }()"""
                    ),
                    {
                        "x": "x",
                        "y": "y",
                        "stroke": "color",
                        "strokeWidth": "width",
                        "opacity": "opacity",
                        "z": "cluster",
                    },
                ),
                # 3. Cluster means as stars
                Plot.dot(
                    Plot.js(
                        """function() {
                    """
                        + frame_data_js
                        + """
                    """
                        + color_palette
                        + """
                        
                    // Create a simple array for each cluster mean
                    return means.map((mean, i) => {
                        // Only include means for clusters with sufficient weight
                        if (weights[i] >= 0.01) {
                            const isHovered = hoveredCluster === i;
                            return {
                                x: mean,
                                y: 0,
                                cluster: i,
                                color: plotColors[i % 20],
                                opacity: isHovered ? 1.0 : 0.8
                            };
                        }
                        return null;  // Skip low-weight clusters
                    }).filter(d => d !== null);  // Remove null values
                }()"""
                    ),
                    {
                        "x": "x",
                        "y": "y",
                        "fill": "color",
                        "r": 10,
                        "symbol": "star",
                        "stroke": "black",
                        "strokeWidth": 2,
                        "opacity": "opacity",
                    },
                ),
            ],
            "grid": True,
            "marginTop": 40,
            "marginRight": 40,
            "marginBottom": 40,
            "marginLeft": 40,
            "style": {"height": "400px"},
            "title": Plot.js(
                "`Dirichlet Mixture Model - Iteration ${$state.frame} of "
                + str(len(all_posterior_means) - 1)
                + "`"
            ),
            "subtitle": "Cluster centers (★) with standard deviation (—) and data points (•)",
        }
    )
    |
    # Animation controls and legend with hover effects
    Plot.html(
        [
            "div",
            {"className": "p-4"},
            [
                "div",
                {"className": "mb-4"},
                Plot.Slider(
                    "frame",
                    init=0,
                    range=[0, len(all_posterior_means) - 1],
                    step=1,
                    label="Iteration",
                    width="100%",
                    fps=8,
                ),
            ],
            [
                "div",
                {"className": "mt-4"},
                Plot.js(
                    """function() {
                """
                    + frame_data_js
                    + """
                // Count assignments in current frame
                const counts = {};
                assignments.forEach(a => { counts[a] = (counts[a] || 0) + 1; });
                
                """
                    + color_palette
                    + """
                
                // Sort clusters by weight, filter by minimum weight, and limit to top 10
                const topClusters = Object.keys(weights)
                    .map(i => ({ 
                        id: parseInt(i), 
                        weight: weights[i], 
                        count: counts[parseInt(i)] || 0 
                    }))
                    .filter(c => c.weight >= 0.01)
                    .sort((a, b) => b.weight - a.weight)
                    .slice(0, 10);
                
                // Create placeholder rows for consistent height
                const placeholders = Array(Math.max(0, 10 - topClusters.length))
                    .fill(0)
                    .map(() => ["tr", {"className": "h-8"}, ["td", {"colSpan": 3}, ""]]);
                    
                return [
                    "div", {},
                    ["h3", {}, `Top Clusters by Weight (Iteration ${frame})`],
                    ["div", {"style": {"height": "280px", "overflow": "auto"}},
                        ["table", {"className": "w-full mt-2"},
                            ["thead", ["tr",
                                ["th", {"className": "text-left"}, "Cluster"],
                                ["th", {"className": "text-left"}, "Weight"],
                                ["th", {"className": "text-left"}, "Points"]
                            ]],
                            ["tbody",
                                ...topClusters.map(cluster => 
                                    ["tr", {
                                        "className": "h-8",
                                        "style": {
                                            "cursor": "pointer",
                                            "backgroundColor": $state.hoveredCluster === cluster.id ? "#f0f0f0" : "transparent"
                                        },
                                        "onMouseEnter": () => { $state.hoveredCluster = cluster.id; },
                                        "onMouseLeave": () => { $state.hoveredCluster = null; }
                                    },
                                    ["td", {"className": "py-1"}, 
                                        ["div", {"className": "flex items-center"},
                                            ["div", {
                                                "style": {
                                                    "backgroundColor": plotColors[cluster.id % 20],
                                                    "width": "24px",
                                                    "height": "24px",
                                                    "borderRadius": "4px",
                                                    "border": "1px solid rgba(0,0,0,0.2)",
                                                    "display": "inline-block",
                                                    "marginRight": "8px"
                                                }
                                            }],
                                            `Cluster ${cluster.id}`
                                        ]
                                    ],
                                    ["td", {"className": "py-1"}, cluster.weight.toFixed(4)],
                                    ["td", {"className": "py-1"}, cluster.count]
                                    ]
                                ),
                                ...placeholders
                            ]
                        ]
                    ]
                ];
            }()"""
                ),
            ],
        ]
    )
)

For the interested reader, here's some exercises to try out to make this model better:
1) Extend the model to infer the variance of the clusters by putting an inverse_gamma prior replacing the `OBS_VARIANCE` hyperparameter and doing block-Gibbs on it using the normal-inverse-gamma conjugacy
2) Try a better initialization of the datapoint assignment: pick a point a use something like k-means and assign all the surrounding points to the same initial cluster. Iterate on all the points until they all have some initial cluster.
3) Improve inference using SMC via data annealing: subssample 1/100 of the data and run inference on this, then run inference again on 1/10 of the data starting with the inferred choices for cluster means and weights from the previous trace, and finally repeat for the whole data.

Note that the model is still expected to get stuck in local minima (the clustering at the borders isn't great), and one way to improve upon it would be to use a split-merge move, via reversible-jump MCMC.