# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.

In [None]:
import animation
import gibbs_updates
import jax
import jax.numpy as jnp
import model_simple_continuous
from scipy import datasets

from genjax import ChoiceMapBuilder as C
from genjax import pretty

pretty()

## Model

### Testing to sample from model

In [None]:
image = datasets.face()
H, W, _ = image.shape

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([50.0, 50.0]),
    sigma_rgb=jnp.array([10.0, 10.0, 10.0]),
    n_blobs=100,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
# tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

### Gibbs updates

NEXT STEPS: 
- add one by one the Gibbs updates and test them individually
  - update cluster weight
  - some updates seem busted
  - improve visuals
  - update sigma_xy
  - update sigma_rgb
  - proper sigma_xy, sigma_rgb for clusters variance
  - infer hyperparams using exact Gibbs

### Main inference loop

In [None]:
from genjax._src.core.generative.choice_map import ChoiceMap

N_ITER = 10
RECORD = True
DEBUG = False
TRIVIAL = True


def id(key, trace):
    return trace


def infer(image, hypers):
    key = jax.random.key(32421)

    # Image pre-processing
    H = hypers.H
    W = hypers.W
    flattened_image = jnp.concatenate(
        (jnp.indices((H, W)).reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
    )
    xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

    # Setup for better initial trace
    n_blobs = hypers.n_blobs
    obs: ChoiceMap = C["likelihood_model", "xy"].set(xy) | C[
        "likelihood_model", "rgb"
    ].set(rgb)
    initial_weights = C["blob_model", "mixture_weight"].set(jnp.ones(n_blobs) / n_blobs)

    grid_of_xy_means = jnp.array([
        [
            (i % jnp.array(jnp.sqrt(n_blobs), dtype=jnp.int32) + 0.5)
            * (W / jnp.sqrt(n_blobs)),
            (i // jnp.array(jnp.sqrt(n_blobs), dtype=jnp.int32) + 0.5)
            * (H / jnp.sqrt(n_blobs)),
        ]
        for i in range(n_blobs)
    ])
    initial_cluster_xy_mean = C["blob_model", "xy_mean"].set(grid_of_xy_means)
    constraints = obs | initial_weights | initial_cluster_xy_mean

    # Sample an initial trace
    key, subkey = jax.random.split(key)
    args = (hypers,)
    tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, constraints, args)

    # Record info for plotting and debugging purposes
    if RECORD:
        all_posterior_xy_means = [tr.get_choices()["blob_model", "xy_mean"]]
        all_posterior_xy_variances = [tr.get_choices()["blob_model", "sigma_xy"]]
        all_posterior_rgb_means = [tr.get_choices()["blob_model", "rgb_mean"]]
        all_posterior_rgb_variances = [tr.get_choices()["blob_model", "sigma_rgb"]]
        all_cluster_assignment = [tr.get_choices()["likelihood_model", "blob_idx"]]
        all_posterior_weights = [tr.get_choices()["blob_model", "mixture_weight"]]

        if DEBUG:
            jax.debug.print("Initial means: {v}", v=all_posterior_xy_means[0])
            jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        # Main inference loop
        for _ in range(N_ITER):
            # Gibbs update on `("likelihood_model", "blob_idx", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_cluster_assignment)(subkey, tr)
            all_cluster_assignment.append(
                tr.get_choices()["likelihood_model", "blob_idx"]
            )

            # Gibbs update on `("blob_model", "xy_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_xy_mean)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "xy_mean"])

            # Gibbs update on `("blob_model", "sigma_xy", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_sigma)(subkey, tr)
            all_posterior_xy_variances.append(
                tr.get_choices()["blob_model", "sigma_xy"]
            )

            # Gibbs update on `("blob_model", "rgb_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_rgb_mean)(subkey, tr)
            all_posterior_rgb_means.append(tr.get_choices()["blob_model", "rgb_mean"])

            # Gibbs update on `("blob_model", "sigma_rgb", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_sigma)(subkey, tr)
            all_posterior_rgb_variances.append(
                tr.get_choices()["blob_model", "sigma_rgb"]
            )

            # Gibbs update on `("blob_model", "mixture_weight", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(gibbs_updates.update_mixture_weight)(subkey, tr)
            all_posterior_weights.append(
                tr.get_choices()["blob_model", "mixture_weight"]
            )

        return (
            all_posterior_xy_means,
            all_posterior_xy_variances,
            all_posterior_rgb_means,
            all_posterior_rgb_variances,
            all_posterior_weights,
            all_cluster_assignment,
            tr,
        )


(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_rgb_means,
    all_posterior_rgb_variances,
    all_posterior_weights,
    all_cluster_assignment,
    tr,
) = jax.jit(infer)(image, hypers)

### Visualization

In [None]:
visualization = animation.create_cluster_visualization(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_weights,
    all_posterior_rgb_means,
    all_cluster_assignment,
    num_frames=15,
    pixel_sampling=10,  # Sample every 10th pixel
    confidence_factor=3.0,  # Scale factor for ellipses
    min_weight=0.01,  # Minimum weight threshold for showing clusters
)

visualization