Skip to content

Commit

Permalink
adding test for sampling momentum
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 1, 2021
1 parent 6e6ece9 commit 241cf40
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/hamiltonian_test.py
@@ -0,0 +1,44 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax, random, vmap
from jax.flatten_util import ravel_pytree

from rmhmc.base_types import Array, Momentum
from rmhmc.hamiltonian import euclidean, riemannian

L = np.random.default_rng(9).normal(size=(5, 5))
L[np.diag_indices_from(L)] = np.exp(L[np.diag_indices_from(L)])
L[np.triu_indices_from(L, 1)] = 0.0


@pytest.mark.parametrize(
"cov",
[
3.5 * jnp.eye(2),
10.0 * jnp.eye(5),
jnp.diag(np.random.default_rng(77).uniform(size=4)),
L @ L.T + 1e-1 * jnp.eye(L.shape[0]),
],
)
def test_sample_momentum_euclidean(cov: Array) -> None:
ndim = cov.shape[0]
system = euclidean(lambda x: 0.5 * jnp.sum(x ** 2), cov=cov)
kinetic_state = system.kinetic_tune_init(ndim)

def _sample(
key: random.KeyArray, _: int
) -> Tuple[random.KeyArray, Momentum]:
key1, key2 = random.split(key)
return key2, system.sample_momentum(
kinetic_state, jnp.zeros(ndim), key1
)

_, result = lax.scan(_sample, random.PRNGKey(5), jnp.arange(100_000))

np.testing.assert_allclose(
jnp.dot(cov, jnp.cov(result, rowvar=0)), np.eye(ndim), atol=0.05
)

0 comments on commit 241cf40

Please sign in to comment.