Skip to content

Commit

Permalink
Merge pull request #19390 from jakevdp:jnp-sign
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599203136
  • Loading branch information
jax authors committed Jan 17, 2024
2 parents 199591f + fb56224 commit aac996c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Expand Up @@ -48,9 +48,11 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
reshaped to the dimension of the input, following a similar change to
{func}`numpy.unique` in NumPy 2.0.
* {func}`jax.numpy.sign` now returns `x / abs(x)` for nonzero complex inputs. This is
consistent with the behavior of {func}`numpy.sign` in NumPy version 2.0.
* {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
of the function in SciPy v1.13.
of {func}`scipy.special.logsumexp` in SciPy v1.13.

* JAX now supports the bool DLPack type for both import and export.
Previously bool values could not be imported and were exported as integers.
Expand Down
13 changes: 1 addition & 12 deletions jax/_src/numpy/ufuncs.py
Expand Up @@ -173,6 +173,7 @@ def _arccosh(x: ArrayLike, /) -> Array:
arccosh = _one_to_one_unop(np.arccosh, _arccosh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sign = _one_to_one_unop(np.sign, lax.sign)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)

Expand Down Expand Up @@ -257,18 +258,6 @@ def rint(x: ArrayLike, /) -> Array:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)


@_wraps(np.sign, module='numpy')
@jit
def sign(x: ArrayLike, /) -> Array:
check_arraylike('sign', x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.complexfloating):
re = lax.real(x)
return lax.complex(
lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
return lax.sign(x)


@_wraps(np.copysign, module='numpy')
@jit
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
Expand Down
21 changes: 20 additions & 1 deletion tests/lax_numpy_operators_test.py
Expand Up @@ -56,6 +56,7 @@
default_dtypes = float_dtypes + int_dtypes
inexact_dtypes = float_dtypes + complex_dtypes
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
real_dtypes = float_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes


Expand Down Expand Up @@ -272,7 +273,9 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
[]),
op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_default, [], check_dtypes=False),
op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []),
# numpy < 2.0.0 has a different convention for complex sign.
op_record("sign", 1, real_dtypes if jtu.numpy_version() < (2, 0, 0) else number_dtypes,
all_shapes, jtu.rand_some_inf_and_nan, []),
# numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately.
op_record("copysign", 2, default_dtypes + unsigned_dtypes,
all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False),
Expand Down Expand Up @@ -646,6 +649,22 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes):
self._CompileAndCheck(op, args_maker)
self._CheckAgainstNumpy(np_op, op, args_maker)

# This test can be deleted once we test against NumPy 2.0.
@jtu.sample_product(
shape=all_shapes,
dtype=complex_dtypes
)
def testSignComplex(self, shape, dtype):
rng = jtu.rand_default(self.rng())
if jtu.numpy_version() >= (2, 0, 0):
np_fun = np.sign
else:
np_fun = lambda x: (x / np.where(x == 0, 1, abs(x))).astype(np.result_type(x))
jnp_fun = jnp.sign
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testDeferToNamedTuple(self):
class MyArray(NamedTuple):
arr: jax.Array
Expand Down

0 comments on commit aac996c

Please sign in to comment.