In [None]:
import jax
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from genjax import Pytree, Weight, gen, pretty
from genjax._src.generative_functions.distributions.distribution import Distribution
from genjax.typing import PRNGKey
from jax import jit
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
key = jax.random.PRNGKey(0)
pretty()

We can create a distribution as a class. Here we'll create a simple mixture of Gaussians.

In [None]:
@Pytree.dataclass
class GaussianMixture(Distribution):
    # It can have static args
    bias: float = Pytree.static(default=0.0)

    # For distributions that can compute their densities exactly, `random_weighted` should return a sample x and the reciprocal density 1/p(x).
    def random_weighted(self, key: PRNGKey, probs, means, vars) -> tuple[any, Weight]:
        # making sure that the inputs are jnp arrays for jax compatibility
        probs = jnp.asarray(probs)
        means = jnp.asarray(means)
        vars = jnp.asarray(vars)

        # sampling from the categorical distribution and then sampling from the normal distribution
        cat = tfd.Categorical(probs=probs)
        cat_index = jnp.asarray(cat.sample(seed=key))
        normal = tfd.Normal(
            loc=means[cat_index] + jnp.asarray(self.bias), scale=vars[cat_index]
        )
        key, subkey = jax.random.split(key)
        normal_sample = normal.sample(seed=subkey)

        # calculating the reciprocal density
        zipped = jnp.stack([probs, means, vars], axis=1)
        weight_recip = -jnp.log(
            sum(
                jax.vmap(
                    lambda z: tfd.Normal(
                        loc=z[1] + jnp.asarray(self.bias), scale=z[2]
                    ).prob(normal_sample)
                    * tfd.Categorical(probs=probs).prob(z[0])
                )(zipped)
            )
        )

        return normal_sample, weight_recip

    # For distributions that can compute their densities exactly, `estimate_logpdf` should return the log density at x.
    def estimate_logpdf(self, key: jax.random.PRNGKey, x, probs, means, vars) -> Weight:
        zipped = jnp.stack([probs, means, vars], axis=1)
        return jnp.log(
            sum(
                jax.vmap(
                    lambda z: tfd.Normal(
                        loc=z[1] + jnp.asarray(self.bias), scale=z[2]
                    ).prob(x)
                    * tfd.Categorical(probs=probs).prob(z[0])
                )(zipped)
            )
        )

Testing:

In [None]:
# Create a particular instance of the distribution
gauss_mix = GaussianMixture(0.0)


@gen
def model(probs):
    mix1 = gauss_mix(probs, jnp.array([0.0, 1.0]), jnp.array([1.0, 1.0])) @ "mix1"
    mix2 = gauss_mix(probs, jnp.array([0.0, 1.0]), jnp.array([1.0, 1.0])) @ "mix2"
    return mix1, mix2


probs = jnp.array([0.5, 0.5])
key, subkey = jax.random.split(key)
model.simulate(subkey, (probs,))
key, subkey = jax.random.split(key)
jit(model.simulate)(subkey, (probs,))

Testing importance sampling

In [None]:
key, subkey = jax.random.split(key)
jit(model.importance)(subkey, C["mix1"].set(3.0), (probs,))