Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility with nightly numpy #22257

Merged
merged 1 commit into from
Jul 3, 2024

Commits on Jul 3, 2024

  1. Fix compatibility with nightly numpy

    Numpy recently merged support for the 2023.12 revision of the Array API:
    numpy/numpy#26724
    
    This breaks two of our tests:
    
    1. The first breakage was caused by differences in how numpy and JAX
       cast negative floats to `uint8`. Specifically
       `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas
       `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`.
       We don't make any promises about consistency with casting floats to
       ints, noting that this can even be backend dependent. To fix our
       test, we now only generate positive inputs when the output dtype is
       unsigned.
    
    2. The second failure was caused by the fact that the approach we took
       in jax-ml#20550 to support backwards compatibility and the Array API for
       `clip` differs from the one used in numpy/numpy#26724. Again, the
       behavior is consistent, but it produces a different signature. I've
       skipped checking `clip`'s signature, but we should revisit it once
       the `a_min` and `a_max` parameters have been removed from JAX.
    
    Fixes jax-ml#22251
    dfm committed Jul 3, 2024
    Configuration menu
    Copy the full SHA
    9e9acc9 View commit details
    Browse the repository at this point in the history