# Gaussian Mixture Model Example

Uses more of library.

In [None]:
# Imports
from typing import Callable

import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import numpyro.distributions as dist
from IPython.display import HTML
from jax import grad, vmap
from jax.tree_util import Partial
from jaxtyping import Array, PRNGKeyArray, Real

from diffusionlib.sampler import SamplerName, get_sampler
from diffusionlib.sde import SDE, VP
from diffusionlib.smc.feynman_kac_model import LikelihoodGuidedPF
from diffusionlib.smc.state_space_model import SimpleObservationSSM

matplotlib.rcParams["animation.embed_limit"] = 2**128
COLOR_POSTERIOR = "#a2c4c9"
COLOR_ALGORITHM = "#ff7878"

We'll look at an example of a 2 ($dim_x$) dimensional GMM with 25, evenly spaced clusters centered
on means $(8 \times \{-2, ..., 2\}, 8 \times \{-2, ..., 2\})$ with unit covariances. Our observation
is some 2 ($dim_y$) dimensional value using which we'd like to be able to sample from the posterior
distribution. The config below is easily changeable, and empirically the below methodology has
worked well with more (or fewer) clusters, larger $dim_x$, and larger $dim_y$.

In [None]:
# Config
key = random.PRNGKey(100)
num_steps = 1000
num_samples = 1000
dim_x = 2
dim_y = 2
measurement_noise_std = 1.0
size = 8.0  # mean multiplier
center_range = np.array((-2, 2))

beta_min = 0.01
beta_max = 20.0

# plotting
chart_lims = size * (center_range + np.array((-0.5, 0.5)))

