Skip to content

Commit

Permalink
Merge pull request #18683 from jakevdp:gamma-neg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 585764804
  • Loading branch information
jax authors committed Nov 27, 2023
2 parents d488714 + 01fde43 commit 5274ca9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
6 changes: 4 additions & 2 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def gammaln(x: ArrayLike) -> Array:
The JAX version only accepts real-valued inputs.""")
def gamma(x: ArrayLike) -> Array:
x, = promote_args_inexact("gamma", x)
return lax.exp(lax.lgamma(x))

# Compute the sign for negative x, matching the semantics of scipy.special.gamma
floor_x = lax.floor(x)
sign = jnp.where((x > 0) | (x == floor_x), 1.0, (-1.0) ** floor_x)
return sign * lax.exp(lax.lgamma(x))

betaln = _wraps(
osp_special.betaln,
Expand Down
11 changes: 10 additions & 1 deletion tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
"betainc", 3, float_dtypes, jtu.rand_positive, False
),
op_record(
"gamma", 1, float_dtypes, jtu.rand_positive, True
"gamma", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"digamma", 1, float_dtypes, jtu.rand_positive, True
Expand Down Expand Up @@ -199,6 +199,15 @@ def testScipySpecialFunBernoulli(self, n):
self._CheckAgainstNumpy(scipy_op, lax_op, args_maker, atol=0, rtol=1E-5)
self._CompileAndCheck(lax_op, args_maker, atol=0, rtol=1E-5)

def testGammaSign(self):
# Test that the sign of `gamma` matches at integer-valued inputs.
dtype = jax.numpy.zeros(0).dtype # default float dtype.
args_maker = lambda: [np.arange(-10, 10).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.gamma, lsp_special.gamma, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.gamma, args_maker, rtol=rtol)



if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 5274ca9

Please sign in to comment.