-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
truncnorm.py
297 lines (225 loc) · 9.02 KB
/
truncnorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
def _log_diff(x, y):
return logsumexp(
jnp.array([x, y]),
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
axis=0
)
def _log_gauss_mass(a, b):
"""Log of Gaussian probability mass within an interval"""
a, b = jnp.array(a), jnp.array(b)
a, b = jnp.broadcast_arrays(a, b)
# Note: Docstring carried over from scipy
# Calculations in right tail are inaccurate, so we'll exploit the
# symmetry and work only in the left tail
case_left = b <= 0
case_right = a > 0
case_central = ~(case_left | case_right)
def mass_case_left(a, b):
return _log_diff(log_ndtr(b), log_ndtr(a))
def mass_case_right(a, b):
return mass_case_left(-b, -a)
def mass_case_central(a, b):
# Note: Docstring carried over from scipy
# Previously, this was implemented as:
# left_mass = mass_case_left(a, 0)
# right_mass = mass_case_right(0, b)
# return _log_sum(left_mass, right_mass)
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
# Correct for this with an alternative formulation.
# We're not concerned with underflow here: if only one term
# underflows, it was insignificant; if both terms underflow,
# the result can't accurately be represented in logspace anyway
# because sc.log1p(x) ~ x for small x.
return jnp.log1p(-ndtr(a) - ndtr(-b))
out = jnp.select(
[case_left, case_right, case_central],
[mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)]
)
return out
def logpdf(x, a, b, loc=0, scale=1):
r"""Truncated normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logpdf``.
The truncated normal probability distribution is given by
.. math::
f(x, a, b) = \begin{cases}
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
0 & \mathrm{otherwise}
\end{cases}
where :math:`a` and :math:`b` are effectively specified in number of
standard deviations from zero. JAX uses the scipy nomenclature
of ``loc`` for the centroid and ``scale`` for the standard deviation.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
x_scaled = lax.div(lax.sub(x, loc), scale)
val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val)
val = jnp.where(a >= b, jnp.nan, val)
return val
def pdf(x, a, b, loc=0, scale=1):
r"""Truncated normal probability distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``pdf``.
The truncated normal probability distribution is given by
.. math::
f(x, a, b) = \begin{cases}
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
0 & \mathrm{otherwise}
\end{cases}
where :math:`a` and :math:`b` are effectively specified in number of
standard deviations from the centroid. JAX uses the scipy nomenclature
of ``loc`` for the centroid and ``scale`` for the standard deviation.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
return lax.exp(logpdf(x, a, b, loc, scale))
def logsf(x, a, b, loc=0, scale=1):
"""Truncated normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.truncnorm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
return logcdf(-x, -b, -a, -loc, scale)
def sf(x, a, b, loc=0, scale=1):
"""Truncated normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.truncnorm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
"""
return lax.exp(logsf(x, a, b, loc, scale))
def logcdf(x, a, b, loc=0, scale=1):
r"""Truncated normal log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.truncnorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logcdf > -0.1, x > a],
[0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
)
logcdf = jnp.where(a >= b, jnp.nan, logcdf)
return logcdf
def cdf(x, a, b, loc=0, scale=1):
r"""Truncated normal cumulative distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``cdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.truncnorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
return lax.exp(logcdf(x, a, b, loc, scale))