Skip to content

Commit

Permalink
handling pytrees properly
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 1, 2021
1 parent 2be9710 commit 6e6ece9
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 171 deletions.
34 changes: 20 additions & 14 deletions src/rmhmc/integrator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
__all__ = ["leapfrog", "implicit_midpoint"]

from typing import Any, Callable, NamedTuple, Tuple, Union
from typing import Any, Callable, NamedTuple, Optional, Tuple, Union

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from .base_types import (
Array,
KineticFunction,
KineticState,
Momentum,
Expand Down Expand Up @@ -39,6 +40,12 @@ class ImplicitMidpointState(NamedTuple):
]


def _axpby(a: Scalar, x: Array, y: Array, b: Optional[Scalar] = None) -> Array:
if b is None:
return tree_map(lambda x_, y_: a * x_ + y_, x, y)
return tree_map(lambda x_, y_: a * x_ + b * y_, x, y)


def leapfrog(
potential_fn: PotentialFunction, kinetic_fn: KineticFunction
) -> Tuple[IntegratorInitFunction, IntegratorUpdateFunction]:
Expand All @@ -51,13 +58,11 @@ def update_fn(
step_size: Scalar, kinetic_state: KineticState, state: IntegratorState
) -> Tuple[IntegratorState, bool]:
assert isinstance(state, LeapfrogState)
p = tree_map(
lambda p, dUdq: p - 0.5 * step_size * dUdq, state.p, state.dUdq
)
p = _axpby(-0.5 * step_size, state.dUdq, state.p)
dTdp = jax.grad(kinetic_fn, argnums=2)(kinetic_state, None, p)
q = tree_map(lambda q, dTdp: q + step_size * dTdp, state.q, dTdp)
q = _axpby(step_size, dTdp, state.q)
dUdq = dU(q)
p = tree_map(lambda p, dUdq: p - 0.5 * step_size * dUdq, p, dUdq)
p = _axpby(-0.5 * step_size, dUdq, p)
return LeapfrogState(q, p, dUdq), True

return init_fn, update_fn
Expand Down Expand Up @@ -88,22 +93,23 @@ def step(args: Tuple[Position, Momentum]) -> Tuple[Position, Momentum]:
q, p = args
dHdq, dHdp = vector_field(kinetic_state, q, p)
return (
state.q + 0.5 * step_size * dHdp,
state.p - 0.5 * step_size * dHdq,
_axpby(0.5 * step_size, dHdp, state.q),
_axpby(-0.5 * step_size, dHdq, state.p),
)

# Use an initial half step using the pre-computed vector field
q = state.q + 0.5 * step_size * state.dHdp
p = state.p - 0.5 * step_size * state.dHdq
q = _axpby(0.5 * step_size, state.dHdp, state.q)
p = _axpby(-0.5 * step_size, state.dHdq, state.p)

# Solve for the midpoint
(q, p), success = solve_fixed_point(step, (q, p), **solver_kwargs)

# Compute the resulting vector field and update the state
dHdq = (2.0 / step_size) * (state.p - p)
dHdp = (2.0 / step_size) * (q - state.q)
q = 2 * q - state.q
p = 2 * p - state.p
a = 2.0 / step_size
dHdq = _axpby(-a, p, state.p, a)
dHdp = _axpby(a, q, state.q, -a)
q = _axpby(2.0, q, state.q, -1.0)
p = _axpby(2.0, p, state.p, -1.0)

return ImplicitMidpointState(q, p, dHdq, dHdp), success

Expand Down
Empty file added tests/__init__.py
Empty file.
190 changes: 33 additions & 157 deletions tests/integrator_test.py
Original file line number Diff line number Diff line change
@@ -1,164 +1,29 @@
import dataclasses
from functools import partial
from typing import Callable, Tuple
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from rmhmc.base_types import Array, Momentum, Position, Scalar
from rmhmc.hamiltonian import (
System,
compute_total_energy,
euclidean,
integrate,
integrate_trajectory,
riemannian,
)
from rmhmc.integrator import IntegratorState


