This notebook builds on top of the `custom_distribution` one.

In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from genjax import Pytree, Weight, pretty
from genjax._src.generative_functions.distributions.distribution import Distribution
from genjax.typing import Any, Tuple
from jax.random import PRNGKey, split
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
pretty()

Recall how we defined a distribution for a Gaussian mixture, using the `Distribution` class.

In [None]:
@Pytree.dataclass
class GaussianMixture(Distribution):
    def random_weighted(self, key: PRNGKey, probs, means, vars) -> Tuple[Any, Weight]:
        probs = jnp.asarray(probs)
        means = jnp.asarray(means)
        vars = jnp.asarray(vars)
        cat = tfd.Categorical(probs=probs)
        cat_index = jnp.asarray(cat.sample(seed=key))
        normal = tfd.Normal(loc=means[cat_index], scale=vars[cat_index])
        key, subkey = split(key)
        normal_sample = normal.sample(seed=subkey)
        zipped = jnp.stack([probs, means, vars], axis=1)
        weight_recip = -jax.scipy.special.logsumexp(
            jax.vmap(
                lambda z: tfd.Normal(loc=z[1], scale=z[2]).log_prob(normal_sample)
                + tfd.Categorical(probs=probs).log_prob(z[0])
            )(zipped)
        )

        return normal_sample, weight_recip

    def estimate_logpdf(self, key: PRNGKey, x, probs, means, vars) -> Weight:
        zipped = jnp.stack([probs, means, vars], axis=1)
        return jax.scipy.special.logsumexp(
            jax.vmap(
                lambda z: tfd.Normal(loc=z[1], scale=z[2]).log_prob(x)
                + tfd.Categorical(probs=probs).log_prob(z[0])
            )(zipped)
        )

In the class above, note in `estimate_logpdf` how we computed the density as a sum over all possible paths in the that could lead to a particular outcome `x`. 

In fact, the same occurs in `random_weighted`: even though we know exactly the path we took to get to the sample `normal_sample`, when evaluating the reciprocal density, we also sum over all possible paths that could lead to that `value`. 

Precisely, this required to sum over all the possible values of the categorical distribution `cat`. We technically sampled two random values `cat_index` and `normal_sample`, but we are only interested in the distribution on `normal_sample`: we marginalized out the intermediate random variable `cat_index`. 

Mathematically, we have
`p(normal_sample) = sum_{cat_index} p(normal_sample, cat_index)`.

GenJAX supports a more general kind of distribution, that only need to be able to estimate their densities.
The correctness criterion for this to be valid are that the estimation should be unbiased, i.e. the correct value on average.

More precisely,  `estimate_logpdf` should return an unbiased density estimate, while `random_weighted` should return an unbiased estimate for the reciprocal density. In general you can't get one from the other, as the following example shows.

Flip a coin and with 50% chance return 1, otherwise 3. This gives an unbiased estimator of 2.
If we now return 1/1 with 50%, and 1/3 otherwise, the average value is 2/3, which is not 1/2.

Let's now define a Gaussian mixture distribution that only estimates its density.

In [None]:
@Pytree.dataclass
class StochasticGaussianMixture(Distribution):
    def random_weighted(self, key: PRNGKey, probs, means, vars) -> Tuple[Any, Weight]:
        probs = jnp.asarray(probs)
        means = jnp.asarray(means)
        vars = jnp.asarray(vars)
        cat = tfd.Categorical(probs=probs)
        cat_index = jnp.asarray(cat.sample(seed=key))
        normal = tfd.Normal(loc=means[cat_index], scale=vars[cat_index])
        key, subkey = split(key)
        normal_sample = normal.sample(seed=subkey)
        # We can estimate the reciprocal (marginal) density in constant time. Math magic explained at the end!
        weight_recip = -tfd.Normal(
            loc=means[cat_index], scale=vars[cat_index]
        ).log_prob(normal_sample)
        return normal_sample, weight_recip

    # Given a sample `x`, we can also estimate the density in constant time
    # Math again explained at the end.
    # TODO: we could probably improve further with a better proposal
    def estimate_logpdf(self, key: PRNGKey, x, probs, means, vars) -> Weight:
        cat = tfd.Categorical(probs=probs)
        cat_index = jnp.asarray(cat.sample(seed=key))
        return tfd.Normal(loc=means[cat_index], scale=vars[cat_index]).log_prob(x)

To test, we start by creating a generative function using our new distribution.

In [None]:
key = PRNGKey(0)
sgm = StochasticGaussianMixture()


@genjax.gen
def model(cat_probs, means, vars):
    x = sgm(cat_probs, means, vars) @ "x"
    y_means = jnp.repeat(x, len(means))
    y = sgm(cat_probs, y_means, vars) @ "y"
    return (x, y)

We can then simulate from the model, assess a trace, or use importance sampling with the default proposal, seemlessly. 

