# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.
We recommend running it on a GPU, especially if N_ITER > 10 or N_BLOBS > 100.

In [None]:
from typing import Any

import animation
import gibbs_updates
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import model_simple_continuous
from scipy import datasets

from genjax import ChoiceMapBuilder as C
from genjax import pretty
from genjax._src.core.generative.choice_map import ChoiceMap

pretty()

## Hyperparameters

In [None]:
N_ITER = 5
RECORD = True
DEBUG = False
# Helper to desactivate some Gibbs move and help debug
TRIVIAL = [False, False, False, False, True, False]

image = datasets.face()
H, W, _ = image.shape

# # Load and convert image
# image = mpimg.imread("image (4).png")[::2, ::2]
# H, W, _ = image.shape

N_BLOBS = 81  # needs to be a square number for now
A_XY = jnp.array([100.0, 100.0])
B_XY = jnp.array([10000.0, 10000.0])
MU_XY = jnp.array([H / 2, W / 2])
A_RGB = jnp.array([25.0, 25.0, 25.0])
B_RGB = jnp.array([450.0, 450.0, 450.0])
ALPHA: float = 1.0
SIGMA_XY = jnp.array([
    H / jnp.sqrt(N_BLOBS),
    W / jnp.sqrt(N_BLOBS),
])  # so that the initial grid of blobs roughly covers the image
SIGMA_RGB = jnp.array([10.0, 10.0, 10.0])

## Model

### Testing to sample from model

