Skip to content

Commit

Permalink
Create lax.polygamma with native HLO lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 16, 2023
1 parent 539a6d1 commit 0ad6196
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/jax.lax.rst
Expand Up @@ -111,6 +111,7 @@ Operators
neg
nextafter
pad
polygamma
pow
real
reciprocal
Expand Down
15 changes: 15 additions & 0 deletions jax/_src/lax/special.py
Expand Up @@ -49,6 +49,10 @@ def digamma(x: ArrayLike) -> Array:
r"""Elementwise digamma: :math:`\psi(x)`."""
return digamma_p.bind(x)

def polygamma(m: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`."""
return polygamma_p.bind(m, x)

def igamma(a: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise regularized incomplete gamma function."""
return igamma_p.bind(a, x)
Expand Down Expand Up @@ -111,6 +115,12 @@ def igammac_gradx(g, a, x):
def igammac_grada(g, a, x):
return -igamma_grada(g, a, x)

def polygamma_gradm(g, m, x):
raise ValueError("polygamma gradient with respect to m is not supported")

def polygamma_gradx(g, m, x):
return g * polygamma(add(m, _const(m, 1)), x)

# The below is directly ported from tensorflow/compiler/xla/client/lib/math.cc
# We try to follow the corresponding functions as closely as possible, so that
# we can quickly incorporate changes.
Expand Down Expand Up @@ -615,6 +625,11 @@ def bessel_i0e_impl(x):

digamma_p = standard_unop(_float, 'digamma')
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
ad.defjvp(digamma_p, lambda g, x: mul(g, polygamma(_const(x, 1), x)))

polygamma_p = standard_naryop([_float, _float], 'polygamma')
mlir.register_lowering(polygamma_p, partial(_nary_lower_hlo, chlo.PolygammaOp))
ad.defjvp(polygamma_p, polygamma_gradm, polygamma_gradx)

igamma_p = standard_naryop([_float, _float], 'igamma')
mlir.register_lowering(igamma_p, mlir.lower_fun(_up_and_broadcast(igamma_impl),
Expand Down
16 changes: 1 addition & 15 deletions jax/_src/scipy/special.py
Expand Up @@ -27,7 +27,6 @@
from jax._src import core
from jax._src import custom_derivatives
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
from jax._src.numpy.util import _wraps
Expand Down Expand Up @@ -67,9 +66,6 @@ def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
def digamma(x: ArrayLike) -> Array:
x, = promote_args_inexact("digamma", x)
return lax.digamma(x)
ad.defjvp(
lax.digamma_p,
lambda g, x: lax.mul(g, polygamma(1, x))) # type: ignore[has-type]


@_wraps(osp_special.gammainc, module='scipy.special', update_doc=False)
Expand Down Expand Up @@ -285,17 +281,7 @@ def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array:
def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
assert jnp.issubdtype(lax.dtype(n), jnp.integer)
n_arr, x_arr = promote_args_inexact("polygamma", n, x)
shape = lax.broadcast_shapes(n_arr.shape, x_arr.shape)
return _polygamma(jnp.broadcast_to(n_arr, shape), jnp.broadcast_to(x_arr, shape))


@custom_derivatives.custom_jvp
def _polygamma(n: ArrayLike, x: ArrayLike) -> Array:
dtype = lax.dtype(n).type
n_plus = n + dtype(1)
sign = dtype(1) - (n_plus % dtype(2)) * dtype(2)
return jnp.where(n == 0, digamma(x), sign * jnp.exp(gammaln(n_plus)) * zeta(n_plus, x))
_polygamma.defjvps(None, lambda g, ans, n, x: lax.mul(g, _polygamma(n + 1, x)))
return lax.polygamma(n_arr, x_arr)


# Normal distributions
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1432,6 +1432,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"clz",
"igamma_grad_a",
"random_gamma_grad",
"polygamma",
"reduce_xor",
"schur",
"closed_call",
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/__init__.py
Expand Up @@ -245,6 +245,8 @@
igamma_p as igamma_p,
lgamma as lgamma,
lgamma_p as lgamma_p,
polygamma as polygamma,
polygamma_p as polygamma_p,
random_gamma_grad as random_gamma_grad,
random_gamma_grad_p as random_gamma_grad_p,
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
Expand Down

0 comments on commit 0ad6196

Please sign in to comment.