# Gen2D

This notebook shows a simple model for clustering a 2D image into different components.

In [None]:
import jax
import jax.numpy as jnp
import model_simple_continuous
from scipy import misc

from genjax import ChoiceMapBuilder as C
from genjax import pretty

pretty()

## Model

### Testing to sample from model

In [None]:
image = misc.face()
# plt.imshow(image)
# plt.show()

H, W, _ = image.shape

image.shape
hypers = model_simple_continuous.Hyperparams(
    a_x=1.0,
    b_x=1.0,
    a_y=1.0,
    b_y=1.0,
    mu_x=0.0,
    mu_y=0.0,
    a_rgb=jnp.array([1.0, 1.0, 1.0]),
    b_rgb=jnp.array([1.0, 1.0, 1.0]),
    alpha=1.0,
    sigma_xy=jnp.array([1.0, 1.0]),
    sigma_rgb=jnp.array([1.0, 1.0, 1.0]),
    n_blobs=100,
    H=H,
    W=W,
)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
tr = model_simple_continuous.model.simulate(subkey, (hypers,))

### Image processing

In [None]:
flattened_image = jnp.concatenate(
    (jnp.indices((H, W)).reshape(H * W, 2), image.reshape(H * W, 3)), axis=1
)

xy, rgb = flattened_image[:, :2], flattened_image[:, 2:]

### Initial trace for inference

In [None]:
key, subkey = jax.random.split(key)
obs = C["likelihood_model", "xy"].set(xy) ^ C["likelihood_model", "rgb"].set(rgb)

tr = model_simple_continuous.model.importance(subkey, obs, (hypers,))

## Inference

We will do inference via exact block-Gibbs, using the fact that the model is defined using conjugate pairs.

In [None]:
def conjugate_update_mvnormal_with_known_cov(
    prior_mean,  # (D,)
    prior_cov,  # (D, D)
    obs_cov,  # (D, D)
    obs,  # (M, D)
):
    """
    Returns the posterior mean and covariance for the mean
    of a multivariate normal distribution with known covariance.
    That is, given
      mu ~ Normal(prior_mean, prior_cov),
      obs_i ~ Normal(mu, obs_cov) for i = 0, 1, ..., M-1,
    this function returns (post_mean, post_cov) where
      P(mu | obs) = Normal(post_mean, post_cov).
    """
    M = obs.shape[0]
    post_cov = jnp.linalg.inv(jnp.linalg.inv(prior_cov) + M * jnp.linalg.inv(obs_cov))
    obsmean = jnp.sum(obs) / M
    post_mean = post_cov @ (
        jnp.linalg.inv(prior_cov) @ prior_mean + M * jnp.linalg.inv(obs_cov) @ obsmean
    )
    return jnp.where(M > 0, post_mean, prior_mean), jnp.where(
        M > 0, post_cov, prior_cov
    )


def dirichlet_categorical_update(key, associations, n_clusters, alpha):
    """Returns (categorical_vector, metadata_dict)."""

    def get_assoc_count(cluster_idx):
        masked_relevant_datapoint_indices = tiling.relevant_datapoints_for_blob(
            cluster_idx
        )
        relevant_associations = associations[masked_relevant_datapoint_indices.value]
        return jnp.sum(
            jnp.logical_and(
                masked_relevant_datapoint_indices.flag,
                relevant_associations == cluster_idx,
            )
        )

    assoc_counts = jax.vmap(get_assoc_count)(jnp.arange(n_clusters))
    prior_alpha = alpha
    post_alpha = prior_alpha + assoc_counts
    return dirichlet(post_alpha)(key), {}


def conjugate_dirichlet_categorical(
    key, associations, n_clusters, alpha, λ=model_simple_continuous.GAMMA_RATE_PARAMETER
):
    """
    Conjugate update for the case where we have
        X_i ~ Gamma(alpha_i / n, lambda) for i = 1, 2, ..., n;
        X_0 := sum_i X_i
        p := [X_1, X_2, ..., X_n] / X_0
        Y_i ~ Categorical(p) for i = 1, 2, ..., m.

    Here, `n_clusters` is `n`, `associations` is `Y`,
    and `alpha_vec_for_gamma_distributions[i-1]` is `alpha_i`.

    Returns (mixture_weights, metadata), where `mixture_weights`
    is the same thing as the vector `[X_1, X_2, ..., X_n]`.
    """
    ## Derivation of this update:
    # With notation as the above, it turns out
    # X_0 ~ Gamma(alpha.sum(), lambda),
    # p ~ Dirichlet(alpha_1, alpha_2, ..., alpha_n),
    # and X_0 and p are independent.
    # Thus, the posterior on (X_0, p) is
    # p ~ dirichlet_categorical_posterior(alpha, n, assoc_counts);
    # X_0 ~ gamma(alpha.sum(), lambda). # Ie. same as the prior.
    k1, k2 = split(key)
    posterior_pvec, _ = dirichlet_categorical_update(
        k1, associations, n_clusters, alpha
    )
    total = gamma(alpha.sum(), λ)(k2)
    return posterior_pvec * total, {}


# one option in the mean time is to replace inverse_gamma in the model by a categorical with 64 values.
def conjugate_update_mean_normal_inverse_gamma():
    return None


def conjugate_update_sigma_normal_inverse_gamma():
    return None


def conjugate_discrete_enumeration():
    return None


def update_xy_mean(xy_mean, xy_mean_blanket):
    return None


def update_xy_sigma(xy_sigma, xy_sigma_blanket):
    return None


def update_rgb_mean(rgb_mean, rgb_mean_blanket):
    return None


def update_rgb_sigma(rgb_sigma, rgb_sigma_blanket):
    return None

In [None]:
### write discretized model
### exact inference on discretized model
### exact Gibbs move on discretized model
### exact Gibbs for continuous model
### test exact Gibbs
### update for cont model