# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.

In [None]:
import animation
import gibbs_updates
import jax
import jax.numpy as jnp
import model_simple_continuous
from scipy import datasets

from genjax import ChoiceMapBuilder as C
from genjax import pretty

pretty()

## Model

### Testing to sample from model

In [None]:
image = datasets.face()
H, W, _ = image.shape

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([50.0, 50.0]),
    sigma_rgb=jnp.array([10.0, 10.0, 10.0]),
    n_blobs=200,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
# tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))

In [None]:
# TODO: attempt to modify hypers in place. this works following what McCoy recommended. It creates a copy but it's convenient.

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([300.0, 300.0]),
    b_xy=hypers.b_xy,
    mu_xy=hypers.mu_xy,
    a_rgb=hypers.a_rgb,
    b_rgb=hypers.b_rgb,
    alpha=hypers.alpha,
    sigma_xy=hypers.sigma_xy,
    sigma_rgb=hypers.sigma_rgb,
    n_blobs=hypers.n_blobs,
    H=hypers.H,
    W=hypers.W,
)

# replace(hypers, a_xy=jnp.array([301.0, 301.0]))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

### Gibbs updates

NEXT STEPS: 
- finish debugging/improving Genstudio visualization
- add one by one the Gibbs updates and test them individually
  - update xy mean
  - update rgb mean
  - update cluster assignment
  - update cluster weight
  - update sigma_xy
  - update sigma_rgb

### Main inference loop

In [None]:
N_ITER = 100
RECORD = True
DEBUG = False
TRIVIAL = True


def id(key, trace):
    return trace


def infer(image, hypers):
    key = jax.random.key(32421)

    # Image pre-processing
    H = hypers.H
    W = hypers.W
    flattened_image = jnp.concatenate(
        (jnp.indices((H, W)).reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
    )
    xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

    # Setup for better initial trace
    n_blobs = hypers.n_blobs
    obs = C["likelihood_model", "xy"].set(xy) ^ C["likelihood_model", "rgb"].set(rgb)
    initial_weights = C["blob_model", "mixture_weight"].set(jnp.ones(n_blobs) / n_blobs)
    constraints = obs | initial_weights

    # Sample an initial trace
    key, subkey = jax.random.split(key)
    args = (hypers,)
    tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, constraints, args)

    # Record info for plotting and debugging purposes
    if RECORD:
        all_posterior_xy_means = [tr.get_choices()["blob_model", "xy_mean"]]
        all_posterior_xy_variances = [tr.get_choices()["blob_model", "sigma_xy"]]
        all_posterior_rgb_means = [tr.get_choices()["blob_model", "rgb_mean"]]
        all_posterior_rgb_variances = [tr.get_choices()["blob_model", "sigma_rgb"]]
        all_cluster_assignment = [tr.get_choices()["likelihood_model", "blob_idx"]]
        all_posterior_weights = [tr.get_choices()["blob_model", "mixture_weight"]]

        if DEBUG:
            jax.debug.print("Initial means: {v}", v=all_posterior_xy_means[0])
            jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        # Main inference loop
        for _ in range(N_ITER):
            # Gibbs update on `("blob_model", "xy_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_mean)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "xy_mean"])

            # Gibbs update on `("blob_model", "sigma_xy", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_xy_sigma)(subkey, tr)
            all_posterior_xy_variances.append(
                tr.get_choices()["blob_model", "sigma_xy"]
            )

            # Gibbs update on `("blob_model", "rgb_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_mean)(subkey, tr)
            all_posterior_rgb_means.append(tr.get_choices()["blob_model", "rgb_mean"])

            # Gibbs update on `("blob_model", "sigma_rgb", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_rgb_sigma)(subkey, tr)
            all_posterior_rgb_variances.append(
                tr.get_choices()["blob_model", "sigma_rgb"]
            )

            # Gibbs update on `("likelihood_model", "blob_idx", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_cluster_assignment)(subkey, tr)
            all_cluster_assignment.append(
                tr.get_choices()["likelihood_model", "blob_idx"]
            )

            # Gibbs update on `("blob_model", "mixture_weight", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            if TRIVIAL:
                tr = id(key, tr)
            else:
                tr = jax.jit(gibbs_updates.update_mixture_weight)(subkey, tr)
            all_posterior_weights.append(
                tr.get_choices()["blob_model", "mixture_weight"]
            )

        return (
            all_posterior_xy_means,
            all_posterior_xy_variances,
            all_posterior_rgb_means,
            all_posterior_rgb_variances,
            all_posterior_weights,
            all_cluster_assignment,
            tr,
        )


image = datasets.face()
hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([100.0, 100.0]),
    b_xy=jnp.array([10000.0, 10000.0]),
    mu_xy=jnp.array([H / 2, W / 2]),
    a_rgb=jnp.array([25.0, 25.0, 25.0]),
    b_rgb=jnp.array([450.0, 450.0, 450.0]),
    alpha=1.0,
    sigma_xy=jnp.array([1.0, 1.0]),
    sigma_rgb=jnp.array([1.0, 1.0, 1.0]),
    n_blobs=10,
    H=H,
    W=W,
)

(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_rgb_means,
    all_posterior_rgb_variances,
    all_posterior_weights,
    all_cluster_assignment,
    tr,
) = jax.jit(infer)(image, hypers)

### Visualization

In [None]:
# TODO:
# - need to fix colour of points to match that of their cluster,
# - the counting of points per cluster is busted.
# - hovering busted
# - number of clusters shown busted

# Example usage with the original data:
visualization = animation.create_cluster_visualization(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_weights,
    all_posterior_rgb_means,
    num_frames=15,  # Show more frames
    pixel_sampling=10,  # Sample every 10th pixel
    confidence_factor=3.0,  # Scale factor for ellipses
    min_weight=0.01,  # Minimum weight threshold
)

# Display the visualization
visualization