Skip to content

Commit

Permalink
Fix loc and scale parameters in scipy.logistic. Add CDF and SF for se…
Browse files Browse the repository at this point in the history
…veral distributions.
  • Loading branch information
b0nce authored and joglekara committed Mar 24, 2023
1 parent 0e2cf94 commit 8f4b8a0
Show file tree
Hide file tree
Showing 14 changed files with 453 additions and 46 deletions.
17 changes: 16 additions & 1 deletion docs/jax.scipy.rst
Expand Up @@ -172,6 +172,9 @@ jax.scipy.stats.beta

logpdf
pdf
cdf
logcdf
sf

jax.scipy.stats.betabinom
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -192,6 +195,11 @@ jax.scipy.stats.cauchy

logpdf
pdf
cdf
logcdf
sf
isf
ppf

jax.scipy.stats.chi2
~~~~~~~~~~~~~~~~~~~~
Expand All @@ -202,7 +210,9 @@ jax.scipy.stats.chi2

logpdf
pdf

cdf
logcdf
sf


jax.scipy.stats.dirichlet
Expand Down Expand Up @@ -232,6 +242,9 @@ jax.scipy.stats.gamma

logpdf
pdf
cdf
logcdf
sf

jax.scipy.stats.gennorm
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -296,6 +309,8 @@ jax.scipy.stats.norm
logpdf
pdf
ppf
sf
isf

jax.scipy.stats.pareto
~~~~~~~~~~~~~~~~~~~~~~
Expand Down
30 changes: 29 additions & 1 deletion jax/_src/scipy/stats/beta.py
Expand Up @@ -19,7 +19,7 @@
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import betaln, xlogy, xlog1py
from jax.scipy.special import betaln, betainc, xlogy, xlog1py


@_wraps(osp_stats.beta.logpdf, update_doc=False)
Expand All @@ -40,3 +40,31 @@ def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, b, loc, scale))


@_wraps(osp_stats.beta.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
return betainc(
a,
b,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)


@_wraps(osp_stats.beta.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, b, loc, scale))


@_wraps(osp_stats.beta.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, a, b, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
44 changes: 42 additions & 2 deletions jax/_src/scipy/stats/cauchy.py
Expand Up @@ -18,8 +18,8 @@

from jax import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import promote_args_inexact
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax.numpy import arctan
from jax._src.typing import Array, ArrayLike


Expand All @@ -31,6 +31,46 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
normalize_term = lax.log(lax.mul(pi, scale))
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))


@_wraps(osp_stats.cauchy.pdf, update_doc=False)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))



@_wraps(osp_stats.cauchy.cdf, update_doc=False)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))


@_wraps(osp_stats.cauchy.logcdf, update_doc=False)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, loc, scale))


@_wraps(osp_stats.cauchy.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, = promote_args_inexact("cauchy.sf", x)
cdf_result = cdf(x, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)


@_wraps(osp_stats.cauchy.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
return lax.add(lax.mul(unscaled, scale), loc)


@_wraps(osp_stats.cauchy.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
return lax.add(lax.mul(unscaled, scale), loc)
49 changes: 39 additions & 10 deletions jax/_src/scipy/stats/chi2.py
Expand Up @@ -20,23 +20,52 @@
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammainc


@_wraps(osp_stats.chi2.logpdf, update_doc=False)
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
two = _lax_const(x, 2)
y = lax.div(lax.sub(x, loc), scale)
df_on_two = lax.div(df, two)
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
two = _lax_const(x, 2)
y = lax.div(lax.sub(x, loc), scale)
df_on_two = lax.div(df, two)

kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs)

@_wraps(osp_stats.chi2.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))
return lax.exp(logpdf(x, df, loc, scale))


@_wraps(osp_stats.chi2.cdf, update_doc=False)
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammainc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, jnp.inf),
),
)


@_wraps(osp_stats.chi2.logcdf, update_doc=False)
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, df, loc, scale))


@_wraps(osp_stats.chi2.sf, update_doc=False)
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, df, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)
26 changes: 25 additions & 1 deletion jax/_src/scipy/stats/gamma.py
Expand Up @@ -19,7 +19,7 @@
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax.scipy.special import gammaln, xlogy
from jax.scipy.special import gammaln, xlogy, gammainc, gammaincc


