forked from pyro-ppl/numpyro
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
von Mises distribution (pyro-ppl#623)
* ADDED: Von Mises distribution log prob and half closed interval constraint * ADDED: Rejection sampler for sampling von Mises distribution; FIXED: deletion of _IntegerGreaterThan. * ADDED: tests, variance and mean. * ADDED: comments * FIXED: comments * REMOVED: halfopen interval from distributions.constraints; RENAMED: open_interval to interval in distributions.constrains. * ADDED: VonMises to __all__ in distributions init. * REMOVED: main from debugging in directional; ADDED: missing comma in distributions init. * REMOVED: GreaterThanEqual from constraints. * ADDED: Copyright notice. * UPDATED: factored out _sample_centered from VonMises to distributions.util; changed support to interval [-pi,pi]; changed private attributes _loc, _concentration. * FIXED: test and arg_constraints. * FIXED: lint and change np to jno in util. * CHANGED: np to jnp in directional. * CHANGED: style for von_mises_centered call. * UPDATED: for loop to while; CHANGED: to tests to check correct sample stats (circular mean and variance); loc domain to reals. * ADDED: tests cases. * FIXED: lint * UPDATED: hardcoded max iterations for _von_mises_centered. * CHANGED: loc attribute of VonMises not mapped to support.
- Loading branch information
1 parent
16c1f5c
commit 35aeab7
Showing
6 changed files
with
160 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import math | ||
|
||
import jax.numpy as jnp | ||
from jax import lax | ||
|
||
from numpyro.distributions import constraints | ||
from numpyro.distributions.distribution import Distribution | ||
from numpyro.distributions.util import promote_shapes, von_mises_centered | ||
from numpyro.distributions.util import validate_sample | ||
from numpyro.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(Distribution) | ||
class VonMises(Distribution): | ||
arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive} | ||
|
||
support = constraints.interval(-math.pi, math.pi) | ||
|
||
def __init__(self, loc, concentration, validate_args=None): | ||
""" von Mises distribution for sampling directions. | ||
:param loc: center of distribution | ||
:param concentration: concentration of distribution | ||
""" | ||
self.loc, self.concentration = promote_shapes(loc, concentration) | ||
|
||
batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(loc)) | ||
|
||
super(VonMises, self).__init__(batch_shape=batch_shape, | ||
validate_args=validate_args) | ||
|
||
def sample(self, key, sample_shape=()): | ||
""" Generate sample from von Mises distribution | ||
:param sample_shape: shape of samples | ||
:param key: random number generator key | ||
:return: samples from von Mises | ||
""" | ||
samples = von_mises_centered(key, self.concentration, sample_shape + self.shape()) | ||
samples = samples + self.loc # VM(0, concentration) -> VM(loc,concentration) | ||
samples = (samples + jnp.pi) % (2. * jnp.pi) - jnp.pi | ||
|
||
return samples | ||
|
||
@validate_sample | ||
def log_prob(self, value): | ||
return -(jnp.log(2 * jnp.pi) + lax.bessel_i0e(self.concentration)) + ( | ||
self.concentration * jnp.cos(value - self.loc)) | ||
|
||
@property | ||
def mean(self): | ||
""" Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi] """ | ||
return jnp.broadcast_to((self.loc + jnp.pi) % (2. * jnp.pi) - jnp.pi, self.batch_shape) | ||
|
||
@property | ||
def variance(self): | ||
""" Computes circular variance of distribution """ | ||
return jnp.broadcast_to(1. - lax.bessel_i1e(self.concentration) / lax.bessel_i0e(self.concentration), | ||
self.batch_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters