# Rao-Blackwellisation
## From high variance low compute to high compute low variance

We will revisit the Gen1D model and use a Collapsed Gibbs sampler for inference.


The idea is that we marginalize out the cluster weights "probs" and means "means".
Following the Rao-Blackwell theorem, this should reduce the variance of our estimate, meaning this should lead to faster mixing and better exploration of the tails of the posterior distribution. 
We may then directly determine the predictive distribution of cluster assignment ["idx", i] given
the other cluster assignments "idx", and construct a more efficient sampler.

See Graphical Models for Visual Object Recognition and Tracking by Erik B. Sudderth.

In [None]:
import imageio
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import categorical, dirichlet, gen, normal, pretty
from genjax._src.core.pytree import Const, Pytree

pretty()
key = jax.random.key(0)

TODO: There are a few ways to revise our model to do this. 
First, we can decide to not change the model and adapt inference to ignore the sampled "idx" values and use the collapsed-Gibbs logic on the rest.
We can then overwrite the initially sampled values for "idx" with a sample from the approximate posterior.

TODO: we could try  partially collapsed Gibbs sampler. It should still be parallel friendly otherwise there's no real point in GenJAX, at least not until much later in the development.
- if we try to marginalize out the cluster weights, we then only have 2 Gibbs update:
  - update cluster means: no change
  - update cluster assignment: 
  $\begin{equation}
    P(["idx", i] ~|~ "idx"_{-i}, "obs",{\footnotesize\text{ALPHA}}, "means") \propto P(["idx", i] ~|~ "idx", {\footnotesize\text{ALPHA}})  P(["obs", i] ~|~ "idx", "obs"_{-i}, {\footnotesize\text{(PRIOR\_MEAN,PRIOR\_VARIANCE)}})
    \end{equation}$

Using the Markov structure, we have the following factorization
$\begin{equation}
P(["idx", i] ~|~ "idx"_{-i}, "obs",{\footnotesize\text{ALPHA}}, {\footnotesize\text{(PRIOR\_MEAN,PRIOR\_VARIANCE)}}) \propto P(["idx", i] ~|~ "idx", {\footnotesize\text{ALPHA}})  P(["obs", i] ~|~ "idx", "obs"_{-i}, {\footnotesize\text{(PRIOR\_MEAN,PRIOR\_VARIANCE)}})
\end{equation}$

For the first term, we marginalize the mixture weights and it is given by:
$\begin{equation}
P(["idx", i] = k ~|~ "idx"_{-i}, {\footnotesize\text{ALPHA}}) = \frac{N_k^{-i}+{\footnotesize\text{ALPHA}} /{\footnotesize\text{N\_CLUSTERS}}}{{\footnotesize\text{N\_DATAPOINTS}}-1+ {\footnotesize\text{ALPHA}}}
\end{equation}$
where $N_k^{-i}$ is the current number of obs assigned to the cluster $k$, excluding $["idx", i]$.

For the second term, we have:
$\begin{equation}
P(["obs", i] ~|~ ["idx", i] = k, "idx"_{-i}, "obs"_{-i}, {\footnotesize\text{(PRIOR\_MEAN,PRIOR\_VARIANCE)}}) = P(["obs", i] ~|~ \{["obs", j] ~|~ ["idx",j]=k,j\neq i\} ,{\footnotesize\text{(PRIOR\_MEAN,PRIOR\_VARIANCE)}})
\end{equation}$

where the last one for us will be a a Student–t predictive distribution, which can usually be approximated by moment–matched Gaussians.

The resulting new Gibbs sweep looks like:
1. Sample a random permutation $\tau$ of $\{1,...,{\footnotesize\text{N\_DATAPOINTS}}\}$
2. For each i in  $\{\tau(1),...,\tau({\footnotesize\text{N\_DATAPOINTS}})\}$
   1.  For each of the clusters, compute the predictive likelihood (2) (possibly from cached sufficient statistics)
   2.  Sample a new cluster assignment for each point from a categorical with values from (1)
   3.  Update cached sufficient statistics

At the end, optionally, we can sample mixture parameters and cluster means via the previous Gibbs sampling scheme and the current assignment of the datapoints.

One could also do a partially collapsed Gibbs sampler for more parallelization, or do the wrong bloc-Gibbs update with an MH accept ratio for more parallelism in the number of datapoints (again a coarse to fine inference using a style of Hogwild Gibbs sampling). 
Could look at:
- https://www.phontron.com/paper/neubig14pgibbs.pdf
- https://www.sciencedirect.com/science/article/pii/S0167865517300752?casa_token=nYtOwi2vkpEAAAAA:EI_UgTLBHlM4bqwbH4YkXEeHAuGCIEkRcA6WOgGQxYxtyDbNNAl23AKEQtvVNwz0_1oZxxu5
- https://cs.brown.edu/~sudderth/papers/sudderthPhD.pdf p90-94

