-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated jnp.ceil/floor/trunc to preserve int dtypes #21441
Updated jnp.ceil/floor/trunc to preserve int dtypes #21441
Conversation
@Micky774 can you review this PR please |
Overall looks good, it may be worth adding a quick changelog entry since this will result in different dtypes, what do you think @jakevdp Either way, @vfdev-5 can you go ahead and push again to trigger CI? Edit: Do we have test coverage for the integral dtype preservation of |
fe91af6
to
b8a01f0
Compare
b8a01f0
to
b050380
Compare
ce2c9bb
to
0b9d6b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but I think we also need a changelog entry.
I'll push an update now |
0b9d6b7
to
8cc1a83
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
975efbc
to
33c2d53
Compare
33c2d53
to
9b61ae5
Compare
9b61ae5
to
19f1166
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
It looks like this change is breaking tests of expected output dtypes when tested against older numpy versions. We'll have to account for that in the tests. |
19f1166
to
781cd2f
Compare
781cd2f
to
0197aa8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
Actually, I have one more question here: what about booleans? The array API spec doesn't mention this explicitly, but it might be surprising if booleans are promoted to float when integers are not. I've asked the NumPy devs about this at numpy/numpy#26766 (comment). |
I asked Array API folks about that and here is the answer on this question:
cc @kgryte |
OK, in that case let's make the change for booleans as well. I find it strange that booleans are silently promoted to float, while integers keep their type. Alternatively, we could choose to promote booleans to int, by calling |
0197aa8
to
02f4bbb
Compare
@jakevdp I included bools in the same route as for integers, let me know if we need to implement the alternative solution you suggested above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
@jakevdp we should remove similarly ceil, floor, trunc from |
Yes, we should do that |
Description: - Updated jnp.ceil/floor/trunc to preserve int dtypes - Updated tests - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
02f4bbb
to
70b4823
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
JAX has an unreleased change google/jax#21441 that make it so that `int`s and `bool`s are no longer promoted to floats by `floor` and `ceil`. Our current implementation follows the Numpy promotion rule, i.e. `int`s and `bool`s are promoted to floats. Therefore: - updated unit tests, which use `jax.numpy` as the reference. - updated our JAX implementation of `floor` and `ceil` (note that `ceil` would already cast).
…9946) JAX has an unreleased change google/jax#21441 that make it so that `int`s and `bool`s are no longer promoted to floats by `floor` and `ceil`. Our current implementation follows the Numpy promotion rule, i.e. `int`s and `bool`s are promoted to floats. Therefore: - updated unit tests, which use `jax.numpy` as the reference. - updated our JAX implementation of `floor` and `ceil` (note that `ceil` would already cast).
Description:
Related to #21088
cc @Micky774