## Nested approximate marginalisation & RAVI stacks [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChiSym/genjax/blob/main/docs/cookbook/inactive/expressivity/ravi_stack.ipynb)
### How to be recursively wrong everywhere all the time yet correct at the end

In [None]:
import sys

if "google.colab" in sys.modules:
    %pip install --quiet "genjax[genstudio]"

In [None]:
import jax
import jax.numpy as jnp

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import SelectionBuilder as S
from genjax import Target, gen, pretty

pretty()

Say you have a model of interest for which you want to do inference. It consists of a mixture of 3 Gaussians, two of which are close to each other while the other one is far. We will informally call cluster 1 the single Gaussian far from the others and cluster 2 the other two.

In [None]:
@gen
def model():
    idx = genjax.categorical(probs=[0.5, 0.25, 0.25]) @ "idx"
    # under the prior, 50% chance to be in cluster 1 and 50% chance to be in cluster 2.
    means = jnp.array([0.0, 10.0, 11.0])
    vars = jnp.array([1.0, 1.0, 1.0])
    x = genjax.normal(means[idx], vars[idx]) @ "x"
    y = genjax.normal(means[idx], vars[idx]) @ "y"
    return x, y


obs1 = C["x"].set(1.0)
obs2 = C["x"].set(10.5)

We will only care about the values of "x" and "y" in the output, so we will marginalize "idx" out.

In [None]:
marginal_model = model.marginal(
    selection=S["x"] | S["y"]
)  # This means we are projection onto the variables x and y, marginalizing out the rest

Testing the marginal model

In [None]:
key = jax.random.key(0)
marginal_model.simulate(key, ())

In [None]:
tr, w = marginal_model.importance(key, obs1, ())
tr.get_choices()

Now depending on what we observe, we will want to infer that the data was likely generated from one cluster (the single Gaussian far from the other ones) or the other (the two Gaussians close to each other).

Let's create a data-driven proposal that targets the model and will incorporate this logic. 
In order to avoid being too eager in our custom logic, we may want to just use this as a probabilistic heuristics instead of a deterministic one. After all, it's possible that the value 10.5 for "x" was generated from the cluster with a single Gaussian.

In [None]:
@gen
def proposal(target: Target):
    x_obs = target.constraint["x"]
    probs = jax.lax.cond(
        x_obs < 5.0,
        lambda _: jnp.array([0.9, 0.1]),
        lambda _: jnp.array([0.1, 0.9]),
        operand=None,
    )
    # if x_obs < 5, then our heuristics is to propose something closer to cluster 1 with probability 0.9, otherwise we propose in cluster 2 with probability 0.9.
    cluster_idx = genjax.categorical(probs=probs) @ "cluster_idx"
    means = jnp.array([0.0, 10.5])
    # second cluster is more spread out so we use a larger variance
    vars = jnp.array([1.0, 3.0])
    y = genjax.normal(means[cluster_idx], vars[cluster_idx]) @ "y"
    return y

Testing the proposal.

In [None]:
target = Target(marginal_model, (), obs1)
proposal.simulate(key, (target,))

So now this may seem great, but we cannot yet use this proposal as an importance sampler for the model. The issue is that the traces produced by the proposal don't match the ones the model accepts: the model doesn't know what to do with "cluster_idx".

In [None]:
k_particles = 500
alg = genjax.smc.ImportanceK(target, q=proposal.marginal(), k_particles=k_particles)

try:
    alg.simulate(key, (target,))
except Exception as e:
    # TODO: this currently doesn't raise an exception but in the future it should
    print(e)

There again, we can use marginal to marginalise out the variable from the proposal.

In [None]:
k_particles = 500
alg = genjax.smc.ImportanceK(
    target, q=proposal.marginal(selection=S["y"]), k_particles=k_particles
)

alg.simulate(key, (target,))