Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility with nightly numpy #22257

Merged
merged 1 commit into from
Jul 3, 2024

Conversation

dfm
Copy link
Member

@dfm dfm commented Jul 3, 2024

Numpy recently merged support for the 2023.12 revision of the Array API: numpy/numpy#26724

This breaks two of our tests:

  1. The first breakage was caused by differences in how numpy and JAX cast negative floats to uint8. Specifically np.float32(-1).astype(np.uint8) returns np.uint8(255) whereas jnp.float32(-1).astype(jnp.uint8) produces Array(0, dtype=uint8). We don't make any promises about consistency with casting floats to ints, noting that this can even be backend dependent. I don't believe this failure is identifying any unexpected behavior, and we test many other dtypes properly so I'm not concerned about skipping this test. To fix our test, we now only generate positive inputs when the output dtype is unsigned.

  2. The second failure was caused by the fact that the approach we took in Update jnp.clip to Array API 2023 standard and introduces jax.experimental.array_api.clip #20550 to support backwards compatibility and the Array API for clip differs from the one used in ENH: Add Array API 2023.12 version support numpy/numpy#26724. Again, the behavior is consistent, but it produces a different signature. I've skipped checking clip's signature, but we should revisit it once the a_min and a_max parameters have been removed from JAX.

Fixes #22251

@dfm dfm self-assigned this Jul 3, 2024
Numpy recently merged support for the 2023.12 revision of the Array API:
numpy/numpy#26724

This breaks two of our tests:

1. The first breakage was caused by differences in how numpy and JAX
   cast negative floats to `uint8`. Specifically
   `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas
   `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`.
   We don't make any promises about consistency with casting floats to
   ints, noting that this can even be backend dependent. To fix our
   test, we now only generate positive inputs when the output dtype is
   unsigned.

2. The second failure was caused by the fact that the approach we took
   in google#20550 to support backwards compatibility and the Array API for
   `clip` differs from the one used in numpy/numpy#26724. Again, the
   behavior is consistent, but it produces a different signature. I've
   skipped checking `clip`'s signature, but we should revisit it once
   the `a_min` and `a_max` parameters have been removed from JAX.

Fixes google#22251
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 3, 2024
@copybara-service copybara-service bot merged commit 82f4b5a into google:main Jul 3, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

⚠️ Nightly upstream-dev CI failed ⚠️
3 participants