Skip to content

Commit

Permalink
von Mises distribution (pyro-ppl#623)
Browse files Browse the repository at this point in the history
* ADDED: Von Mises distribution log prob and half closed interval constraint

* ADDED: Rejection sampler for sampling von Mises distribution; FIXED: deletion of _IntegerGreaterThan.

* ADDED: tests, variance and mean.

* ADDED: comments

* FIXED: comments

* REMOVED: halfopen interval from distributions.constraints; RENAMED: open_interval to interval in distributions.constrains.

* ADDED: VonMises to __all__ in distributions init.

* REMOVED: main from debugging in directional; ADDED: missing comma in distributions init.

* REMOVED: GreaterThanEqual from constraints.

* ADDED: Copyright notice.

* UPDATED: factored out _sample_centered from VonMises to distributions.util; changed support to interval [-pi,pi]; changed private attributes _loc, _concentration.

* FIXED: test and arg_constraints.

* FIXED: lint and change np to jno in util.

* CHANGED: np to jnp in directional.

* CHANGED: style for von_mises_centered call.

* UPDATED: for loop to while; CHANGED: to tests to check correct sample stats (circular mean and variance); loc domain to reals.

* ADDED: tests cases.

* FIXED: lint

* UPDATED: hardcoded max iterations for _von_mises_centered.

* CHANGED: loc attribute of VonMises not mapped to support.
  • Loading branch information
OlaRonning committed Jun 14, 2020
1 parent 16c1f5c commit 35aeab7
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 9 deletions.
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
PRNGIdentity,
ZeroInflatedPoisson
)
from numpyro.distributions.directional import VonMises
from numpyro.distributions.distribution import (
Distribution,
ExpandedDistribution,
Expand Down Expand Up @@ -113,6 +114,7 @@
'TruncatedPolyaGamma',
'Uniform',
'Unit',
'VonMises',
'ZeroInflatedPoisson',

]
2 changes: 1 addition & 1 deletion numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
'Constraint',
]


import jax.numpy as jnp


Expand All @@ -60,6 +59,7 @@ class Constraint(object):
A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
"""

def __call__(self, x):
raise NotImplementedError

Expand Down
61 changes: 61 additions & 0 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import math

import jax.numpy as jnp
from jax import lax

from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, von_mises_centered
from numpyro.distributions.util import validate_sample
from numpyro.util import copy_docs_from


@copy_docs_from(Distribution)
class VonMises(Distribution):
arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}

support = constraints.interval(-math.pi, math.pi)

def __init__(self, loc, concentration, validate_args=None):
""" von Mises distribution for sampling directions.
:param loc: center of distribution
:param concentration: concentration of distribution
"""
self.loc, self.concentration = promote_shapes(loc, concentration)

batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(loc))

super(VonMises, self).__init__(batch_shape=batch_shape,
validate_args=validate_args)

def sample(self, key, sample_shape=()):
""" Generate sample from von Mises distribution
:param sample_shape: shape of samples
:param key: random number generator key
:return: samples from von Mises
"""
samples = von_mises_centered(key, self.concentration, sample_shape + self.shape())
samples = samples + self.loc # VM(0, concentration) -> VM(loc,concentration)
samples = (samples + jnp.pi) % (2. * jnp.pi) - jnp.pi

return samples

@validate_sample
def log_prob(self, value):
return -(jnp.log(2 * jnp.pi) + lax.bessel_i0e(self.concentration)) + (
self.concentration * jnp.cos(value - self.loc))

@property
def mean(self):
""" Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi] """
return jnp.broadcast_to((self.loc + jnp.pi) % (2. * jnp.pi) - jnp.pi, self.batch_shape)

@property
def variance(self):
""" Computes circular variance of distribution """
return jnp.broadcast_to(1. - lax.bessel_i1e(self.concentration) / lax.bessel_i0e(self.concentration),
self.batch_shape)
74 changes: 73 additions & 1 deletion numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from jax.scipy.linalg import solve_triangular
from jax.util import partial


# Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3.
_tr_params = namedtuple('tr_params', ['c', 'b', 'a', 'alpha', 'u_r', 'v_r', 'm', 'log_p', 'log1_p', 'log_h'])

Expand Down Expand Up @@ -63,6 +62,7 @@ def _binomial_btrs(key, p, n):
Hormann, "The Generation of Binonmial Random Variates"
(https://core.ac.uk/download/pdf/11007254.pdf)
"""

