In [None]:
from dataclasses import dataclass
from typing import Tuple

import jax
import jax.numpy as jnp
from jax import jit
from jax import tree_util as jtu

import genjax
from genjax import ChoiceMap, GenerativeFunction, pretty
from genjax import ChoiceMapBuilder as C
from genjax import SelectionBuilder as S
from genjax.typing import static_check_supports_grad

pretty()
key = jax.random.PRNGKey(0)

Let's create a class `SymplecticIntegrator`. Integrators solve differential equations by taking a point and integrating along the path defined by the differential equation to the solution at some time `t`. It is symplectic when it satisfies a particular equation that is useful for the stablity of the solution of certain dynamical systems, and in particular the one used in the HMC algorithm.

The particular we will implement is the classic leap frop integrator.

In [None]:
class SymplecticIntegrator:
    def integrate(potential, kinetic, init_q, init_p):
        raise NotImplementedError


# Vector space structure of choice maps
# choice maps are the data structures that store all the random variables
# described by a Gen model. Each trace contains a choice map and some other
# metadata
def chm_scale(c: float, chm: ChoiceMap) -> ChoiceMap:
    return jtu.tree_map(lambda x: c * x, chm)


def chm_add(chm1: ChoiceMap, chm2: ChoiceMap) -> ChoiceMap:
    return jtu.tree_map(lambda x, y: x + y, chm1, chm2)


# Utility functions
def tree_zip(diff_tree, nondiff_tree):
    return jtu.tree_map(
        lambda v1, v2: v1 if v1 is not None else v2, diff_tree, nondiff_tree
    )


def my_closure(f, diff_tree, nondiff_tree):
    full_tree = tree_zip(diff_tree, nondiff_tree)
    return f(full_tree)


def get_diff_tree(tree):
    return jtu.tree_map(lambda v: v if static_check_supports_grad(v) else None, tree)


def get_nondiff_tree(tree):
    return jtu.tree_map(
        lambda v: v if not static_check_supports_grad(v) else None, tree
    )


@dataclass
class ChmLeapFrog(SymplecticIntegrator):
    """
    A leap frog integrator on the space of choice maps
    """

    step_size: float
    num_steps: int

    def integrate(self, potential, kinetic, init_q, init_p):
        non_diff_q = get_nondiff_tree(init_q)
        non_diff_p = get_nondiff_tree(init_p)

        def diff_potential(x):
            return my_closure(potential, x, non_diff_q)

        def diff_kinetic(x):
            return my_closure(kinetic, x, non_diff_p)

        dUdq = jax.grad(diff_potential)
        dKdp = jax.grad(diff_kinetic)

        def step(qp, _):
            q, p = qp
            p_halfway = chm_add(
                p, get_diff_tree(chm_scale(-self.step_size / 2, dUdq(q)))
            )
            next_q = chm_add(
                q, get_diff_tree(chm_scale(self.step_size, dKdp(p_halfway)))
            )
            next_p = chm_add(
                p_halfway, get_diff_tree(chm_scale(self.step_size / 2, dUdq(next_q)))
            )
            return (next_q, next_p), None

        return jax.lax.scan(step, (init_q, init_p), length=self.num_steps)[0]

We now define a class `HMCSampler` that will run HMC on the choicemap of a generative function.

In [None]:
@dataclass
class HMCSampler:
    model: GenerativeFunction  # the joint model
    model_args: Tuple  # the arguments of the model (e.g. hyperparameters)
    data: ChoiceMap  # the data on which we are conditioning
    integrator: SymplecticIntegrator  # the symplectic integration strategy
    mass: float = 1.0

    def potential(self, params: ChoiceMap) -> float:
        result = self.model.importance(
            jax.random.PRNGKey(0), params.merge(self.data), self.model_args
        )[0].get_score()
        return -result

    def kinetic(self, momentum: ChoiceMap) -> float:
        return jtu.tree_reduce(lambda s, x: s + x**2, momentum, 0) / self.mass / 2

    def sample_momentum(self, key) -> ChoiceMap:
        keys = jax.random.split(key, self.param_structure.num_leaves)
        momentum = jax.vmap(jax.random.normal)(keys) * self.mass
        return jtu.tree_unflatten(self.param_structure, momentum)

    def hamiltonian(self, params: ChoiceMap, momentum: ChoiceMap) -> float:
        potential_value = self.potential(params)
        kinetic_value = self.kinetic(momentum)
        return potential_value + kinetic_value

    def hmc_step(self, key, params) -> Tuple[ChoiceMap, bool]:
        momentum_key, mh_key = jax.random.split(key)
        momentum = self.sample_momentum(momentum_key)
        hamiltonian = self.hamiltonian(params, momentum)
        new_params, new_momentum = self.integrator.integrate(
            self.potential, self.kinetic, params, momentum
        )
        new_hamiltonian = self.hamiltonian(new_params, new_momentum)
        log_mh_ratio = hamiltonian - new_hamiltonian
        acc = jnp.log(jax.random.uniform(mh_key)) < log_mh_ratio
        return jax.lax.cond(acc, lambda: new_params, lambda: params), acc

