-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
dirichlet.py
98 lines (74 loc) · 3.12 KB
/
dirichlet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import lax
import jax.numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_dtypes_inexact
from jax.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
def _is_simplex(x: Array) -> Array:
x_sum = jnp.sum(x, axis=0)
return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6)
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
r"""Dirichlet log probability distribution function.
JAX implementation of :obj:`scipy.stats.dirichlet` ``logpdf``.
The Dirichlet probability density function is
.. math::
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
in a :math:`K`-dimensional vector space.
Args:
x: arraylike, value at which to evaluate the PDF
alpha: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.dirichlet.pdf`
"""
return _logpdf(*promote_dtypes_inexact(x, alpha))
def _logpdf(x: Array, alpha: Array) -> Array:
if alpha.ndim != 1:
raise ValueError(
f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
)
if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
raise ValueError(
"`x` must have either the same number of entries as `alpha` "
f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
)
one = _lax_const(x, 1)
if x.shape[0] != alpha.shape[0]:
x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
if x.ndim > 1:
alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
def pdf(x: ArrayLike, alpha: ArrayLike) -> Array:
r"""Dirichlet probability distribution function.
JAX implementation of :obj:`scipy.stats.dirichlet` ``pdf``.
The Dirichlet probability density function is
.. math::
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
in a :math:`K`-dimensional vector space.
Args:
x: arraylike, value at which to evaluate the PDF
alpha: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.dirichlet.logpdf`
"""
return lax.exp(logpdf(x, alpha))