Skip to content

Commit

Permalink
Fix#10219
Browse files Browse the repository at this point in the history
  • Loading branch information
YouJiacheng committed Apr 12, 2022
1 parent 35b32ee commit 4695dd9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
8 changes: 7 additions & 1 deletion jax/_src/scipy/stats/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
from jax.scipy.special import expit, logit

from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.lax_numpy import _promote_args_inexact
from jax._src.numpy import lax_numpy as jnp


@_wraps(osp_stats.logistic.logpdf, update_doc=False)
def logpdf(x):
return lax.neg(x) - 2. * lax.log1p(lax.exp(lax.neg(x)))
x, = _promote_args_inexact("logistic.logpdf", x)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))

@_wraps(osp_stats.logistic.pdf, update_doc=False)
def pdf(x):
Expand Down
7 changes: 7 additions & 0 deletions tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,13 @@ def args_maker():
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)

def testLogisticLogpdfOverflow(self):
# Regression test for https://github.com/google/jax/issues/10219
self.assertAllClose(
np.array([-100, -100], np.float32),
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
check_dtypes=False)

@genNamedParametersNArgs(1)
def testLogisticPpf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
Expand Down

0 comments on commit 4695dd9

Please sign in to comment.