We can write a simple generative function on which to test our HMC algorithm.

In [None]:
@genjax.gen
def model():
    x = genjax.normal(0.0, 4.0) @ "x"
    y = genjax.normal(x, 1.0) @ "y"
    return y

We can finally test HMC on our model.

In [None]:
init_params = C["x"].set(0.0)
data = C["y"].set(1.0)
leap_size = 1e-3
num_steps = 40
hmc_sampler = HMCSampler(model, (), data, ChmLeapFrog(leap_size, num_steps))

hmc_sampler.hmc_step(key, init_params)

We can also test for a slightly more complex model, e.g. an HMM.

In [None]:
length_chain = 50
state_size = 50
number_runs = 100
# for numerical stability of the HMM, ensuring that the eigenvalues of the transition matrices are around 1.
magic_number = jnp.exp(1)
normalizer = 1.0 / jnp.sqrt(state_size / magic_number)
key, subkey = jax.random.split(key)
transition_matrix = jax.random.normal(subkey, (state_size, state_size)) * normalizer

key, subkey = jax.random.split(key)
observation_matrix = jax.random.normal(subkey, (state_size, state_size)) * normalizer
latent_variance = jnp.eye(state_size)
obs_variance = jnp.eye(state_size)
key, subkey = jax.random.split(key)


@genjax.gen
def initial_state_model():
    return (
        genjax.mv_normal(
            jnp.zeros(state_size, dtype=float), jnp.identity(state_size, dtype=float)
        )
        @ "initial_state"
    )


@genjax.gen
def hmm_step(x, _):
    new_x = (
        genjax.mv_normal(jnp.matmul(transition_matrix, x), latent_variance) @ "new_x"
    )
    _ = genjax.mv_normal(jnp.matmul(observation_matrix, new_x), obs_variance) @ "obs"
    return new_x, None


@genjax.gen
def hmm():
    x = initial_state_model() @ "init"
    _ = hmm_step.scan(n=length_chain)(x, None) @ "steps"


# Testing that the model runs
jitted = jit(hmm.repeat(n=number_runs).simulate)
key, subkey = jax.random.split(key)
trace = jitted(subkey, ())
trace.get_sample()

In [None]:
# Creating observations
obs = jax.vmap(
    lambda idx: C["steps", idx, "obs"].set(
        idx.astype(float) * jnp.arange(state_size) / state_size
    )
)(jnp.arange(length_chain))

# Creating an initial state
key, subkey = jax.random.split(key)
init_params = (
    hmm.importance(subkey, obs, ())[0]
    .get_choices()
    .filter(S["initial_state"] | S["steps", ..., "new_x"])
    .simplify()
)

# type of init_params
param_structure = jtu.tree_structure(init_params)


# Running HMC
hmc_sampler = HMCSampler(
    hmm, (), obs, param_structure, ChmLeapFrog(leap_size, num_steps)
)
key, subkey = jax.random.split(key)
jit(hmc_sampler.hmc_step)(subkey, init_params)

Finally, note that this is not a tutorial on HMC. Typical implementations have many more parameters, including a burn-in period, a way to tune the step size and number of steps, and more sophisticated methods to decide when to stop the algorithm. This simpler version is only slightly fancier version of the typical Metropolis-Hastings random-walk algorithm that one could use as a rejuvenation kernel in an SMC algorithm.