Skip to content

Commit

Permalink
Remove superfluous double-where in xlogy and xlog1py.
Browse files Browse the repository at this point in the history
These functions have custom derivatives, so there seems to be no point to using the double-where guard on the primal function: the implementation can never be differentiated!

PiperOrigin-RevId: 551843160
  • Loading branch information
hawkinsp authored and jax authors committed Jul 28, 2023
1 parent 88e11ae commit 640ee1e
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlogy(0, 0) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlogy", x, y)
x_ok = x != 0.
safe_x = jnp.where(x_ok, x, 1.)
safe_y = jnp.where(x_ok, y, 1.)
return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))
return jnp.where(x_ok, lax.mul(x, lax.log(y)), jnp.zeros_like(x))

def _xlogy_jvp(primals, tangents):
(x, y) = primals
Expand All @@ -145,9 +143,7 @@ def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlog1py(0, -1) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlog1py", x, y)
x_ok = x != 0.
safe_x = jnp.where(x_ok, x, 1.)
safe_y = jnp.where(x_ok, y, 1.)
return jnp.where(x_ok, lax.mul(safe_x, lax.log1p(safe_y)), jnp.zeros_like(x))
return jnp.where(x_ok, lax.mul(x, lax.log1p(y)), jnp.zeros_like(x))

def _xlog1py_jvp(primals, tangents):
(x, y) = primals
Expand Down

0 comments on commit 640ee1e

Please sign in to comment.