diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 5547cfab0427..7e8a6ac86418 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -192,6 +192,7 @@ jax.scipy.stats.beta cdf logcdf sf + logsf jax.scipy.stats.betabinom ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -225,6 +226,7 @@ jax.scipy.stats.cauchy cdf logcdf sf + logsf isf ppf @@ -240,6 +242,7 @@ jax.scipy.stats.chi2 cdf logcdf sf + logsf jax.scipy.stats.dirichlet @@ -272,6 +275,7 @@ jax.scipy.stats.gamma cdf logcdf sf + logsf jax.scipy.stats.gennorm ~~~~~~~~~~~~~~~~~~~~~~~ @@ -350,12 +354,13 @@ jax.scipy.stats.norm .. autosummary:: :toctree: _autosummary - cdf - logcdf logpdf pdf + cdf + logcdf ppf sf + logsf isf jax.scipy.stats.pareto diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index cf8ae2e3b8e9..4e796ca33bfc 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -59,12 +59,26 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, @_wraps(osp_stats.beta.logcdf, update_doc=False) def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike, - loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(cdf(x, a, b, loc, scale)) @_wraps(osp_stats.beta.sf, update_doc=False) def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike, - loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - cdf_result = cdf(x, a, b, loc, scale) - return lax.sub(_lax_const(cdf_result, 1), cdf_result) + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale) + return betainc( + b, + a, + 1 - lax.clamp( + _lax_const(x, 0), + lax.div(lax.sub(x, loc), scale), + _lax_const(x, 1), + ) + ) + + +@_wraps(osp_stats.beta.logsf, update_doc=False) +def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike, + loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(sf(x, a, b, loc, scale)) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 426b1eec0a07..38565ff65c7a 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -53,9 +53,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: @_wraps(osp_stats.cauchy.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - x, = promote_args_inexact("cauchy.sf", x) - cdf_result = cdf(x, loc, scale) - return lax.sub(_lax_const(cdf_result, 1), cdf_result) + x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale) + return cdf(-x, -loc, scale) + + +@_wraps(osp_stats.cauchy.logsf, update_doc=False) +def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale) + return logcdf(-x, -loc, scale) @_wraps(osp_stats.cauchy.isf, update_doc=False) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index 912f225befe8..76decb29e722 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -20,7 +20,7 @@ from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax._src.typing import Array, ArrayLike -from jax.scipy.special import gammainc +from jax.scipy.special import gammainc, gammaincc @_wraps(osp_stats.chi2.logpdf, update_doc=False) @@ -67,5 +67,21 @@ def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1 @_wraps(osp_stats.chi2.sf, update_doc=False) def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: - cdf_result = cdf(x, df, loc, scale) - return lax.sub(_lax_const(cdf_result, 1), cdf_result) + x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale) + two = _lax_const(scale, 2) + return gammaincc( + lax.div(df, two), + lax.clamp( + _lax_const(x, 0), + lax.div( + lax.sub(x, loc), + lax.mul(scale, two), + ), + _lax_const(x, jnp.inf), + ), + ) + + +@_wraps(osp_stats.chi2.logsf, update_doc=False) +def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(sf(x, df, loc, scale)) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index dcfb9439a6e1..8a5e70215bc6 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -59,3 +59,8 @@ def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale) return gammaincc(a, lax.div(lax.sub(x, loc), scale)) + + +@_wraps(osp_stats.gamma.logsf, update_doc=False) +def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + return lax.log(sf(x, a, loc, scale)) diff --git a/jax/_src/scipy/stats/truncnorm.py b/jax/_src/scipy/stats/truncnorm.py index 3f116c42e0d6..e4c48271de60 100644 --- a/jax/_src/scipy/stats/truncnorm.py +++ b/jax/_src/scipy/stats/truncnorm.py @@ -88,18 +88,7 @@ def pdf(x, a, b, loc=0, scale=1): @_wraps(osp_stats.truncnorm.logsf, update_doc=False) def logsf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale) - x, a, b = jnp.broadcast_arrays(x, a, b) - x = lax.div(lax.sub(x, loc), scale) - logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b) - logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b) - - logsf = jnp.select( - # third condition: avoid catastrophic cancellation (from scipy) - [x >= b, x <= a, logsf > -0.1, x > a], - [-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf] - ) - logsf = jnp.where(a >= b, jnp.nan, logsf) - return logsf + return logcdf(-x, -b, -a, -loc, scale) @_wraps(osp_stats.truncnorm.sf, update_doc=False) diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 963181fa0226..5c57dda6bb56 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -16,9 +16,10 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.scipy.stats.beta import ( - logpdf as logpdf, - pdf as pdf, cdf as cdf, logcdf as logcdf, + logpdf as logpdf, + logsf as logsf, + pdf as pdf, sf as sf, ) diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index b3b0d994c865..4ff79f5f9888 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -16,11 +16,12 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.scipy.stats.cauchy import ( - logpdf as logpdf, - pdf as pdf, cdf as cdf, - logcdf as logcdf, - sf as sf, isf as isf, + logcdf as logcdf, + logpdf as logpdf, + logsf as logsf, + pdf as pdf, ppf as ppf, + sf as sf, ) diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index 9cb28c8a616b..e17a2e331958 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -16,9 +16,10 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.scipy.stats.chi2 import ( - logpdf as logpdf, - pdf as pdf, cdf as cdf, logcdf as logcdf, + logpdf as logpdf, + logsf as logsf, + pdf as pdf, sf as sf, ) diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 268fc4fa03de..8efecafed3bd 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -16,9 +16,10 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.scipy.stats.gamma import ( - logpdf as logpdf, - pdf as pdf, cdf as cdf, logcdf as logcdf, + logpdf as logpdf, + logsf as logsf, + pdf as pdf, sf as sf, ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index ad9549ca17df..49d5075a435e 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -257,6 +257,22 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) + @genNamedParametersNArgs(5) + def testBetaLogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.beta.logsf + lax_fun = lsp_stats.beta.logsf + + def args_maker(): + x, a, b, loc, scale = map(rng, shapes, dtypes) + return [x, a, b, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) + def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 a = b = 1. @@ -279,7 +295,7 @@ def args_maker(): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker, tol={np.float64: 1E-14}) @genNamedParametersNArgs(3) def testCauchyLogCdf(self, shapes, dtypes): @@ -299,6 +315,42 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) + @genNamedParametersNArgs(3) + def testCauchyCdf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.cdf + lax_fun = lsp_stats.cauchy.cdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, + atol={np.float64: 1e-14}) + + @genNamedParametersNArgs(3) + def testCauchyLogSf(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + scipy_fun = osp_stats.cauchy.logsf + lax_fun = lsp_stats.cauchy.logsf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + # clipping to ensure that scale is not too low + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, + atol={np.float64: 1e-14}) + @genNamedParametersNArgs(3) def testCauchySf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) @@ -314,7 +366,8 @@ def args_maker(): with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, + atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testCauchyIsf(self, shapes, dtypes): @@ -450,6 +503,21 @@ def args_maker(): @genNamedParametersNArgs(4) def testGammaLogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.gamma.logsf + lax_fun = lsp_stats.gamma.logsf + + def args_maker(): + x, a, loc, scale = map(rng, shapes, dtypes) + return [x, a, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(4) + def testGammaSf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.gamma.sf lax_fun = lsp_stats.gamma.sf @@ -960,6 +1028,21 @@ def args_maker(): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) + def testChi2Cdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.cdf + lax_fun = lsp_stats.chi2.cdf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) def testChi2Sf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) @@ -975,6 +1058,21 @@ def args_maker(): tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) + def testChi2LogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.chi2.logsf + lax_fun = lsp_stats.chi2.logsf + + def args_maker(): + x, df, loc, scale = map(rng, shapes, dtypes) + return [x, df, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng())