Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated jnp.ceil/floor/trunc to preserve int dtypes #21441

Merged

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented May 27, 2024

Description:

>>> np.ceil(np.arange(5))
array([0., 1., 2., 3., 4.])
>>> np.arange(5)
array([0, 1, 2, 3, 4])
>>> np.__version__
'2.0.0rc2'
  • Removed ceil, floor, trunc implemetnations from array_api and use jnp equivalent functions

Related to #21088

cc @Micky774

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 28, 2024

@Micky774 can you review this PR please

@vfdev-5 vfdev-5 marked this pull request as ready for review May 30, 2024 07:44
@Micky774
Copy link
Collaborator

Micky774 commented May 31, 2024

Overall looks good, it may be worth adding a quick changelog entry since this will result in different dtypes, what do you think @jakevdp

Either way, @vfdev-5 can you go ahead and push again to trigger CI?

Edit: Do we have test coverage for the integral dtype preservation of jnp.trunc? I'm not sure of the generic unary op check cares

jax/_src/numpy/ufuncs.py Outdated Show resolved Hide resolved
@jakevdp jakevdp self-assigned this Jun 4, 2024
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from fe91af6 to b8a01f0 Compare June 13, 2024 16:24
jax/_src/numpy/ufuncs.py Outdated Show resolved Hide resolved
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from b8a01f0 to b050380 Compare June 14, 2024 10:46
@vfdev-5 vfdev-5 requested a review from jakevdp June 14, 2024 13:51
jax/_src/numpy/ufuncs.py Outdated Show resolved Hide resolved
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch 2 times, most recently from ce2c9bb to 0b9d6b7 Compare June 24, 2024 12:31
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but I think we also need a changelog entry.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jun 24, 2024
@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Jun 24, 2024

I'll push an update now

@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from 0b9d6b7 to 8cc1a83 Compare June 24, 2024 12:56
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

CHANGELOG.md Outdated Show resolved Hide resolved
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch 2 times, most recently from 975efbc to 33c2d53 Compare June 24, 2024 13:36
CHANGELOG.md Outdated Show resolved Hide resolved
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from 33c2d53 to 9b61ae5 Compare June 24, 2024 13:40
CHANGELOG.md Outdated Show resolved Hide resolved
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from 9b61ae5 to 19f1166 Compare June 24, 2024 13:49
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 24, 2024

It looks like this change is breaking tests of expected output dtypes when tested against older numpy versions. We'll have to account for that in the tests.

@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from 19f1166 to 781cd2f Compare June 24, 2024 18:37
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from 781cd2f to 0197aa8 Compare June 24, 2024 18:38
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 24, 2024

Actually, I have one more question here: what about booleans? The array API spec doesn't mention this explicitly, but it might be surprising if booleans are promoted to float when integers are not. I've asked the NumPy devs about this at numpy/numpy#26766 (comment).

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Jun 25, 2024

Actually, I have one more question here: what about booleans? The array API spec doesn't mention this explicitly, but it might be surprising if booleans are promoted to float when integers are not. I've asked the NumPy devs about this at numpy/numpy#26766 (comment).

I asked Array API folks about that and here is the answer on this question:

Those functions are only defined for numerical types. If a library chooses to support boolean dtypes in those functions, that is okay (e.g., for backward compat reasons), but that code is not guaranteed to be portable.

cc @kgryte

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 25, 2024

OK, in that case let's make the change for booleans as well. I find it strange that booleans are silently promoted to float, while integers keep their type. Alternatively, we could choose to promote booleans to int, by calling promote_args_numeric on the inputs.

@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from 0197aa8 to 02f4bbb Compare June 25, 2024 17:34
@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Jun 25, 2024

@jakevdp I included bools in the same route as for integers, let me know if we need to implement the alternative solution you suggested above.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Jun 25, 2024

@jakevdp we should remove similarly ceil, floor, trunc from jax/experimental/array_api/_elementwise_functions.py and import them from jnp, right ?

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 25, 2024

@jakevdp we should remove similarly ceil, floor, trunc from jax/experimental/array_api/_elementwise_functions.py and import them from jnp, right ?

Yes, we should do that

Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
  - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
@vfdev-5 vfdev-5 force-pushed the preserve-int-dtype-ceil-floor-trunc-ops branch from 02f4bbb to 70b4823 Compare June 25, 2024 18:27
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@copybara-service copybara-service bot merged commit e34ba4a into google:main Jun 25, 2024
14 of 15 checks passed
@vfdev-5 vfdev-5 deleted the preserve-int-dtype-ceil-floor-trunc-ops branch June 25, 2024 18:54
hertschuh added a commit to hertschuh/keras that referenced this pull request Jul 1, 2024
JAX has an unreleased change google/jax#21441 that make it so that `int`s and `bool`s are no longer promoted to floats by `floor` and `ceil`.

Our current implementation follows the Numpy promotion rule, i.e. `int`s and `bool`s are promoted to floats. Therefore:
- updated unit tests, which use `jax.numpy` as the reference.
- updated our JAX implementation of `floor` and `ceil` (note that `ceil` would already cast).
fchollet pushed a commit to keras-team/keras that referenced this pull request Jul 1, 2024
…9946)

JAX has an unreleased change google/jax#21441 that make it so that `int`s and `bool`s are no longer promoted to floats by `floor` and `ceil`.

Our current implementation follows the Numpy promotion rule, i.e. `int`s and `bool`s are promoted to floats. Therefore:
- updated unit tests, which use `jax.numpy` as the reference.
- updated our JAX implementation of `floor` and `ceil` (note that `ceil` would already cast).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants