Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated jnp.ceil/floor/trunc to preserve int dtypes #21441

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Remember to align the itemized text with the first line of an item within a list

* Changes
* The minimum NumPy version is now 1.24.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.

## jaxlib 0.4.31

Expand Down
2 changes: 2 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def result_type(*args: Any) -> DType:
@jit
def trunc(x: ArrayLike) -> Array:
util.check_arraylike('trunc', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax_internal.asarray(x)
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))


Expand Down
6 changes: 6 additions & 0 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,17 @@ def sign(x: ArrayLike, /) -> Array:
@implements(np.floor, module='numpy')
@partial(jit, inline=True)
def floor(x: ArrayLike, /) -> Array:
check_arraylike('floor', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax.asarray(x)
return lax.floor(*promote_args_inexact('floor', x))

@implements(np.ceil, module='numpy')
@partial(jit, inline=True)
def ceil(x: ArrayLike, /) -> Array:
check_arraylike('ceil', x)
if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')):
return lax.asarray(x)
return lax.ceil(*promote_args_inexact('ceil', x))

@implements(np.exp, module='numpy')
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
broadcast_arrays as broadcast_arrays,
broadcast_to as broadcast_to,
can_cast as can_cast,
ceil as ceil,
complex128 as complex128,
complex64 as complex64,
concat as concat,
Expand All @@ -85,6 +86,7 @@
flip as flip,
float32 as float32,
float64 as float64,
floor as floor,
floor_divide as floor_divide,
from_dlpack as from_dlpack,
full as full,
Expand Down Expand Up @@ -160,6 +162,7 @@
tile as tile,
tril as tril,
triu as triu,
trunc as trunc,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
Expand Down Expand Up @@ -192,11 +195,8 @@
)

from jax.experimental.array_api._elementwise_functions import (
ceil as ceil,
clip as clip,
floor as floor,
hypot as hypot,
trunc as trunc,
)

from jax.experimental.array_api._statistical_functions import (
Expand Down
27 changes: 0 additions & 27 deletions jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@
from jax._src.numpy.util import promote_args


# TODO(micky774): Update jnp.ceil to preserve integral dtype
def ceil(x, /):
"""Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i."""
x, = promote_args("ceil", x)
if isdtype(x.dtype, "integral"):
return x
return jax.numpy.ceil(x)


# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
def clip(x, /, min=None, max=None):
Expand All @@ -43,15 +34,6 @@ def clip(x, /, min=None, max=None):
return jax.numpy.clip(x, min=min, max=max)


# TODO(micky774): Update jnp.floor to preserve integral dtype
def floor(x, /):
"""Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i."""
x, = promote_args("floor", x)
if isdtype(x.dtype, "integral"):
return x
return jax.numpy.floor(x)


# TODO(micky774): Remove when jnp.hypot deprecation is completed
# (began 2024-4-14) and default behavior is Array API 2023 compliant
def hypot(x1, x2, /):
Expand All @@ -64,12 +46,3 @@ def hypot(x1, x2, /):
"values first, such as by using jnp.real or jnp.imag to take the real "
"or imaginary components respectively.")
return jax.numpy.hypot(x1, x2)


# TODO(micky774): Update jnp.trunc to preserve integral dtype
def trunc(x, /):
"""Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i."""
x, = promote_args("trunc", x)
if isdtype(x.dtype, "integral"):
return x
return jax.numpy.trunc(x)
2 changes: 2 additions & 0 deletions tests/jet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def test_floor(self): self.unary_check(jnp.floor)
@jtu.skip_on_devices("tpu")
def test_ceil(self): self.unary_check(jnp.ceil)
@jtu.skip_on_devices("tpu")
def test_trunc(self): self.unary_check(jnp.trunc)
@jtu.skip_on_devices("tpu")
def test_round(self): self.unary_check(lax.round)
@jtu.skip_on_devices("tpu")
def test_sign(self): self.unary_check(lax.sign)
Expand Down
2 changes: 1 addition & 1 deletion tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ class PallasOpsTest(PallasTest):
[jnp.abs, jnp.negative],
["int16", "int32", "int64", "float16", "float32", "float64"],
),
([jnp.ceil, jnp.floor], ["float32", "float64"]),
([jnp.ceil, jnp.floor], ["float32", "float64", "int32"]),
(
[jnp.exp, jnp.exp2, jnp.sin, jnp.cos, jnp.log, jnp.sqrt],
["float16", "float32", "float64"],
Expand Down