def sho(use_euclidean: bool) -> System:
def log_posterior(q: Position) -> Scalar:
return -0.5 * jnp.sum(q ** 2)

def metric(q: Position) -> Array:
return jnp.diag(jnp.ones_like(q))

if use_euclidean:
return euclidean(log_posterior)
return riemannian(log_posterior, metric)


def planet(use_euclidean: bool) -> System:
def log_posterior(q: Position) -> Scalar:
return 1.0 / jnp.sqrt(jnp.sum(q ** 2))

def metric(q: Position) -> Array:
return jnp.diag(jnp.ones_like(q))

if use_euclidean:
return euclidean(log_posterior)
return riemannian(log_posterior, metric)


def banana_problem(fixed: bool, use_euclidean: bool) -> System:
t = 0.5
sigma_y = 2.0
sigma_theta = 2.0
num_obs = 100

random = np.random.default_rng(1234)
theta = np.array([t, np.sqrt(1.0 - t)])
y = (
theta[0]
+ np.square(theta[1])
+ sigma_y * random.normal(size=(num_obs,))
)

def log_posterior(q: Position) -> Scalar:
p = q[0] + jnp.square(q[1])
ll = jnp.sum(jnp.square(y - p)) / sigma_y ** 2
lp = jnp.sum(jnp.square(theta)) / sigma_theta ** 2
return -0.5 * (ll + lp)

if fixed:

def metric(q: Position) -> Array:
return 10 * jnp.diag(jnp.ones_like(q))

else:

def metric(q: Position) -> Array:
n = y.size
s = 2.0 * n * q[1] / sigma_y ** 2
return jnp.array(
[
[n / sigma_y ** 2 + 1.0 / sigma_theta ** 2, s],
[
s,
4.0 * n * jnp.square(q[1]) / sigma_y ** 2
+ 1.0 / sigma_theta ** 2,
],
]
)

if fixed and use_euclidean:
return euclidean(log_posterior)
return riemannian(log_posterior, metric)


@dataclasses.dataclass(frozen=True)
class Problem:
builder: Callable[[], System]
q: Position
p: Momentum
num_steps: int
step_size: float
energy_prec: float = 1e-4
pos_prec: float = 5e-5


PROBLEMS = dict(
sho_riemannian=Problem(
partial(sho, False),
jnp.array([0.1]),
jnp.array([2.0]),
2000,
0.01,
),
sho_euclidean=Problem(
partial(sho, True),
jnp.array([0.1]),
jnp.array([2.0]),
2000,
0.01,
),
planet_riemannian=Problem(
partial(planet, False),
jnp.array([1.0, 0.0]),
jnp.array([0.0, 1.0]),
2000,
0.01,
pos_prec=5e-4,
),
planet_euclidean=Problem(
partial(planet, True),
jnp.array([1.0, 0.0]),
jnp.array([0.0, 1.0]),
2000,
0.01,
),
banana_riemannian=Problem(
partial(banana_problem, False, False),
jnp.array([0.1, 0.3]),
jnp.array([2.0, 0.5]),
2000,
0.001,
),
banana_fixed=Problem(
partial(banana_problem, True, False),
jnp.array([0.1, 0.3]),
jnp.array([2.0, 0.5]),
2000,
0.001,
),
banana_euclidean=Problem(
partial(banana_problem, True, True),
jnp.array([0.1, 0.3]),
jnp.array([2.0, 0.5]),
2000,
0.001,
energy_prec=0.002,
),
)
from .problems import PROBLEMS


def run(
system: System, num_steps: int, step_size: float, q: Position, p: Momentum
) -> Tuple[Scalar, Array, IntegratorState, Array]:
kinetic_state = system.kinetic_tune_init(q.size)
kinetic_state = system.kinetic_tune_init(ravel_pytree(q)[0].size)
state = system.integrator_init(kinetic_state, q, p)