In [None]:
# Hyper parameters
PRIOR_VARIANCE = 10.0
OBS_VARIANCE = 1.0
N_DATAPOINTS = 5000
N_CLUSTERS = 40
ALPHA = float(N_DATAPOINTS / (N_CLUSTERS * 10))
PRIOR_MEAN = 50.0
N_ITER = 50

# Debugging mode
DEBUG = True


@Pytree.dataclass
class Cluster(Pytree):
    mean: float


@gen
def generate_cluster(mean, var):
    cluster_mean = normal(mean, var) @ "mean"
    return Cluster(cluster_mean)


@gen
def generate_cluster_weight(alphas):
    probs = dirichlet(alphas) @ "probs"
    return probs


@gen
def generate_datapoint(probs, clusters):
    idx = categorical(jnp.log(probs)) @ "idx"
    obs = normal(clusters.mean[idx], OBS_VARIANCE) @ "obs"
    return obs


@gen
def generate_data(n_clusters: Const[int], n_datapoints: Const[int], alpha: float):
    clusters = (
        generate_cluster.repeat(n=n_clusters.unwrap())(PRIOR_MEAN, PRIOR_VARIANCE)
        @ "clusters"
    )

    probs = generate_cluster_weight.inline(
        alpha / n_clusters.unwrap() * jnp.ones(n_clusters.unwrap())
    )

    datapoints = (
        generate_datapoint.repeat(n=n_datapoints.unwrap())(probs, clusters)
        @ "datapoints"
    )

    return datapoints

In [None]:
datapoints = C["datapoints", "obs"].set(
    jnp.concatenate([
        jax.random.uniform(jax.random.key(i), shape=(int(N_DATAPOINTS / N_CLUSTERS),))
        + PRIOR_MEAN
        + PRIOR_VARIANCE * (-4 + 8 * i / N_CLUSTERS)
        for i in range(N_CLUSTERS)
    ])
)

In [None]:
def infer(datapoints):
    key = jax.random.key(32421)
    args = (Const(N_CLUSTERS), Const(N_DATAPOINTS), ALPHA)
    key, subkey = jax.random.split(key)
    initial_weights = C["probs"].set(jnp.ones(N_CLUSTERS) / N_CLUSTERS)
    constraints = datapoints | initial_weights
    tr, _ = generate_data.importance(subkey, constraints, args)

    if DEBUG:
        all_posterior_means = [tr.get_choices()["clusters", "mean"]]
        all_posterior_weights = [tr.get_choices()["probs"]]
        all_cluster_assignment = [tr.get_choices()["datapoints", "idx"]]

        for _ in range(N_ITER):
            # Gibbs update on `("clusters", i, "mean")` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_means)(subkey, tr)
            all_posterior_means.append(tr.get_choices()["clusters", "mean"])

            # # Gibbs update on `("datapoints", i, "idx")` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_datapoint_assignment)(subkey, tr)
            all_cluster_assignment.append(tr.get_choices()["datapoints", "idx"])

            # # Gibbs update on `probs`
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_weights)(subkey, tr)
            all_posterior_weights.append(tr.get_choices()["probs"])

        return all_posterior_means, all_posterior_weights, all_cluster_assignment, tr

    else:

        def update(carry, _):
            key, tr = carry
            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_means)(subkey, tr)

            key, subkey = jax.random.split(key)
            tr = jax.jit(update_datapoint_assignment)(subkey, tr)

            key, subkey = jax.random.split(key)
            tr = jax.jit(update_cluster_weights)(subkey, tr)
            return (key, tr), None

        (key, tr), _ = jax.lax.scan(update, (key, tr), None, length=N_ITER)
        return tr


def update_cluster_means(key, trace):
    datapoint_indexes = trace.get_choices()["datapoints", "idx"]
    datapoints = trace.get_choices()["datapoints", "obs"]
    n_clusters = trace.get_args()[0].unwrap()
    current_means = trace.get_choices()["clusters", "mean"]

    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )

    cluster_means = (
        jax.vmap(
            lambda i: jnp.sum(jnp.where(datapoint_indexes == i, datapoints, 0)),
            in_axes=(0),
            out_axes=(0),
        )(jnp.arange(n_clusters))
        / category_counts
    )

    posterior_means = (
        PRIOR_VARIANCE
        / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts)
        * cluster_means
        + (OBS_VARIANCE / category_counts)
        / (PRIOR_VARIANCE + OBS_VARIANCE / category_counts)
        * PRIOR_MEAN
    )

    posterior_variances = 1 / (1 / PRIOR_VARIANCE + category_counts / OBS_VARIANCE)

    key, subkey = jax.random.split(key)
    new_means = (
        generate_cluster.vmap()
        .simulate(key, (posterior_means, posterior_variances))
        .get_choices()["mean"]
    )

    chosen_means = jnp.where(category_counts == 0, current_means, new_means)

    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["clusters", "mean"].set(chosen_means), argdiffs
    )
    return new_trace


