This notebook builds on top of the `custom_distribution` one.

In [None]:
import jax
import jax.numpy as jnp
from genjax import Pytree, Weight, pretty
from genjax._src.generative_functions.distributions.distribution import Distribution
from genjax.typing import Any, Tuple
from jax.random import PRNGKey, split
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
pretty()

Recall how we defined a distribution for a Gaussian mixture, using the `Distribution` class.

In [None]:
@Pytree.dataclass
class GaussianMixture(Distribution):
    def random_weighted(self, key: PRNGKey, probs, means, vars) -> Tuple[Any, Weight]:
        probs = jnp.asarray(probs)
        means = jnp.asarray(means)
        vars = jnp.asarray(vars)
        cat = tfd.Categorical(probs=probs)
        cat_index = jnp.asarray(cat.sample(seed=key))
        normal = tfd.Normal(loc=means[cat_index], scale=vars[cat_index])
        key, subkey = split(key)
        normal_sample = normal.sample(seed=subkey)
        zipped = jnp.stack([probs, means, vars], axis=1)
        weight_recip = -jnp.log(
            sum(
                jax.vmap(
                    lambda z: tfd.Normal(loc=z[1], scale=z[2]).prob(normal_sample)
                    * tfd.Categorical(probs=probs).prob(z[0])
                )(zipped)
            )
        )

        return normal_sample, weight_recip

    def estimate_logpdf(self, key: PRNGKey, x, probs, means, vars) -> Weight:
        zipped = jnp.stack([probs, means, vars], axis=1)
        return jnp.log(
            sum(
                jax.vmap(
                    lambda z: tfd.Normal(loc=z[1], scale=z[2]).prob(x)
                    * tfd.Categorical(probs=probs).prob(z[0])
                )(zipped)
            )
        )

In the class above, note in `estimate_logpdf` how we computed the density as a sum over all possible paths in the that could lead to a particular outcome `x`. 

In fact, the same occurs in `random_weighted`: even though we know exactly the path we took to get to the sample `normal_sample`, when evaluating the reciprocal density, we also sum over all possible paths that could lead to that `value`. 

Precisely, this required to sum over all the possible values of the categorical distribution `cat`. We technically sampled two random values `cat_index` and `normal_sample`, but we are only interested in the distribution on `normal_sample`: we marginalized out the intermediate random variable `cat_index`. 

Mathematically, we have
`p(normal_sample) = sum_{cat_index} p(normal_sample, cat_index)`.

GenJAX supports a more general kind of distribution, that only need to be able to estimate their densities.
The correctness criterion for this to be valid are that the estimation should be unbiased, i.e. the correct value on average.

More precisely,  `estimate_logpdf` should return an unbiased density estimate, while `random_weighted` should return an unbiased estimate for the reciprocal density. In general you can't get one from the other, as the following example shows.

Flip a coin and with 50% chance return 1, otherwise 3. This gives an unbiased estimator of 2.
If we now return 1/1 with 50%, and 1/3 otherwise, the average value is 2/3, which is not 1/2.

Let's now define a Gaussian mixture distribution that only estimates its density.

In [None]:
@Pytree.dataclass
class StochasticGaussianMixture(Distribution):
    def random_weighted(self, key: PRNGKey, probs, means, vars) -> Tuple[Any, Weight]:
        probs = jnp.asarray(probs)
        means = jnp.asarray(means)
        vars = jnp.asarray(vars)
        cat = tfd.Categorical(probs=probs)
        cat_index = jnp.asarray(cat.sample(seed=key))
        normal = tfd.Normal(loc=means[cat_index], scale=vars[cat_index])
        key, subkey = split(key)
        normal_sample = normal.sample(seed=subkey)
        # We can estimate the reciprocal in constant time
        # The math magic will be detailed later
        weight_recip = -tfd.Normal(
            loc=means[cat_index], scale=vars[cat_index]
        ).log_prob(normal_sample)

        return normal_sample, weight_recip

    # We can also estimate the pdf in constant time
    # The math magic will here will also be detailed later
    def estimate_logpdf(self, key: PRNGKey, x, probs, means, vars) -> Weight:
        cat = tfd.Categorical(probs=probs)
        cat_index = jnp.asarray(cat.sample(seed=key))
        return tfd.Normal(loc=means[cat_index], scale=vars[cat_index]).log_prob(x)

Testing

In [None]:
# TODO: just run, see and plot the stochasticity, compare to the original implementation

The reason we need both is that both methods will be used at different times, notably depending on whether we use the distribution in a proposal or in a model, as we now show!

In [None]:
# TODO: add importance sampling here.

Now for those interested, we will now explain the math magic that enabled us to get fast and unbiased density estimators. 

In [None]:
# TODO: add the math magic here