In [None]:
hypers = model_simple_continuous.Hyperparams(
    a_xy=A_XY,
    b_xy=B_XY,
    mu_xy=MU_XY,
    a_rgb=A_RGB,
    b_rgb=B_RGB,
    alpha=ALPHA,
    sigma_xy=SIGMA_XY,
    sigma_rgb=SIGMA_RGB,
    n_blobs=N_BLOBS,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
# tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

### Gibbs updates

NEXT IMMEDIATE STEPS: 
- add one by one the Gibbs updates and test them individually
  - improve performance and try on GPU
  - update sigma_xy
  - update sigma_rgb
  - there's a few weird clusters on the animation
  - ask George how to recover when some cluster has no points or no weight
  - infer hyperparams using exact Gibbs

NEXT BIGGER STEPS:
- add hyperclustering to recover proto-objects
- do real time inference on videos
- model attention (bias number of gaussians in a chosen region by artificially likelihood to matter more, e.g. duplicating points in that region)

### Main inference loop

In [None]:
def id(key, trace):
    return trace


def infer(image, hypers):
    key = jax.random.key(32421)

    # Image pre-processing
    H = hypers.H
    W = hypers.W
    y_coords, x_coords = jnp.indices((H, W))
    coords = jnp.stack([x_coords, y_coords], axis=-1)  # Create (x,y) pairs
    flattened_image = jnp.concatenate(
        (coords.reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
    )
    xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

    # Setup for better initial trace
    n_blobs = hypers.n_blobs
    obs: ChoiceMap = C["likelihood_model", "xy"].set(xy) | C[
        "likelihood_model", "rgb"
    ].set(rgb)
    initial_weights = C["blob_model", "mixture_weight"].set(jnp.ones(n_blobs) / n_blobs)
    grid_of_xy_means = jnp.array([
        [
            (i % jnp.array(jnp.sqrt(n_blobs), dtype=jnp.int32) + 0.5)
            * (W / jnp.sqrt(n_blobs)),
            (i // jnp.array(jnp.sqrt(n_blobs), dtype=jnp.int32) + 0.5)
            * (H / jnp.sqrt(n_blobs)),
        ]
        for i in range(n_blobs)
    ])
    initial_cluster_xy_mean = C["blob_model", "xy_mean"].set(grid_of_xy_means)
    grid_of_points = jnp.argmin(
        jnp.sum((xy[:, None, :] - grid_of_xy_means[None, :, :]) ** 2, axis=2), axis=1
    )
    initial_cluster_assignment = C["likelihood_model", "blob_idx"].set(grid_of_points)
    constraints = (
        obs | initial_weights | initial_cluster_xy_mean | initial_cluster_assignment
    )

    # Sample an initial trace
    key, subkey = jax.random.split(key)
    args = (hypers,)
    tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, constraints, args)

    # Record info for plotting and debugging purposes
    if RECORD:
        all_posterior_xy_means: list[Any] = [tr.get_choices()["blob_model", "xy_mean"]]
        all_posterior_xy_variances = [tr.get_choices()["blob_model", "sigma_xy"]]
        all_posterior_rgb_means = [tr.get_choices()["blob_model", "rgb_mean"]]
        all_posterior_rgb_variances = [tr.get_choices()["blob_model", "sigma_rgb"]]
        all_cluster_assignment = [tr.get_choices()["likelihood_model", "blob_idx"]]
        all_posterior_weights = [tr.get_choices()["blob_model", "mixture_weight"]]

        if DEBUG:
            jax.debug.print("Initial means: {v}", v=all_posterior_xy_means[0])
            jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        # Main inference loop
        for _ in range(N_ITER):
            # Gibbs update on `("likelihood_model", "blob_idx", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL[0]:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_cluster_assignment)(subkey, tr)
            all_cluster_assignment.append(
                tr.get_choices()["likelihood_model", "blob_idx"]
            )

            # Gibbs update on `("blob_model", "xy_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL[1]:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_mean)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "xy_mean"])

            # Gibbs update on `("blob_model", "sigma_xy", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL[2]:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_sigma)(subkey, tr)
            all_posterior_xy_variances.append(
                tr.get_choices()["blob_model", "sigma_xy"]
            )

            # Gibbs update on `("blob_model", "rgb_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL[3]:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_mean)(subkey, tr)
            all_posterior_rgb_means.append(tr.get_choices()["blob_model", "rgb_mean"])

            # Gibbs update on `("blob_model", "sigma_rgb", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL[4]:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_sigma)(subkey, tr)
            all_posterior_rgb_variances.append(
                tr.get_choices()["blob_model", "sigma_rgb"]
            )

            # Gibbs update on `("blob_model", "mixture_weight", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL[5]:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_mixture_weight)(subkey, tr)
            all_posterior_weights.append(
                tr.get_choices()["blob_model", "mixture_weight"]
            )

        return (
            all_posterior_xy_means,
            all_posterior_xy_variances,
            all_posterior_rgb_means,
            all_posterior_rgb_variances,
            all_posterior_weights,
            all_cluster_assignment,
            tr,
        )

    else:  # One Gibbs sweep consist of updating each latent variable

        def update(carry, _):
            key, tr = carry
            # Gibbs update on cluster assignments
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_cluster_assignment)(subkey, tr)

            # Gibbs update on xy means
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_xy_mean)(subkey, tr)

            # Gibbs update on rgb means
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_rgb_mean)(subkey, tr)

            # Gibbs update on mixture weights
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_mixture_weight)(subkey, tr)

            return (key, tr), None

        # Overall inference performs a fixed number of Gibbs sweeps
        scan_fn = jax.jit(lambda c, x: update(c, x))
        (key, tr), _ = jax.lax.scan(scan_fn, (key, tr), None, length=N_ITER)
        return tr

In [None]:
if RECORD:
    (
        all_posterior_xy_means,
        all_posterior_xy_variances,
        all_posterior_rgb_means,
        all_posterior_rgb_variances,
        all_posterior_weights,
        all_cluster_assignment,
        tr,
    ) = jax.jit(infer)(image, hypers)
else:
    tr = jax.jit(infer)(image, hypers)

### Visualizating Inference

In [None]:
visualization = animation.create_cluster_visualization(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_weights,
    all_posterior_rgb_means,
    all_cluster_assignment,
    image=image,
    num_frames=15,
    pixel_sampling=10,  # Sample every 10th pixel
    confidence_factor=3.0,  # Scale factor for ellipses
    min_weight=0.01,  # Minimum weight threshold for showing clusters
)