@_wraps(osp_stats.gamma.logpdf, update_doc=False)
Expand All @@ -35,3 +35,27 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1)
@_wraps(osp_stats.gamma.pdf, update_doc=False)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, a, loc, scale))


@_wraps(osp_stats.gamma.cdf, update_doc=False)
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
return gammainc(
a,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, jnp.inf),
)
)


@_wraps(osp_stats.gamma.logcdf, update_doc=False)
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.log(cdf(x, a, loc, scale))


@_wraps(osp_stats.gamma.sf, update_doc=False)
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
return gammaincc(a, lax.div(lax.sub(x, loc), scale))
35 changes: 22 additions & 13 deletions jax/_src/scipy/stats/logistic.py
Expand Up @@ -23,29 +23,38 @@


@_wraps(osp_stats.logistic.logpdf, update_doc=False)
def logpdf(x: ArrayLike) -> Array:
x, = promote_args_inexact("logistic.logpdf", x)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
x = lax.div(lax.sub(x, loc), scale)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))


@_wraps(osp_stats.logistic.pdf, update_doc=False)
def pdf(x: ArrayLike) -> Array:
return lax.exp(logpdf(x))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, loc, scale))


@_wraps(osp_stats.logistic.ppf, update_doc=False)
def ppf(x):
return logit(x)
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
return lax.add(lax.mul(logit(x), scale), loc)


@_wraps(osp_stats.logistic.sf, update_doc=False)
def sf(x):
return expit(lax.neg(x))
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))


@_wraps(osp_stats.logistic.isf, update_doc=False)
def isf(x):
return -logit(x)
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)


@_wraps(osp_stats.logistic.cdf, update_doc=False)
def cdf(x):
return expit(x)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
return expit(lax.div(lax.sub(x, loc), scale))
12 changes: 12 additions & 0 deletions jax/_src/scipy/stats/norm.py
Expand Up @@ -24,6 +24,7 @@
from jax._src.typing import Array, ArrayLike
from jax.scipy import special


@_wraps(osp_stats.norm.logpdf, update_doc=False)
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
Expand Down Expand Up @@ -54,3 +55,14 @@ def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
@_wraps(osp_stats.norm.ppf, update_doc=False)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return jnp.asarray(special.ndtri(q) * scale + loc, float)


@_wraps(osp_stats.norm.sf, update_doc=False)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
cdf_result = cdf(x, loc, scale)
return lax.sub(_lax_const(cdf_result, 1), cdf_result)


@_wraps(osp_stats.norm.isf, update_doc=False)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
1 change: 1 addition & 0 deletions jax/_src/scipy/stats/t.py
Expand Up @@ -36,6 +36,7 @@ def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))


@_wraps(osp_stats.t.pdf, update_doc=False)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
return lax.exp(logpdf(x, df, loc, scale))
3 changes: 3 additions & 0 deletions jax/scipy/stats/beta.py
Expand Up @@ -18,4 +18,7 @@
from jax._src.scipy.stats.beta import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)
5 changes: 5 additions & 0 deletions jax/scipy/stats/cauchy.py
Expand Up @@ -18,4 +18,9 @@
from jax._src.scipy.stats.cauchy import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
isf as isf,
ppf as ppf,
)
3 changes: 3 additions & 0 deletions jax/scipy/stats/chi2.py
Expand Up @@ -18,4 +18,7 @@
from jax._src.scipy.stats.chi2 import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)
3 changes: 3 additions & 0 deletions jax/scipy/stats/gamma.py
Expand Up @@ -18,4 +18,7 @@
from jax._src.scipy.stats.gamma import (
logpdf as logpdf,
pdf as pdf,
cdf as cdf,
logcdf as logcdf,
sf as sf,
)
2 changes: 2 additions & 0 deletions jax/scipy/stats/norm.py
Expand Up @@ -21,4 +21,6 @@
logpdf as logpdf,
pdf as pdf,
ppf as ppf,
sf as sf,
isf as isf,
)

0 comments on commit 8f4b8a0

Please sign in to comment.