Skip to content

Commit

Permalink
jax.scipy.stats: add logsf & make sf more accurate near zero
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 22, 2023
1 parent 6670ea4 commit d1c2277
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 36 deletions.
9 changes: 7 additions & 2 deletions docs/jax.scipy.rst
Expand Up @@ -192,6 +192,7 @@ jax.scipy.stats.beta
cdf
logcdf
sf
logsf

jax.scipy.stats.betabinom
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -225,6 +226,7 @@ jax.scipy.stats.cauchy
cdf
logcdf
sf
logsf
isf
ppf

Expand All @@ -240,6 +242,7 @@ jax.scipy.stats.chi2
cdf
logcdf
sf
logsf


jax.scipy.stats.dirichlet
Expand Down Expand Up @@ -272,6 +275,7 @@ jax.scipy.stats.gamma
cdf
logcdf
sf
logsf

jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -350,12 +354,13 @@ jax.scipy.stats.norm
.. autosummary::
:toctree: _autosummary

cdf
logcdf
logpdf
pdf
cdf
logcdf
ppf
sf
logsf
isf

jax.scipy.stats.pareto
Expand Down
22 changes: 18 additions & 4 deletions jax/_src/scipy/stats/beta.py
Expand Up @@ -59,12 +59,26 @@ def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,

@_wraps(osp_stats.beta.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, b, loc, scale))


@_wraps(osp_stats.beta.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, a, b, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
return betainc(
b,
a,
1 - lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)


@_wraps(osp_stats.beta.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, b, loc, scale))
11 changes: 8 additions & 3 deletions jax/_src/scipy/stats/cauchy.py
Expand Up @@ -53,9 +53,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:

@_wraps(osp_stats.cauchy.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, = promote_args_inexact("cauchy.sf", x)
cdf_result = cdf(x, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
return cdf(-x, -loc, scale)


@_wraps(osp_stats.cauchy.logsf, update_doc=False)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)


@_wraps(osp_stats.cauchy.isf, update_doc=False)
Expand Down
22 changes: 19 additions & 3 deletions jax/_src/scipy/stats/chi2.py
Expand Up @@ -20,7 +20,7 @@
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammainc
from jax.scipy.special import gammainc, gammaincc


@_wraps(osp_stats.chi2.logpdf, update_doc=False)
Expand Down Expand Up @@ -67,5 +67,21 @@ def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1

@_wraps(osp_stats.chi2.sf, update_doc=False)
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, df, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammaincc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, jnp.inf),
),
)


@_wraps(osp_stats.chi2.logsf, update_doc=False)
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, df, loc, scale))
5 changes: 5 additions & 0 deletions jax/_src/scipy/stats/gamma.py
Expand Up @@ -59,3 +59,8 @@ def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale))


@_wraps(osp_stats.gamma.logsf, update_doc=False)
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(sf(x, a, loc, scale))
13 changes: 1 addition & 12 deletions jax/_src/scipy/stats/truncnorm.py
Expand Up @@ -88,18 +88,7 @@ def pdf(x, a, b, loc=0, scale=1):
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
def logsf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)

logsf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logsf > -0.1, x > a],
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
)
logsf = jnp.where(a >= b, jnp.nan, logsf)
return logsf
return logcdf(-x, -b, -a, -loc, scale)


@_wraps(osp_stats.truncnorm.sf, update_doc=False)
Expand Down
5 changes: 3 additions & 2 deletions jax/scipy/stats/beta.py
Expand Up @@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.scipy.stats.beta import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf,
)
9 changes: 5 additions & 4 deletions jax/scipy/stats/cauchy.py
Expand Up @@ -16,11 +16,12 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.scipy.stats.cauchy import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
isf as isf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
ppf as ppf,
sf as sf,
)
5 changes: 3 additions & 2 deletions jax/scipy/stats/chi2.py
Expand Up @@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.scipy.stats.chi2 import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf,
)
5 changes: 3 additions & 2 deletions jax/scipy/stats/gamma.py
Expand Up @@ -16,9 +16,10 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.scipy.stats.gamma import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
sf as sf,
)
102 changes: 100 additions & 2 deletions tests/scipy_stats_test.py
Expand Up @@ -257,6 +257,22 @@ def args_maker():
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})

@genNamedParametersNArgs(5)
def testBetaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.beta.logsf
lax_fun = lsp_stats.beta.logsf

def args_maker():
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})

def testBetaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7645
a = b = 1.
Expand All @@ -279,7 +295,7 @@ def args_maker():
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker, tol={np.float64: 1E-14})

@genNamedParametersNArgs(3)
def testCauchyLogCdf(self, shapes, dtypes):
Expand All @@ -299,6 +315,42 @@ def args_maker():
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})

@genNamedParametersNArgs(3)
def testCauchyCdf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.cdf
lax_fun = lsp_stats.cauchy.cdf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})

@genNamedParametersNArgs(3)
def testCauchyLogSf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
scipy_fun = osp_stats.cauchy.logsf
lax_fun = lsp_stats.cauchy.logsf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})

@genNamedParametersNArgs(3)
def testCauchySf(self, shapes, dtypes):
rng = jtu.rand_default(self.rng())
Expand All @@ -314,7 +366,8 @@ def args_maker():
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14},
atol={np.float64: 1e-14})

@genNamedParametersNArgs(3)
def testCauchyIsf(self, shapes, dtypes):
Expand Down Expand Up @@ -450,6 +503,21 @@ def args_maker():

@genNamedParametersNArgs(4)
def testGammaLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.logsf
lax_fun = lsp_stats.gamma.logsf

def args_maker():
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(4)
def testGammaSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.gamma.sf
lax_fun = lsp_stats.gamma.sf
Expand Down Expand Up @@ -960,6 +1028,21 @@ def args_maker():
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(4)
def testChi2Cdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.cdf
lax_fun = lsp_stats.chi2.cdf

def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(4)
def testChi2Sf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
Expand All @@ -975,6 +1058,21 @@ def args_maker():
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(4)
def testChi2LogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.chi2.logsf
lax_fun = lsp_stats.chi2.logsf

def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(5)
def testBetaBinomLogPmf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
Expand Down

0 comments on commit d1c2277

Please sign in to comment.