def _btrs_body_fn(val):
_, key, _, _ = val
key, key_u, key_v = random.split(key, 3)
Expand Down Expand Up @@ -346,6 +346,77 @@ def clamp_probs(probs):
return jnp.clip(probs, a_min=finfo.tiny, a_max=1. - finfo.eps)


def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64):
""" Compute centered von Mises samples using rejection sampling from [1] with wrapped Cauchy proposal.
*** References ***
[1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
:param key: random number generator key
:param concentration: concentration of distribution
:param shape: shape of samples
:param dtype: float precesions for choosing correct s cutfoff
:return: centered samples from von Mises
"""
shape = shape or jnp.shape(concentration)
dtype = canonicalize_dtype(dtype)
concentration = lax.convert_element_type(concentration, dtype)
concentration = jnp.broadcast_to(concentration, shape)
return _von_mises_centered(key, concentration, shape, dtype)


@partial(jit, static_argnums=(2, 3))
def _von_mises_centered(key, concentration, shape, dtype):
# Cutoff from TensorFlow probability
# (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570)
s_cutoff_map = {jnp.dtype(jnp.float16): 1.8e-1,
jnp.dtype(jnp.float32): 2e-2,
jnp.dtype(jnp.float64): 1.2e-4}
s_cutoff = s_cutoff_map.get(dtype)

r = 1. + jnp.sqrt(1. + 4. * concentration ** 2)
rho = (r - jnp.sqrt(2. * r)) / (2. * concentration)
s_exact = (1. + rho ** 2) / (2. * rho)

s_approximate = 1. / concentration

s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)

def cond_fn(*args):
""" check if all are done or reached max number of iterations """
i, _, done, _, _ = args[0]
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))

def body_fn(*args):
i, key, done, _, w = args[0]
uni_ukey, uni_vkey, key = random.split(key, 3)

u = random.uniform(key=uni_ukey, shape=shape, dtype=concentration.dtype, minval=-1., maxval=1.)
z = jnp.cos(jnp.pi * u)
w = jnp.where(done, w, (1. + s * z) / (s + z)) # Update where not done

y = concentration * (s - w)
v = random.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype, minval=-1., maxval=1.)

accept = (y * (2. - y) >= v) | (jnp.log(y / v) + 1. >= y)

return i+1, key, accept | done, u, w

init_done = jnp.zeros(shape, dtype=bool)
init_u = jnp.zeros(shape)
init_w = jnp.zeros(shape)

_, _, done, u, w = lax.while_loop(
cond_fun=cond_fn,
body_fun=body_fn,
init_val=(jnp.array(0), key, init_done, init_u, init_w)
)

return jnp.sign(u) * jnp.arccos(w)


# The is sourced from: torch.distributions.util.py
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Expand Down Expand Up @@ -376,6 +447,7 @@ class lazy_property(object):
first call; thereafter replacing the wrapped method into an instance
attribute.
"""

def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)
Expand Down
3 changes: 3 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def constrain_fn(model, model_args, model_kwargs, params, return_deterministic=F
sites from the model. Defaults to `False`.
:return: `dict` of transformed params.
"""

def substitute_fn(site):
if site['name'] in params:
return biject_to(site['fn'].support)(params[site['name']])
Expand Down Expand Up @@ -436,6 +437,7 @@ class Predictive(object):
:return: dict of samples from the predictive distribution.
"""

def __init__(self, model, posterior_samples=None, guide=None, params=None, num_samples=None,
return_sites=None, parallel=False):
if posterior_samples is None and num_samples is None:
Expand Down Expand Up @@ -505,6 +507,7 @@ def log_likelihood(model, posterior_samples, *args, **kwargs):
:param kwargs: model kwargs.
:return: dict of log likelihoods at observation sites.
"""

def single_loglik(samples):
model_trace = trace(substitute(model, samples)).get_trace(*args, **kwargs)
return {name: site['fn'].log_prob(site['value']) for name, site in model_trace.items()
Expand Down
27 changes: 20 additions & 7 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sample(self, key, sample_shape=()):
dist.Chi2: lambda df: osp.chi2(df),
dist.Dirichlet: lambda conc: osp.dirichlet(conc),
dist.Exponential: lambda rate: osp.expon(scale=jnp.reciprocal(rate)),
dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1./rate),
dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1. / rate),
dist.Gumbel: lambda loc, scale: osp.gumbel_r(loc=loc, scale=scale),
dist.HalfCauchy: lambda scale: osp.halfcauchy(scale=scale),
dist.HalfNormal: lambda scale: osp.halfnorm(scale=scale),
Expand All @@ -97,7 +97,6 @@ def sample(self, key, sample_shape=()):
dist.Logistic: lambda loc, scale: osp.logistic(loc=loc, scale=scale)
}


CONTINUOUS = [
T(dist.Beta, 1., 2.),
T(dist.Beta, 1., jnp.array([2., 2.])),
Expand Down Expand Up @@ -172,6 +171,11 @@ def sample(self, key, sample_shape=()):
T(dist.Uniform, jnp.array([0., 0.]), jnp.array([[2.], [3.]])),
]

DIRECTIONAL = [
T(dist.VonMises, 2., 10.),
T(dist.VonMises, 2., jnp.array([150., 10.])),
T(dist.VonMises, jnp.array([1 / 3 * jnp.pi, -1.]), jnp.array([20., 30.])),
]

DISCRETE = [
T(dist.BetaBinomial, 2., 5., 10),
Expand Down Expand Up @@ -291,7 +295,7 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
raise NotImplementedError('{} not implemented.'.format(constraint))


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE + DIRECTIONAL)
@pytest.mark.parametrize('prepend_shape', [
(),
(2,),
Expand Down Expand Up @@ -383,7 +387,7 @@ def test_pathwise_gradient(jax_dist, sp_dist, params):
assert_allclose(actual_grad[i], expected_grad, rtol=0.005)


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE + DIRECTIONAL)
@pytest.mark.parametrize('prepend_shape', [
(),
(2,),
Expand Down Expand Up @@ -564,7 +568,7 @@ def test_gamma_poisson_log_prob(shape):
assert_allclose(actual, expected, rtol=0.05)


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE + DIRECTIONAL)
def test_log_prob_gradient(jax_dist, sp_dist, params):
if jax_dist in [dist.LKJ, dist.LKJCholesky]:
pytest.skip('we have separated tests for LKJCholesky distribution')
Expand Down Expand Up @@ -596,7 +600,7 @@ def fn(*args):
assert_allclose(jnp.sum(actual_grad), expected_grad, rtol=0.01, atol=0.01)


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE + DIRECTIONAL)
def test_mean_var(jax_dist, sp_dist, params):
if jax_dist is _ImproperWrapper:
pytest.skip("Improper distribution does not has mean/var implemented")
Expand Down Expand Up @@ -650,14 +654,23 @@ def test_mean_var(jax_dist, sp_dist, params):

assert_allclose(jnp.mean(corr_samples, axis=0), expected_mean, atol=0.01)
assert_allclose(jnp.std(corr_samples, axis=0), expected_std, atol=0.01)
elif jax_dist in [dist.VonMises]:
# circular mean = sample mean
assert_allclose(d_jax.mean, jnp.mean(samples, 0), rtol=0.05, atol=1e-2)

# circular variance
x, y = jnp.mean(jnp.cos(samples), 0), jnp.mean(jnp.sin(samples), 0)

expected_variance = 1 - jnp.sqrt(x ** 2 + y ** 2)
assert_allclose(d_jax.variance, expected_variance, rtol=0.05, atol=1e-2)
else:
if jnp.all(jnp.isfinite(d_jax.mean)):
assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
if jnp.all(jnp.isfinite(d_jax.variance)):
assert_allclose(jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.05, atol=1e-2)


@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE)
@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS + DISCRETE + DIRECTIONAL)
@pytest.mark.parametrize('prepend_shape', [
(),
(2,),
Expand Down

0 comments on commit 35aeab7

Please sign in to comment.