# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.

In [None]:
import jax
import jax.numpy as jnp
import model_simple_continuous
from scipy import misc

from genjax import ChoiceMapBuilder as C
from genjax import pretty

pretty()

## Model

### Testing to sample from model

In [None]:
image = misc.face()
H, W, _ = image.shape

hypers = model_simple_continuous.Hyperparams(
    a_xy=jnp.array([1.0, 1.0]),
    b_xy=jnp.array([1.0, 1.0]),
    mu_xy=jnp.array([0.0, 0.0]),
    a_rgb=jnp.array([1.0, 1.0, 1.0]),
    b_rgb=jnp.array([1.0, 1.0, 1.0]),
    alpha=1.0,
    sigma_xy=jnp.array([1.0, 1.0]),
    sigma_rgb=jnp.array([1.0, 1.0, 1.0]),
    n_blobs=10,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
# tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))
tr = jax.jit(model_simple_continuous.model.simulate)(subkey, (hypers,))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

### Gibbs updates

In [None]:
# def conjugate_update_mvnormal_with_known_cov(
#     prior_mean,  # (D,)
#     prior_cov,  # (D, D)
#     obs_cov,  # (D, D)
#     obs,  # (M, D)
# ):
#     """
#     Returns the posterior mean and covariance for the mean
#     of a multivariate normal distribution with known covariance.
#     That is, given
#       mu ~ Normal(prior_mean, prior_cov),
#       obs_i ~ Normal(mu, obs_cov) for i = 0, 1, ..., M-1,
#     this function returns (post_mean, post_cov) where
#       P(mu | obs) = Normal(post_mean, post_cov).
#     """
#     M = obs.shape[0]
#     post_cov = jnp.linalg.inv(jnp.linalg.inv(prior_cov) + M * jnp.linalg.inv(obs_cov))
#     obsmean = jnp.sum(obs) / M
#     post_mean = post_cov @ (
#         jnp.linalg.inv(prior_cov) @ prior_mean + M * jnp.linalg.inv(obs_cov) @ obsmean
#     )
#     return jnp.where(M > 0, post_mean, prior_mean), jnp.where(
#         M > 0, post_cov, prior_cov
#     )


# def dirichlet_categorical_update(key, associations, n_clusters, alpha):
#     """Returns (categorical_vector, metadata_dict)."""

#     def get_assoc_count(cluster_idx):
#         masked_relevant_datapoint_indices = tiling.relevant_datapoints_for_blob(
#             cluster_idx
#         )
#         relevant_associations = associations[masked_relevant_datapoint_indices.value]
#         return jnp.sum(
#             jnp.logical_and(
#                 masked_relevant_datapoint_indices.flag,
#                 relevant_associations == cluster_idx,
#             )
#         )

#     assoc_counts = jax.vmap(get_assoc_count)(jnp.arange(n_clusters))
#     prior_alpha = alpha
#     post_alpha = prior_alpha + assoc_counts
#     return dirichlet(post_alpha)(key), {}


# def conjugate_dirichlet_categorical(
#     key, associations, n_clusters, alpha, λ=model_simple_continuous.GAMMA_RATE_PARAMETER
# ):
#     """
#     Conjugate update for the case where we have
#         X_i ~ Gamma(alpha_i / n, lambda) for i = 1, 2, ..., n;
#         X_0 := sum_i X_i
#         p := [X_1, X_2, ..., X_n] / X_0
#         Y_i ~ Categorical(p) for i = 1, 2, ..., m.

#     Here, `n_clusters` is `n`, `associations` is `Y`,
#     and `alpha_vec_for_gamma_distributions[i-1]` is `alpha_i`.

#     Returns (mixture_weights, metadata), where `mixture_weights`
#     is the same thing as the vector `[X_1, X_2, ..., X_n]`.
#     """
#     ## Derivation of this update:
#     # With notation as the above, it turns out
#     # X_0 ~ Gamma(alpha.sum(), lambda),
#     # p ~ Dirichlet(alpha_1, alpha_2, ..., alpha_n),
#     # and X_0 and p are independent.
#     # Thus, the posterior on (X_0, p) is
#     # p ~ dirichlet_categorical_posterior(alpha, n, assoc_counts);
#     # X_0 ~ gamma(alpha.sum(), lambda). # Ie. same as the prior.
#     k1, k2 = jax.random.split(key)
#     posterior_pvec, _ = dirichlet_categorical_update(
#         k1, associations, n_clusters, alpha
#     )
#     total = gamma(alpha.sum(), λ)(k2)
#     return posterior_pvec * total, {}


# # one option in the mean time is to replace inverse_gamma in the model by a categorical with 64 values.
# def conjugate_update_mean_normal_inverse_gamma():
#     return None


# def conjugate_update_sigma_normal_inverse_gamma():
#     return None


# def conjugate_discrete_enumeration():
#     return None


# def update_xy_mean(xy_mean, xy_mean_blanket):
#     return None


# def update_xy_sigma(xy_sigma, xy_sigma_blanket):
#     return None


# def update_rgb_mean(rgb_mean, rgb_mean_blanket):
#     return None


# def update_rgb_sigma(rgb_sigma, rgb_sigma_blanket):
#     return None

In [None]:
### exact inference on discretized model
### exact Gibbs move on discretized model
### exact Gibbs for continuous model
### test exact Gibbs
### update for cont model

### What do I want?
# Ideally, model.exact_infer(key, obs, args)
# as well as the generality for sub_model.exact_infer(key, obs, args)
# something cool would be an automatic plotting of the BayesNet of the model.

### Math for complexity of exact inference on this model:
# model:
# P(latents | image, hypers)
# = P(latents, image | hypers) / P(image)
# = P(xy_mean, rgb_mean, mixture_weight | hypers)
# * P(blob_idx | mixture_weight)
# * P(xy | xy_mean rgb, blob_idx, hypers)
# * P( rgb | rgb_mean rgb, blob_idx, hypers)  / P(image) = sum_{xy_mean, rgb_mean, mixture_weight, blob_idx} P(latents, image | hypers)
# size of the sum: |xy_mean| * | rgb_mean| * |mixture_weight| * |blob_idx| * |n_blobs | * | image |
# = 64**2 * 64 ** 3 * 64 * 100 * 10 ** 6
# ~ 7. 10^18, without counting the cost of each eval nor ranging over a set of hyper parameters.
# L4 GPU can do 30 * 10^12 flop per sec -> problem.


### However, that's the naive version not tacking the Markov blanket and conditional independence into account. E.g. in HMM using dynamic programming
# we can reduce the complexity from exponential in the length of the chain to linear. The argument there is as follows for a chain of length 3.
# P(x1, x2, x2 | y1, y2, y3)
# = P(x1 | y1) . P(x2 | x1, y2). P(x3 | x2, y3)
# = (P(x1 , y1) / \sum_{x1} P(x1, y1))
# . (P(x1, x2, y2) / \sum_{x1, x2} P(x1, x2, y2))
# . (...)
# =

# NOTES:
# - I may want to keep the distribution for observed data continuous to simplify and avoid unnecessary 0 likelihood.
# - can compress the image to make inference faster, and one can even do SMC from low res to high res.



NEXT STEPS:
1) minimal visualization of the trace over time
2) inference loop with identity rewrite for Gibbs updates
3) add one by one the Gibbs updates and test them individually