visualization

### Image reconstruction

TODO: write a visualization of the intermediate steps, and maybe generate more points at each step for better rendering.

In [None]:
sigma_xy = tr.get_choices()["blob_model", "sigma_xy"]
sigma_rgb = tr.get_choices()["blob_model", "sigma_rgb"]
xy_mean = tr.get_choices()["blob_model", "xy_mean"]
rgb_mean = tr.get_choices()["blob_model", "rgb_mean"]
mixture_weight = tr.get_choices()["blob_model", "mixture_weight"]
# obs = sigma_xy | sigma_rgb | xy_mean | rgb_mean | mixture_weight
obs = (
    C["blob_model", "sigma_xy"].set(sigma_xy)
    | C["blob_model", "sigma_rgb"].set(sigma_rgb)
    | C["blob_model", "xy_mean"].set(xy_mean)
    | C["blob_model", "rgb_mean"].set(rgb_mean)
    | C["blob_model", "mixture_weight"].set(mixture_weight)
)


key, subkey = jax.random.split(key)
args = (hypers,)
new_tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, obs, args)

In [None]:
xy = new_tr.get_choices()["likelihood_model", "xy"]
rgb = jnp.clip(new_tr.get_choices()["likelihood_model", "rgb"], 0, 255)

# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

# Generated points
ax1.scatter(
    xy[:, 0], -xy[:, 1], c=rgb / 255.0, s=1
)  # Negate y coordinates to flip vertically
ax1.axis("equal")
ax1.set_title("Generated Points")
ax1.set_xlabel("X")
ax1.set_ylabel("Y")

# Original image
ax2.imshow(image)
ax2.axis("off")
ax2.set_title("Original Image")

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
def compute_squared_deviations(datapoints, cluster_indices, means, n_clusters):
    """Compute sum of squared deviations from cluster means.

    Args:
        datapoints: Array of shape (N, D) containing observations
        cluster_indices: Array of shape (N,) containing cluster assignments
        means: Array of shape (K, D) containing cluster means
        n_clusters: Number of clusters K

    Returns:
        Array of shape (K, D) containing sum of squared deviations per cluster
    """

    def sum_squared_devs(cluster_idx):
        # Compute (x - μ)² for all points
        diffs = datapoints - means[cluster_idx]
        squared_diffs = jnp.sum(
            diffs**2, axis=1, keepdims=True
        )  # Sum across dimensions first

        # Use where to mask points not in this cluster
        masked_diffs = jnp.where(
            (cluster_indices == cluster_idx)[:, None], squared_diffs, 0.0
        )

        # Sum over all points
        return jnp.sum(masked_diffs, axis=0)

    # Compute for each cluster
    deviations = jax.vmap(sum_squared_devs)(jnp.arange(n_clusters))
    return deviations


# Test case
datapoints = jnp.array([
    [1.0, 1.0],  # point 0: cluster 0
    [2.0, 2.0],  # point 1: cluster 0
    [10.0, 10.0],  # point 2: cluster 1
    [11.0, 11.0],  # point 3: cluster 1
])
cluster_indices = jnp.array([0, 0, 1, 1])
means = jnp.array([[1.5, 1.5], [10.5, 10.5]])
n_clusters = 2

# Let's debug cluster 0 manually:
cluster_idx = 0
diffs = datapoints - means[cluster_idx]
print("Diffs for cluster 0:")
print(diffs)

squared_diffs = diffs**2
print("\nSquared diffs:")
print(squared_diffs)

mask = (cluster_indices == cluster_idx)[:, None]
print("\nMask:")
print(mask)

masked_diffs = jnp.where(mask, squared_diffs, 0.0)
print("\nMasked diffs:")
print(masked_diffs)

sum_result = jnp.sum(masked_diffs, axis=0)
print("\nSum result:")
print(sum_result)