Skip to content

Commit

Permalink
Create lax.zeta with native HLO lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 16, 2023
1 parent 4cf85b9 commit 6cd467f
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Expand Up @@ -146,6 +146,7 @@ Operators
tie_in
top_k
transpose
zeta

.. _lax-control-flow:

Expand Down
7 changes: 7 additions & 0 deletions jax/_src/lax/special.py
Expand Up @@ -65,6 +65,10 @@ def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise derivative of samples from `Gamma(a, 1)`."""
return random_gamma_grad_p.bind(a, x)

def zeta(x: ArrayLike, q: ArrayLike) -> Array:
r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`"""
return zeta_p.bind(x, q)

def bessel_i0e(x: ArrayLike) -> Array:
r"""Exponentially scaled modified Bessel function of order 0:
:math:`\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)`
Expand Down Expand Up @@ -639,6 +643,9 @@ def bessel_i0e_impl(x):
mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl),
multiple_results=False))

zeta_p = standard_naryop([_float, _float], 'zeta')
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.ZetaOp))

bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
mlir.register_lowering(bessel_i0e_p,
mlir.lower_fun(bessel_i0e_impl,
Expand Down
18 changes: 17 additions & 1 deletion jax/_src/scipy/special.py
Expand Up @@ -21,6 +21,7 @@

import jax.numpy as jnp
from jax import jit
from jax import jvp
from jax import vmap
from jax import lax

Expand Down Expand Up @@ -252,9 +253,22 @@ def rel_entr(
]


@custom_derivatives.custom_jvp
@_wraps(osp_special.zeta, module='scipy.special')
def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array:
assert q is not None, "Riemann zeta function is not implemented yet."
if q is None:
raise NotImplementedError(
"Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.")
x, q = promote_args_inexact("zeta", x, q)
return lax.zeta(x, q)


# There is no general closed-form derivative for the zeta function, so we compute
# derivatives via a series expansion
def _zeta_series_expansion(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array:
if q is None:
raise NotImplementedError(
"Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.")
# Reference: Johansson, Fredrik.
# "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives."
# Numerical Algorithms 69.2 (2015): 253-270.
Expand All @@ -280,6 +294,8 @@ def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array:
T = T0 * (dtype(0.5) + T1.sum(-1))
return S + I + T

zeta.defjvp(partial(jvp, _zeta_series_expansion)) # type: ignore[arg-type]


@_wraps(osp_special.polygamma, module='scipy.special', update_doc=False)
def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1449,6 +1449,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"global_array_to_host_local_array",
"host_local_array_to_global_array",
"call_exported",
"zeta",
# Not high priority?
"after_all",
"all_to_all",
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/__init__.py
Expand Up @@ -248,6 +248,8 @@
random_gamma_grad as random_gamma_grad,
random_gamma_grad_p as random_gamma_grad_p,
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
zeta as zeta,
zeta_p as zeta_p,
)
from jax._src.lax.slicing import (
GatherDimensionNumbers as GatherDimensionNumbers,
Expand Down
4 changes: 1 addition & 3 deletions tests/lax_scipy_special_functions_test.py
Expand Up @@ -127,9 +127,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
op_record(
"xlog1py", 2, float_dtypes, jtu.rand_default, True
),
# TODO: enable gradient test for zeta by restricting the domain of
# of inputs to some reasonable intervals
op_record("zeta", 2, float_dtypes, jtu.rand_positive, False),
op_record("zeta", 2, float_dtypes, jtu.rand_positive, True),
# TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise
op_record(
"expi", 1, [np.float32],
Expand Down
6 changes: 0 additions & 6 deletions tests/lax_scipy_test.py
Expand Up @@ -198,12 +198,6 @@ def testIssue980(self):
self.assertAllClose(np.zeros((4,), dtype=np.float32),
lsp_special.expit(x))

@jax.numpy_rank_promotion('raise')
def testIssue3758(self):
x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
q = np.array([1., 40., 30.], dtype=np.float32)
self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q))

def testIssue13267(self):
"""Tests betaln(x, 1) across wide range of x."""
xs = jnp.geomspace(1, 1e30, 1000)
Expand Down

0 comments on commit 6cd467f

Please sign in to comment.