def update_datapoint_assignment(key, trace):
    def compute_local_density(x, i):
        datapoint_mean = trace.get_choices()["datapoints", "obs", x]
        chm = C["obs"].set(datapoint_mean).at["idx"].set(i)
        clusters = Cluster(trace.get_choices()["clusters", "mean"])
        probs = trace.get_choices()["probs"]
        args = (probs, clusters)
        model_logpdf, _ = generate_datapoint.assess(chm, args)
        return model_logpdf

    n_clusters = trace.get_args()[0].unwrap()
    n_datapoints = trace.get_args()[1].unwrap()
    local_densities = jax.vmap(
        lambda x: jax.vmap(lambda i: compute_local_density(x, i))(
            jnp.arange(n_clusters)
        )
    )(jnp.arange(n_datapoints))

    key, subkey = jax.random.split(key)
    new_datapoint_indexes = (
        genjax.categorical.vmap().simulate(key, (local_densities,)).get_choices()
    )
    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(
        subkey, C["datapoints", "idx"].set(new_datapoint_indexes), argdiffs
    )
    return new_trace


def update_cluster_weights(key, trace):
    n_clusters = trace.get_args()[0].unwrap()
    category_counts = jnp.bincount(
        trace.get_choices()["datapoints", "idx"],
        length=n_clusters,
        minlength=n_clusters,
    )

    new_alpha = ALPHA / n_clusters * jnp.ones(n_clusters) + category_counts
    key, subkey = jax.random.split(key)
    new_probs = generate_cluster_weight.simulate(key, (new_alpha,)).get_retval()

    argdiffs = genjax.Diff.no_change(trace.args)
    new_trace, _, _, _ = trace.update(subkey, C["probs"].set(new_probs), argdiffs)
    return new_trace

In [None]:
if DEBUG:
    (
        all_posterior_means,
        all_posterior_weights,
        all_cluster_assignment,
        posterior_trace,
    ) = infer(datapoints)
else:
    posterior_trace = infer(datapoints)

In [None]:
# Function to create plot for a given index
def create_plot(idx):
    datapoint = datapoints["datapoints", "obs"]
    posterior_means = all_posterior_means[idx]
    posterior_weights = all_posterior_weights[idx]
    cluster_assignment = all_cluster_assignment[idx]

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 5))

    # Plot datapoints colored by cluster assignment and posterior means together
    for i in range(len(posterior_means)):
        # Only plot points assigned to this cluster

        mask = cluster_assignment == i
        n_points = jnp.sum(mask)  # Count points in this cluster

        if not jnp.any(mask):  # Skip if no points assigned to this cluster
            continue

        # Add random jitter to y-coordinates
        key = jax.random.PRNGKey(i)  # Use cluster index as seed
        y_jitter = jax.random.uniform(
            key, shape=datapoint[mask].shape, minval=-0.1, maxval=0.1
        )

        ax.scatter(
            datapoint[mask],
            y_jitter,  # Use jittered y-coordinates
            color=f"C{i}",
            alpha=0.3,
            s=5,
        )

        # Plot posterior means with size proportional to weights
        weight = posterior_weights[i]  # Get current weight for this iteration
        ax.scatter(
            posterior_means[i],
            0,
            color=f"C{i}",
            marker="*",
            s=300 + weight * 1200,  # Made cluster means much bigger
            alpha=1,
            label=f"Cluster {i + 1} (Prob: {weight:.6f}, Points: {n_points})",  # Added point count
        )

        # Plot standard deviation of the Gaussian means
        ax.errorbar(
            posterior_means[i],
            0,
            xerr=jnp.sqrt(OBS_VARIANCE),
            fmt="o",
            color=f"C{i}",
            capsize=5,
        )

    ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3)
    ax.set_title(f"Iteration {idx}")
    plt.tight_layout()
    return fig


NUM_FRAMES = 50

# Create animation
frames = []
for i in range(NUM_FRAMES):
    fig = create_plot(int(N_ITER / NUM_FRAMES * i))
    # Convert figure to image array
    fig.canvas.draw()
    image = np.frombuffer(
        fig.canvas.buffer_rgba(), dtype=np.uint8
    )  # Updated to use buffer_rgba
    image = image.reshape(
        fig.canvas.get_width_height()[::-1] + (4,)
    )  # Note: buffer_rgba returns RGBA
    frames.append(image[:, :, :3])  # Convert RGBA to RGB by dropping alpha channel
    plt.close(fig)

# Create animation from frames
fig = plt.figure(figsize=(12, 5))
from matplotlib import animation

ani = animation.ArtistAnimation(
    fig,
    [[plt.imshow(frame)] for frame in frames],
    interval=200,  # .2 second between frames
    blit=True,
)

# Save animation as GIF
imageio.mimsave("dirichlet_mixture_animation.gif", frames, fps=15)

# Display animation in notebook
HTML(ani.to_jshtml())