From the [MCGDiff paper](https://arxiv.org/pdf/2308.07983.pdf) (page 5), the marginals of the
forward process are available in closed form, meaning we don't need a trained model to
estimate the score or noise ($\epsilon$).

In [None]:
# Define epsilon functions
# NOTE: model/epsilon function is just the negative score by the variance
def get_model_fn(
    ou_dist: Callable[[Array], dist.MixtureSameFamily], sde: SDE
) -> Callable[[Array, Array], Array]:
    return vmap(
        grad(
            lambda x, t: -jnp.sqrt(sde.marginal_variance(t))
            * ou_dist(sde.marginal_mean_coeff(t)).log_prob(x)
        )
    )

Next we define the mixture model:

In [None]:
# Define mixture model function
def ou_mixt(mean_coeff: float, means: Array, dim_x: int, weights: Array) -> dist.MixtureSameFamily:
    means = jnp.vstack(means) * mean_coeff
    covs = jnp.repeat(jnp.eye(dim_x)[None], axis=0, repeats=means.shape[0])
    return dist.MixtureSameFamily(
        component_distribution=dist.MultivariateNormal(loc=means, covariance_matrix=covs),
        mixing_distribution=dist.CategoricalProbs(weights),
    )

We'll define a helper function for Gaussian-Gaussian posterior calculation:

In [None]:
# Define posterior
def gaussian_posterior(
    y: Real[Array, " dim_y"],
    likelihood_a: Real[Array, "dim_y dim_x"],
    likelihood_bias: Real[Array, " dim_y"],
    likelihood_precision: Real[Array, "dim_y dim_y"],
    prior_loc: Real[Array, " dim_x"],
    prior_covar: Real[Array, "dim_x dim_x"],
) -> dist.MultivariateNormal:
    # Compute the precision matrix of the prior distribution
    prior_precision_matrix = jnp.linalg.inv(prior_covar)

    # Calculate the precision matrix of the posterior distribution
    posterior_precision_matrix = (
        prior_precision_matrix + likelihood_a.T @ likelihood_precision @ likelihood_a
    )

    # Calculate the covariance matrix of the posterior distribution
    posterior_covariance_matrix = jnp.linalg.inv(posterior_precision_matrix)

    # Calculate the mean of the posterior distribution
    posterior_mean = posterior_covariance_matrix @ (
        likelihood_a.T @ likelihood_precision @ (y - likelihood_bias)
        + prior_precision_matrix @ prior_loc
    )

    # Ensure symmetry and numerical stability of the covariance matrix
    # Handle potential numerical issues by regularization
    try:
        posterior_covariance_matrix = (
            posterior_covariance_matrix + posterior_covariance_matrix.T
        ) / 2
    except ValueError:
        u, s, v = jnp.linalg.svd(posterior_covariance_matrix, full_matrices=False)
        s = jnp.clip(s, 1e-12, 1e6).real
        posterior_covariance_matrix = u.real @ jnp.diag(s) @ v.real
        posterior_covariance_matrix = (
            posterior_covariance_matrix + posterior_covariance_matrix.T
        ) / 2

    return dist.MultivariateNormal(
        loc=posterior_mean, covariance_matrix=posterior_covariance_matrix
    )

Next we define a function which provides essentially exact posterior samples:

In [None]:
# Define posterior for the mixture model
def get_posterior(
    obs: Array, prior: dist.MixtureSameFamily, a: Array, sigma_y: Array
) -> dist.MixtureSameFamily:
    mixing_dist: dist.CategoricalProbs = prior.mixing_distribution
    component_dist: dist.MultivariateNormal = prior.component_distribution  # type: ignore
    comp_mean = component_dist.mean
    comp_cov: Array = component_dist.covariance_matrix  # type: ignore

    # Precompute the inverse of the observation noise covariance matrix
    precision = jnp.linalg.inv(sigma_y)
    modified_means = []
    modified_covars = []
    weights = []

    # Iterate through the components of the prior distribution
    for loc, cov, weight in zip(comp_mean, comp_cov, mixing_dist.probs):
        # Compute the posterior distribution for the current component
        new_dist = gaussian_posterior(obs, a, jnp.zeros_like(obs), precision, loc, cov)
        modified_means.append(new_dist.mean)
        modified_covars.append(new_dist.covariance_matrix)

        # Calculate the prior likelihood and residual
        prior_x = dist.MultivariateNormal(loc, covariance_matrix=cov)
        residue = obs - a @ new_dist.loc

        # Compute log-probability contributions
        log_constant = (
            -0.5 * residue @ precision @ residue.T
            + prior_x.log_prob(new_dist.mean)
            - new_dist.log_prob(new_dist.mean)
        )

        # Compute the log weight for the component
        weights.append(jnp.log(weight) + log_constant)

    # Normalize weights
    weights = jnp.array(weights)
    normalized_weights = weights - jax.scipy.special.logsumexp(weights)

    # Construct categorical distribution from the normalized weights
    categorical_distribution = dist.CategoricalLogits(logits=normalized_weights)

    # Construct a mixture distribution of multivariate normals
    multivariate_mixture = dist.MultivariateNormal(
        loc=jnp.stack(modified_means, axis=0),
        covariance_matrix=jnp.stack(modified_covars, axis=0),
    )

    return dist.MixtureSameFamily(categorical_distribution, multivariate_mixture)

Next we define functions for the inverse problem; see
[MCGDiff paper](https://arxiv.org/pdf/2308.07983.pdf), Appendix B.3.1 for details on the measurement 
model.

In [None]:
# Inverse problem functions
def extended_svd(a: Array) -> tuple[Array, Array, Array, Array]:
    # Compute the singular value decomposition
    u, s, v = jnp.linalg.svd(a, full_matrices=False)

    # Create a coordinate mask based on the length of the singular values
    coordinate_mask = jnp.concatenate([jnp.ones(len(s)), jnp.zeros(v.shape[0] - len(s))]).astype(
        bool
    )

    return u, s, v, coordinate_mask


def generate_measurement_equations(
    dim_x: int,
    dim_y: int,
    mixt: dist.MixtureSameFamily,
    noise_std: float,
    key: PRNGKeyArray,
):
    # Generate random keys for different sources of randomness
    key_a, key_diag, key_init_sample, key_init_obs = random.split(key, 4)

    # Create random matrix
    a = random.normal(key_a, (dim_y, dim_x))

    # Build extended SVD
    u, s, v, coordinate_mask = extended_svd(a)

    # Re-create `s` using uniform sampling, sorting the generated values to align with
    # properties of singular values being ordered in the SVD Sigma (`s`) matrix
    s_new = jnp.sort(random.uniform(key_diag, s.shape), descending=True)
    s_new_mat = jnp.diag(s_new)

    # Re-construct `a` using the sorted diag and coordinate mask
    a_recon: Real[Array, "{dim_y} {dim_x}"] = u @ s_new_mat @ v[coordinate_mask]

    # Sample initial data and simulate initial observations
    init_sample: Real[Array, "{dim_x}"] = mixt.sample(key_init_sample)

    init_obs: Real[Array, "{dim_y}"] = a_recon @ init_sample
    init_obs += random.normal(key_init_obs, init_obs.shape) * noise_std

    # Construct observation noise covariance matrix
    sigma_y = jnp.diag(jnp.full(dim_y, noise_std**2))

    return a_recon, sigma_y, u, s_new, v, coordinate_mask, init_obs

Now we actualy build/realise some samples from our prior model (i.e. $p(x_0)$, the equally spaced
gride GMM).

In [None]:
# Build prior (equal weighted, grid GMM)
means = [
    jnp.array([-size * i, -size * j] * (dim_x // 2))
    for i in range(center_range[0], center_range[1] + 1)
    for j in range(center_range[0], center_range[1] + 1)
]
weights = jnp.ones(len(means))
weights = weights / weights.sum()

ou_mixt_fun = Partial(ou_mixt, means=means, dim_x=dim_x, weights=weights)
mixt = ou_mixt_fun(1)

We can either sample from the prior analytically directly:

In [None]:
# Get analytic prior samples
analytic_prior_samples = mixt.sample(key, (1000,))

In [None]:
# Plot analytic prior samples
fig, ax = plt.subplots(figsize=(6, 6))

# Axes
ax.axhline(0, color="black", linewidth=0.5)
ax.axvline(0, color="black", linewidth=0.5)

# Samples
ax.scatter(
    x=analytic_prior_samples[:, 0],
    y=analytic_prior_samples[:, 1],
    color=COLOR_ALGORITHM,
    alpha=0.5,
    edgecolors="black",
    lw=0.5,
    s=10,
)

# Limits
ax.set_xlim(*chart_lims)
ax.set_ylim(*chart_lims)

# Labels
ax.set_xlabel("Coordinate 1")
ax.set_ylabel("Coordinate 2")
ax.set_title("Analytic Prior samples")

ax.grid(True)

plt.show()

Or we can sample using a diffusion model:

In [None]:
# Get model prior samples (i.e. DDPM)
key, sub_key = random.split(key)

sde = VP(jnp.array(beta_min), jnp.array(beta_max))
model = get_model_fn(ou_mixt_fun, sde)  # epsilon estimator

sampler = get_sampler(
    SamplerName.DDIM_VP,
    num_steps=num_steps,
    shape=(num_samples, dim_x),
    model=model,
    beta_min=beta_min,
    beta_max=beta_max,
    eta=1.0,  # NOTE: equates to using DDPM
    stack_samples=True,
)

# NOTE: lead axis of `prior_samples` is such that index 0 corresponds to X_0 (not X_T).
prior_samples: Real[Array, "{num_steps} {num_samples} {dim_x}"] = sampler.sample(sub_key)

In [None]:
# Plot model prior samples
fig, ax = plt.subplots(figsize=(6, 6))

# Axes
ax.axhline(0, color="black", linewidth=0.5)
ax.axvline(0, color="black", linewidth=0.5)

# Samples
ax.scatter(
    x=prior_samples[0, :, 0],
    y=prior_samples[0, :, 1],
    color=COLOR_ALGORITHM,
    alpha=0.5,
    edgecolors="black",
    lw=0.5,
    s=10,
)

# Limits
ax.set_xlim(*chart_lims)
ax.set_ylim(*chart_lims)

# Labels
ax.set_xlabel("Coordinate 1")
ax.set_ylabel("Coordinate 2")
ax.set_title("Prior samples")

ax.grid(True)

plt.show()

The diffusion model prior is generative, starting from $\mathcal{N}(\mathbf{0}_{d_x}, I_{d_x})$
noise ($X_T$), and evolving according to a discretization of the backwards SDE of some forward
process (determined by $\beta_{min}$ and $\beta_{max}$) using the score; $X_T \rightarrow X_0$.
We can create an animation showing this "denoising" process:

In [None]:
# Create animation of prior sampling (i.e. uncondtional diffusion)
fig, ax = plt.subplots(figsize=(6, 6))

skip = 10
rev_subset_prior_samples = np.concatenate(
    (prior_samples[::-skip], prior_samples[0][None, ...]), axis=0
)


def animate(i):
    # Clear axis for next frame
    ax.clear()

    # Axes
    ax.axhline(0, color="black", lw=0.5)
    ax.axvline(0, color="black", lw=0.5)

    # Samples
    ax.scatter(
        x=rev_subset_prior_samples[i, :, 0],
        y=rev_subset_prior_samples[i, :, 1],
        color=COLOR_ALGORITHM,
        alpha=0.5,
        edgecolors="black",
        lw=0.5,
        s=10,
    )

    # Limits
    ax.set_xlim(*chart_lims)
    ax.set_ylim(*chart_lims)

    # Labels
    ax.set_xlabel("Coordinate 1")
    ax.set_ylabel("Coordinate 2")
    ax.set_title(f"Prior sample generation\nt={num_steps - (i * skip)}")

    ax.grid(True)


ani = animation.FuncAnimation(fig, animate, frames=len(rev_subset_prior_samples), interval=100)
plt.close()

HTML(ani.to_jshtml())

Now we realise the some observation matrix and an observation to setup the inverse problem:

In [None]:
# Setup inverse problem
key, sub_key = random.split(key)

(a, sigma_y, u, diag, v, coordinate_mask, init_obs) = generate_measurement_equations(
    dim_x, dim_y, mixt, measurement_noise_std, sub_key
)

Now, we have some observation $y$, which was collected according to:
$$
y := Ax_0^* + \sigma_y\epsilon,\quad \epsilon \sim \mathcal{N}\left(\mathbf{0}_{d_y}, I_{d_y}\right)
$$
(where $x_0^*$ is some sample drawn from the analytic prior; the overall data distribution). Our
goal then is to sample from the posterior $p(x_0 \mid y)$.

Again, we can exactly sample from the posterior in the case of this model; this is useful for
comparison with our particle method:

In [None]:
# Get posterior samples
posterior = get_posterior(init_obs, mixt, a, sigma_y)
key, sub_key = random.split(key)

posterior_samples = posterior.sample(sub_key, (num_samples,))

In [None]:
# Plot true posterior samples
fig, ax = plt.subplots(figsize=(6, 6))

# Axes
ax.axhline(0, color="black", linewidth=0.5)
ax.axvline(0, color="black", linewidth=0.5)

# Samples
ax.scatter(
    x=posterior_samples[:, 0],
    y=posterior_samples[:, 1],
    color=COLOR_POSTERIOR,
    alpha=0.5,
    edgecolors="black",
    lw=0.5,
    s=10,
)

# Limits
ax.set_xlim(*chart_lims)
ax.set_ylim(*chart_lims)

# Labels
ax.set_xlabel("Coordinate 1")
ax.set_ylabel("Coordinate 2")
ax.set_title("Posterior samples")

ax.grid(True)

plt.show()

Now we consider using SMC to target the posterior. We start by shrinking the measurement to 0 so
that it aligns with the start $x_T^{(i)} \sim \mathcal{N}(0_{d_x}, I_{d_x})$ (and hence is close to
$Ax_T^{(i)} \approx 0_{d_y}$) for each particle $i$. The aim is to "schedule" this shrinkage
that for every $t \in \{T, ..., 0\}$, we have this proximity. There's many ways of shrinking; we can
do so geometrically or according to the $\alpha$ values associated with the prior model SDE (we
use $\sqrt{\overline{\alpha}_t}$ which aligns with the `MCGDiff` paper).

In [None]:
# Create auxiliary "shrunk" Y sequence

# Geometrically shrunk
alpha = 0.995
alphas_geo = alpha ** np.arange(num_steps)[:, None]

# SDE matching schedule
alphas_sde = sampler.sqrt_alphas_cumprod[:, None]

ys_geo = init_obs * alphas_geo
ys_sde = init_obs * alphas_sde

# NOTE: for particle filter since needs to run "backwards" (from T -> 0)
ys_geo_rev = ys_geo[::-1]
ys_sde_rev = ys_sde[::-1]

In [None]:
# Plot alpha schedule
plt.figure(figsize=(8, 6))
plt.axhline(0, color="black", lw=0.5)

for i in range(ys_geo.shape[1]):
    plt.plot(ys_geo[:, i], label=f"Coord {i} Geometric")
    plt.plot(ys_sde[:, i], label=f"Coord {i} SDE")

# Add titles and labels
plt.xlabel("$t$")
plt.ylabel("$\\alpha_tY$")
plt.legend()
plt.title("$\\alpha$ schedule comparison")
plt.show()

The geometric approach ends up being very diffult to tune properly, since for many values of $t$
it isn't providing sufficiently information for guiding the particle filter (via the
weight-resampling). Empirically, the one based on the SDE works very well.

We define our state space model:
\begin{align*}
\overline{X}_0 &\sim \mathcal{N}(0_{d_x}, I_{d_x}) \\
\overline{X}_t &\sim p_{t \mid t+1}(\cdot \mid X_{t+1}) \\
Y_t &\sim \mathcal{N}(HX_t, \sigma I_{d_y})
\end{align*}
(where $p_{t \mid t+1}(\cdot \mid X_{t+1})$ is our prior model (i.e. the unconditional DDPM)).

In [None]:
mgmm_ssm = SimpleObservationSSM(sampler=sampler, dim_x=dim_x, a=a, sigma_y=sigma_y)

Our FK model uses the $p_{t \mid t+1}$ transition of the SSM as the proposal but weights according
to the ratio of likelihoods:
\begin{align*}
\omega_t^{(i)} &= \frac{g(y_t \mid \overline{x}_t^{(i)})}{g(y_{t+1} \mid x_{t+1}^{i})} \\
&= \frac{g(\overline{\alpha}_t^{\frac{1}{2}} y \mid \overline{x}_t^{(i)})}{g(\overline{\alpha}_{t+1}^{\frac{1}{2}} y \mid x_{t+1}^{i})} \\
&= \frac{\mathcal{N}\left(\overline{\alpha}_t^{\frac{1}{2}} y;\ H\overline{x}_t^{(i)}, \sigma_y I_{d_y}\right)}{\mathcal{N}\left(\overline{\alpha}_{t+1}^{\frac{1}{2}} y;\ Hx_{t+1}^{(i)}, \sigma_y I_{d_y}\right)}
\end{align*}
(this is very similar a procedure to `MCGdiff` but with simplified covariance)

In [None]:
fk_guided = LikelihoodGuidedPF(mgmm_ssm, data=ys_sde_rev)

Now we run the particle filter, using systematic resampling and resampling when the ESS ratio drops
below 90% (encourages more resampling which empirically worked a bit better). As a note here, this
variable in conjuction with the scheduling provides a lot of flexibilty and it's unclear immediately
what's optimal...

In [None]:
particle_sampler = get_sampler(
    name=SamplerName.SMC, fk_model=fk_guided, num_particles=num_samples, essr_min=0.9
)

particle_samples = particle_sampler.sample(key)

We can plot our final posterior samples compared with the true ones:

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

ax.axhline(0, color="black", lw=0.5)
ax.axvline(0, color="black", lw=0.5)

# True posterior samples
ax.scatter(
    x=posterior_samples[:, 0],
    y=posterior_samples[:, 1],
    color=COLOR_POSTERIOR,
    alpha=0.5,
    edgecolors="black",
    lw=0.5,
    s=10,
)

# Particle samples
ax.scatter(
    x=particle_samples[:, 0],
    y=particle_samples[:, 1],
    color=COLOR_ALGORITHM,
    alpha=0.5,
    edgecolors="black",
    lw=0.5,
    s=10,
)

# Limits
ax.set_xlim(*chart_lims)
ax.set_ylim(*chart_lims)

# Labels
ax.set_xlabel("Coordinate 1")
ax.set_ylabel("Coordinate 2")
ax.set_title("Particle posterior samples")

ax.grid(True)

plt.show()

We can create a nice animation showing the evolution of the guided particle filter:

In [None]:
# Create animation of particle posterior sampling
subhist = [*particle_sampler.particle_history[::10], particle_sampler.particle_history[-1]]

fig, ax = plt.subplots(figsize=(6, 6))


def animate(i: int):
    # Clear axis for next frame
    ax.clear()

    # Axes
    ax.axhline(0, color="black", lw=0.5)
    ax.axvline(0, color="black", lw=0.5)

    # True posterior samples
    ax.scatter(
        x=posterior_samples[:, 0],
        y=posterior_samples[:, 1],
        color=COLOR_POSTERIOR,
        alpha=0.5,
        edgecolors="black",
        lw=0.5,
        s=10,
    )

    # Particle samples
    ax.scatter(
        x=subhist[i][:, 0],
        y=subhist[i][:, 1],
        color=COLOR_ALGORITHM,
        alpha=0.5,
        edgecolors="black",
        lw=0.5,
        s=10,
    )

    # Limits
    ax.set_xlim(*chart_lims)
    ax.set_ylim(*chart_lims)

    # Labels
    ax.set_xlabel("Coordinate 1")
    ax.set_ylabel("Coordinate 2")
    ax.set_title(f"Particle posterior sampling\nt={num_steps - (i * 10)}")

    ax.grid(True)


ani = animation.FuncAnimation(fig, animate, frames=len(subhist), interval=100)
plt.close()

HTML(ani.to_jshtml())

Finally, we can look at histograms (for each coordinate) of the observation matrix applied to the
true and particle posterior samples (true observation shown by black line).

In [None]:
fig, axs = plt.subplots(nrows=dim_y, ncols=2, figsize=(8, 3 * dim_y))

part_meas = particle_samples @ a.T
true_meas = posterior_samples @ a.T

for i in range(dim_y):
    # True posterior samples re-measured"
    axs[i, 0].set_title("True posterior observations")
    axs[i, 0].set_ylabel(f"Coordinate {i+1}")
    axs[i, 0].hist(true_meas[:, i], bins=20, color=COLOR_POSTERIOR)
    axs[i, 0].axvline(init_obs[i], color="black")

    # Particle posterior samples re-measured
    axs[i, 1].set_title("Particle posterior observations")
    axs[i, 1].hist(part_meas[:, i], bins=20, color=COLOR_ALGORITHM)
    axs[i, 1].axvline(init_obs[i], color="black")

plt.tight_layout()
plt.show()