calc_energy = partial(compute_total_energy, system, kinetic_state)
Expand Down Expand Up @@ -190,17 +55,30 @@ def test_reversibility(problem_name: str) -> None:
system = problem.builder() # type: ignore
func = jax.jit(partial(run, system, problem.num_steps, problem.step_size))
_, _, trace, _ = func(problem.q, problem.p)
_, _, rev_trace, _ = func(trace.q[-1], -trace.p[-1])
np.testing.assert_allclose(
trace.q[:-1][::-1], rev_trace.q[:-1], atol=problem.pos_prec

q = tree_map(lambda x_: x_[-1], trace.q)
p = tree_map(lambda x_: -x_[-1], trace.p)
_, _, rev_trace, _ = func(q, p)

tree_map(
lambda a, b: np.testing.assert_allclose(
a[:-1][::-1], b[:-1], atol=problem.pos_prec
),
trace.q,
rev_trace.q,
)


@pytest.mark.parametrize("problem_name", sorted(PROBLEMS.keys()))
def test_volume_conservation(problem_name: str) -> None:
problem = PROBLEMS[problem_name]
system = problem.builder() # type: ignore
kinetic_state = system.kinetic_tune_init(problem.q.size)

q_flat, unravel = ravel_pytree(problem.q)
p_flat, _ = ravel_pytree(problem.p)

N = q_flat.size
kinetic_state = system.kinetic_tune_init(N)
phi = jax.jit(
partial(
integrate,
Expand All @@ -215,38 +93,36 @@ def test_volume_conservation(problem_name: str) -> None:
ps = []

eps = 1e-6
N = problem.q.size
for n in range(N):
delta = 0.5 * eps * jnp.eye(N, 1, -n)[:, 0]
print(delta)

q = problem.q + delta
q = unravel(q_flat + delta)
plus, _ = phi(system.integrator_init(kinetic_state, q, problem.p))
q = problem.q - delta
q = unravel(q_flat - delta)
minus, _ = phi(system.integrator_init(kinetic_state, q, problem.p))
qs.append((plus.q - minus.q) / eps)
ps.append((plus.p - minus.p) / eps)
qs.append((ravel_pytree(plus.q)[0] - ravel_pytree(minus.q)[0]) / eps)
ps.append((ravel_pytree(plus.p)[0] - ravel_pytree(minus.p)[0]) / eps)

p = problem.p + delta
p = unravel(p_flat + delta)
plus, _ = phi(system.integrator_init(kinetic_state, problem.q, p))
p = problem.p - delta
p = unravel(p_flat - delta)
minus, _ = phi(system.integrator_init(kinetic_state, problem.q, p))
qs.append((plus.q - minus.q) / eps)
ps.append((plus.p - minus.p) / eps)
qs.append((ravel_pytree(plus.q)[0] - ravel_pytree(minus.q)[0]) / eps)
ps.append((ravel_pytree(plus.p)[0] - ravel_pytree(minus.p)[0]) / eps)

F = jnp.concatenate(
(jnp.stack(qs, axis=0), jnp.stack(ps, axis=0)), axis=-1
)
_, ld = jnp.linalg.slogdet(F)
np.testing.assert_allclose(ld, 0.0, atol=1e-4)
np.testing.assert_allclose(ld, 0.0, atol=problem.vol_prec)


@pytest.mark.parametrize("problem_name", sorted(PROBLEMS.keys()))
def test_integrate(problem_name: str) -> None:
problem = PROBLEMS[problem_name]
system = problem.builder() # type: ignore

kinetic_state = system.kinetic_tune_init(problem.q.size)
kinetic_state = system.kinetic_tune_init(ravel_pytree(problem.q)[0].size)
state = system.integrator_init(kinetic_state, problem.q, problem.p)
final_state, success = jax.jit(
partial(
Expand All @@ -271,4 +147,4 @@ def test_integrate(problem_name: str) -> None:
assert np.all(success)

for v, t in zip(final_state, trajectory):
np.testing.assert_allclose(v, t[-1])
tree_map(lambda a, b: np.testing.assert_allclose(a, b[-1]), v, t)

0 comments on commit 6e6ece9

Please sign in to comment.