diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 2bd789937b82..41ab137f58c1 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -146,6 +146,7 @@ Operators tie_in top_k transpose + zeta .. _lax-control-flow: diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index 7269409b6e74..65087f840118 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -65,6 +65,10 @@ def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" return random_gamma_grad_p.bind(a, x) +def zeta(x: ArrayLike, q: ArrayLike) -> Array: + r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" + return zeta_p.bind(x, q) + def bessel_i0e(x: ArrayLike) -> Array: r"""Exponentially scaled modified Bessel function of order 0: :math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)` @@ -639,6 +643,9 @@ def bessel_i0e_impl(x): mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), multiple_results=False)) +zeta_p = standard_naryop([_float, _float], 'zeta') +mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.ZetaOp)) + bessel_i0e_p = standard_unop(_float, 'bessel_i0e') mlir.register_lowering(bessel_i0e_p, mlir.lower_fun(bessel_i0e_impl, diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 21a4e1b73665..bbed51de8766 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -21,6 +21,7 @@ import jax.numpy as jnp from jax import jit +from jax import jvp from jax import vmap from jax import lax @@ -252,9 +253,22 @@ def rel_entr( ] +@custom_derivatives.custom_jvp @_wraps(osp_special.zeta, module='scipy.special') def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array: - assert q is not None, "Riemann zeta function is not implemented yet." + if q is None: + raise NotImplementedError( + "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.") + x, q = promote_args_inexact("zeta", x, q) + return lax.zeta(x, q) + + +# There is no general closed-form derivative for the zeta function, so we compute +# derivatives via a series expansion +def _zeta_series_expansion(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array: + if q is None: + raise NotImplementedError( + "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.") # Reference: Johansson, Fredrik. # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives." # Numerical Algorithms 69.2 (2015): 253-270. @@ -280,6 +294,8 @@ def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array: T = T0 * (dtype(0.5) + T1.sum(-1)) return S + I + T +zeta.defjvp(partial(jvp, _zeta_series_expansion)) # type: ignore[arg-type] + @_wraps(osp_special.polygamma, module='scipy.special', update_doc=False) def polygamma(n: ArrayLike, x: ArrayLike) -> Array: diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 67633686599b..1d32bddd33f8 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1449,6 +1449,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "global_array_to_host_local_array", "host_local_array_to_global_array", "call_exported", + "zeta", # Not high priority? "after_all", "all_to_all", diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index c7591f0fa748..6d0636856bea 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -248,6 +248,8 @@ random_gamma_grad as random_gamma_grad, random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, + zeta as zeta, + zeta_p as zeta_p, ) from jax._src.lax.slicing import ( GatherDimensionNumbers as GatherDimensionNumbers, diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 40b3e4de8232..e0961b24374a 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -127,9 +127,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t op_record( "xlog1py", 2, float_dtypes, jtu.rand_default, True ), - # TODO: enable gradient test for zeta by restricting the domain of - # of inputs to some reasonable intervals - op_record("zeta", 2, float_dtypes, jtu.rand_positive, False), + op_record("zeta", 2, float_dtypes, jtu.rand_positive, True), # TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise op_record( "expi", 1, [np.float32], diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index dd0585cfa914..12703a827e7f 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -198,12 +198,6 @@ def testIssue980(self): self.assertAllClose(np.zeros((4,), dtype=np.float32), lsp_special.expit(x)) - @jax.numpy_rank_promotion('raise') - def testIssue3758(self): - x = np.array([1e5, 1e19, 1e10], dtype=np.float32) - q = np.array([1., 40., 30.], dtype=np.float32) - self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) - def testIssue13267(self): """Tests betaln(x, 1) across wide range of x.""" xs = jnp.geomspace(1, 1e30, 1000)