# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.

In [None]:
import gibbs_updates
import imageio
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import model_simple_continuous
import numpy as np
from IPython.display import HTML
from matplotlib.patches import Ellipse
from scipy import datasets

from genjax import ChoiceMapBuilder as C
from genjax import pretty

pretty()

## Model

### Testing to sample from model

In [None]:
image = datasets.face()
H, W, _ = image.shape

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([50.0, 50.0]),
    sigma_rgb=jnp.array([10.0, 10.0, 10.0]),
    n_blobs=10,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
# tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

### Gibbs updates

NEXT STEPS: add one by one the Gibbs updates and test them individually
- update xy mean
- update rgb mean
- update cluster assignment
- update cluster weight
- update sigma_xy
- update sigma_rgb

### Main inference loop

In [None]:
N_ITER = 100
RECORD = True
DEBUG = False


def infer(image, hypers):
    key = jax.random.key(32421)

    # Image pre-processing
    H = hypers.H
    W = hypers.W
    flattened_image = jnp.concatenate(
        (jnp.indices((H, W)).reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
    )
    xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

    # Setup for better initial trace
    n_blobs = hypers.n_blobs
    obs = C["likelihood_model", "xy"].set(xy) ^ C["likelihood_model", "rgb"].set(rgb)
    initial_weights = C["blob_model", "mixture_weight"].set(jnp.ones(n_blobs) / n_blobs)
    constraints = obs | initial_weights

    # Sample an initial trace
    key, subkey = jax.random.split(key)
    args = (hypers,)
    tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, constraints, args)

    # Record info for plotting and debugging purposes
    if RECORD:
        all_posterior_xy_means = [tr.get_choices()["blob_model", "xy_mean"]]
        all_posterior_xy_variances = [tr.get_choices()["blob_model", "sigma_xy"]]
        all_posterior_rgb_means = [tr.get_choices()["blob_model", "rgb_mean"]]
        all_posterior_rgb_variances = [tr.get_choices()["blob_model", "sigma_rgb"]]
        all_cluster_assignment = [tr.get_choices()["likelihood_model", "blob_idx"]]
        all_posterior_weights = [tr.get_choices()["blob_model", "mixture_weight"]]

        if DEBUG:
            jax.debug.print("Initial means: {v}", v=all_posterior_xy_means[0])
            jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        # Main inference loop
        for _ in range(N_ITER):
            # Gibbs update on `("blob_model", "xy_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_xy_mean)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "xy_mean"])

            # Gibbs update on `("blob_model", "sigma_xy", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_xy_sigma)(subkey, tr)
            all_posterior_xy_variances.append(
                tr.get_choices()["blob_model", "sigma_xy"]
            )

            # Gibbs update on `("blob_model", "rgb_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_rgb_mean)(subkey, tr)
            all_posterior_rgb_means.append(tr.get_choices()["blob_model", "rgb_mean"])

            # Gibbs update on `("blob_model", "sigma_rgb", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_rgb_sigma)(subkey, tr)
            all_posterior_rgb_variances.append(
                tr.get_choices()["blob_model", "sigma_rgb"]
            )

            # Gibbs update on `("likelihood_model", "blob_idx", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_cluster_assignment)(subkey, tr)
            all_cluster_assignment.append(
                tr.get_choices()["likelihood_model", "blob_idx"]
            )

            # Gibbs update on `("blob_model", "mixture_weight", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_mixture_weight)(subkey, tr)
            all_posterior_weights.append(
                tr.get_choices()["blob_model", "mixture_weight"]
            )

        return (
            all_posterior_xy_means,
            all_posterior_xy_variances,
            all_posterior_rgb_means,
            all_posterior_rgb_variances,
            all_posterior_weights,
            all_cluster_assignment,
            tr,
        )


image = datasets.face()
hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([1.0, 1.0]),
    sigma_rgb=jnp.array([1.0, 1.0, 1.0]),
    n_blobs=10,
    H=H,
    W=W,
)