### Main inference loop

In [None]:
N_ITER = 100
RECORD = True
DEBUG = False


def identity_update(key, tr):
    return tr


def infer(image, hypers):
    key = jax.random.key(32421)

    # Image pre-processing
    H = hypers.H
    W = hypers.W
    flattened_image = jnp.concatenate(
        (jnp.indices((H, W)).reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
    )
    xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

    # Setup for better initial trace
    n_blobs = hypers.n_blobs
    obs = C["likelihood_model", "xy"].set(xy) ^ C["likelihood_model", "rgb"].set(rgb)
    initial_weights = C["blob_model", "mixture_weight"].set(jnp.ones(n_blobs) / n_blobs)
    constraints = obs | initial_weights

    # Sample an initial trace
    key, subkey = jax.random.split(key)
    args = (hypers,)
    tr, _ = jax.jit(model_simple_continuous.model.importance)(subkey, constraints, args)

    # Record info for plotting and debugging purposes
    if RECORD:
        # TODO:
        all_posterior_xy_means = [tr.get_choices()["blob_model", "xy_mean"]]
        all_posterior_xy_variances = [tr.get_choices()["blob_model", "sigma_xy"]]
        all_posterior_rgb_means = [tr.get_choices()["blob_model", "rgb_mean"]]
        all_posterior_rgb_variances = [tr.get_choices()["blob_model", "sigma_rgb"]]
        all_cluster_assignment = [tr.get_choices()["likelihood_model", "blob_idx"]]
        all_posterior_weights = [tr.get_choices()["blob_model", "mixture_weight"]]

        if DEBUG:
            jax.debug.print("Initial means: {v}", v=all_posterior_xy_means[0])
            jax.debug.print("Initial weights: {v}", v=all_posterior_weights[0])

        # Main inference loop
        for _ in range(N_ITER):
            # Gibbs update on `("blob_model", "xy_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(identity_update)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "xy_mean"])

            # Gibbs update on `("blob_model", "sigma_xy", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(identity_update)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "sigma_xy"])

            # Gibbs update on `("blob_model", "rgb_mean", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(identity_update)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "rgb_mean"])

            # Gibbs update on `("blob_model", "sigma_rgb", i)` for each i, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(identity_update)(subkey, tr)
            all_posterior_xy_means.append(tr.get_choices()["blob_model", "sigma_rgb"])

            # Gibbs update on `("likelihood_model", "blob_idx", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(identity_update)(subkey, tr)
            all_cluster_assignment.append(
                tr.get_choices()["likelihood_model", "blob_idx"]
            )

            # Gibbs update on `("blob_model", "mixture_weight", i)` for each `i`, in parallel
            key, subkey = jax.random.split(key)
            tr = jax.jit(identity_update)(subkey, tr)
            all_posterior_weights.append(
                tr.get_choices()["blob_model", "mixture_weight"]
            )

        return (
            all_posterior_xy_means,
            all_posterior_xy_variances,
            all_posterior_rgb_means,
            all_posterior_rgb_variances,
            all_posterior_weights,
            all_cluster_assignment,
            tr,
        )


(
    all_posterior_xy_means,
    all_posterior_xy_variances,
    all_posterior_rgb_means,
    all_posterior_rgb_variances,
    all_posterior_weights,
    all_cluster_assignment,
    tr,
) = infer(image, hypers)

### Visualization