In [None]:
cat_probs = jnp.array([0.1, 0.4, 0.2, 0.3])
means = jnp.array([0.0, 1.0, 2.0, 3.0])
vars = jnp.array([1.0, 1.0, 1.0, 1.0])

tr = model.simulate(key, (cat_probs, means, vars))
tr

In [None]:
# TODO: currently raises a not implemented error
# model.assess(tr.get_choices(), (cat_probs, means, vars))

In [None]:
y = 2.0
model.importance(key, C["y"].set(y), (cat_probs, means, vars))

Let's also check that the distribution `sgm` unbiasedly estimates the density.

In [None]:
gm = GaussianMixture()
x = 2.0
N = 42
n_estimates = 2000000
cat_probs = jnp.array(jnp.arange(1.0 / N, 1.0 + 1.0 / N, 1.0 / N))
cat_probs = cat_probs / jnp.sum(cat_probs)
means = jnp.arange(0.0, N * 1.0, 1.0)
vars = jnp.ones(N) / N
cat_probs
key = PRNGKey(0)
keys = split(key, n_estimates)
log_density = gm.estimate_logpdf(key, x, cat_probs, means, vars)  # exact value
log_density
jitted = jax.jit(jax.vmap(sgm.estimate_logpdf, in_axes=(0, None, None, None, None)))
estimates = jitted(keys, x, cat_probs, means, vars)
log_mean_estimates = jax.scipy.special.logsumexp(estimates) - jnp.log(len(estimates))
# TODO: somehow there's a bug, it doesn't converge to the exact value
# error is always a factor of 3.
# np.exp(log_mean_estimates)/jnp.exp(log_density)
log_density, log_mean_estimates

In [None]:
# TODO: find a way to plot decently
# plt.hist(estimates, bins=500)

One benefit of using density estimates instead of exact ones is that it can be much faster to compute.

In [None]:
# TODO: doesn't shine here as the problem is too simple: it's parallel friendly linear time for the exact one.
N = 30000
n_estimates = 10
key = PRNGKey(0)
keys = split(key, n_estimates)
cat_probs = jnp.array(jnp.arange(1.0 / N, 1.0 + 1.0 / N, 1.0 / N))
cat_probs = cat_probs / jnp.sum(cat_probs)
means = jnp.arange(0.0, N * 1.0, 1.0)
vars = jnp.ones(N) / N

jitted_exact = jax.jit(gm.estimate_logpdf)
jitted_approx = jax.jit(
    lambda key, x, cat_probs, means, vars: jax.scipy.special.logsumexp(
        jax.vmap(sgm.estimate_logpdf, in_axes=(0, None, None, None, None))(
            key, x, cat_probs, means, vars
        )
    )
    - jnp.log(n_estimates)
)

jitted_exact(key, x, cat_probs, means, vars)
jitted_approx(keys, x, cat_probs, means, vars)
%timeit jitted(keys, x, cat_probs, means, vars)
%timeit jitted_approx(keys, x, cat_probs, means, vars)

The reason we need both is that both methods will be used at different times, notably depending on whether we use the distribution in a proposal or in a model, as we now show!

Let's define a simple model and a proposal which both use our `sgm` distribution.

In [None]:
@genjax.gen
def model(cat_probs, means, vars):
    x = sgm(cat_probs, means, vars) @ "x"
    y_means = jnp.repeat(x, len(means))
    y = sgm(cat_probs, y_means, vars) @ "y"
    return (x, y)


@genjax.gen
def proposal(obs, cat_probs, means, vars):
    y = obs["y"]
    new_means = jax.vmap(lambda m: (m + y) / 2)(means)
    x = sgm(cat_probs, new_means, vars) @ "x"
    return (x, y)

Let's define importance sampling once again! Note that it is exactly the same as the usual one. This is because GenJAX implements `simulate` using `random_weighted` and `assess` using `estimate_logpdf`.

In [None]:
def gensp_importance_sampling(target, proposal):
    def _inner(key, target_args, proposal_args):
        trace = proposal.simulate(key, *proposal_args)
        chm = trace.get_sample()
        proposal_logpdf = trace.get_score()
        target_logpdf, _ = target.assess(chm, *target_args)
        importance_weight = target_logpdf - proposal_logpdf
        return (trace, importance_weight)

    return _inner

Testing

In [None]:
chm = C["y"].set(2.0)

# TODO: awkward parenthesis and same problem about assess not implemented.
# gensp_importance_sampling(model, proposal)(key, ((cat_probs, means, vars),), ((chm, cat_probs, means, vars),))

Finally, for those curious about the math magic that enabled to correctly (meaning unbiasedly) estimate the pdf and its reciprocal.

In [None]:
# TODO: Math is that p(x) = sum_i p(x|z=i)p(z=i) = E_z[p(x|z)]
# And we can estimate this expectation by sampling z from the categorical distribution