(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_rgb_means,
    all_posterior_rgb_variances,
    all_posterior_weights,
    all_cluster_assignment,
    tr,
) = jax.jit(infer)(image, hypers)

### Visualization

In [None]:
NUM_FRAMES = 10


# Function to create plot for a given index
def create_plot(idx):
    image = datasets.face()

    posterior_xy_means = all_posterior_xy_means[idx]
    posterior_xy_variances = all_posterior_xy_variances[idx]
    posterior_rgb_means = all_posterior_rgb_means[idx]
    # posterior_rgb_variances = all_posterior_rgb_variances[idx]
    posterior_weights = all_posterior_weights[idx]
    cluster_assignment = all_cluster_assignment[idx].reshape(H, W)

    fig, ax = plt.subplots(figsize=(12, 5), facecolor="white")
    ax.set_facecolor("white")

    ax.imshow(image, alpha=0.4, extent=[0, image.shape[1], image.shape[0], 0])
    ax.set_xlim(0, image.shape[1])
    ax.set_ylim(image.shape[0], 0)

    # Create coordinate arrays for all points
    y_coords, x_coords = jnp.mgrid[0:H, 0:W]
    points = jnp.stack([x_coords.flatten(), y_coords.flatten()], axis=1)

    # Plot datapoints colored by cluster assignment and posterior means together
    for i in range(len(posterior_xy_means)):
        # Only plot points assigned to this cluster
        mask = cluster_assignment.flatten() == i
        n_points = jnp.sum(mask)

        if not jnp.any(mask):  # Skip if no points assigned to this cluster
            continue

        cluster_color = tuple(
            float(x) for x in (posterior_rgb_means[i].astype(float) / 255.0)
        )

        # Plot only 1/50 of the actual data points for this cluster for legibility
        points_in_cluster = points[mask]
        subsample_idx = jnp.arange(0, len(points_in_cluster), 50)
        points_subsampled = points_in_cluster[subsample_idx]
        ax.scatter(
            points_subsampled[:, 0],
            points_subsampled[:, 1],
            c=[cluster_color],
            marker=".",
            s=10,
            alpha=0.25,
        )

        # Plot posterior means with size proportional to weights
        weight = float(posterior_weights[i])
        ax.scatter(
            float(posterior_xy_means[i, 0]),
            float(posterior_xy_means[i, 1]),
            c=[cluster_color],
            marker="*",
            s=150 + weight * 800,  #
            alpha=1,
            edgecolor="black",
            linewidth=1,
            label=f"Cluster {i + 1} (Prob: {weight:.6f}, Points: {int(n_points)})",
        )

        # Plot standard deviation of the Gaussian means as ellipses
        std_x = float(jnp.sqrt(posterior_xy_variances[i, 0]))
        std_y = float(jnp.sqrt(posterior_xy_variances[i, 1]))

        ellipse = Ellipse(
            (
                float(posterior_xy_means[i, 0]),
                float(posterior_xy_means[i, 1]),
            ),
            width=20 * std_x,
            height=20 * std_y,
            alpha=0.5,
            color=cluster_color,
            edgecolor="black",
            linewidth=1,
        )
        ax.add_patch(ellipse)

    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.15),
        ncol=3,
        facecolor="white",
        edgecolor="black",
        framealpha=1,
    )
    ax.set_title(f"Iteration {idx}", pad=20, fontsize=12, fontweight="bold")
    plt.tight_layout()
    return fig


# Create animation
frames = []
for i in range(NUM_FRAMES):
    fig = create_plot(int(N_ITER / NUM_FRAMES * i))
    fig.canvas.draw()
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    frames.append(image)
    plt.close(fig)

# Create animation from frames
fig = plt.figure(figsize=(12, 5), facecolor="white")
from matplotlib import animation

ani = animation.ArtistAnimation(
    fig,
    [[plt.imshow(frame)] for frame in frames],
    interval=200,
    blit=True,
)

# Save animation as GIF
imageio.mimsave("dirichlet_mixture_animation.gif", frames, fps=15)

# Display animation in notebook
HTML(ani.to_jshtml())