Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor array_api namespace, relying more directly on jax.numpy #21013

Merged
merged 1 commit into from
May 2, 2024

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented Apr 30, 2024

This PR refactors the jax.experimental.array_api namespace by removing unnecessary wrappers around already-compliant functions in the jax.numpy namespace, and structuring the array_api namespace to pull directly from jax.numpy whenever possible. After this PR, the array_api submodule contain only:

  1. Wrapper functions to insulate jax.numpy from breaking changes, which will be removed when the corresponding jax.numpy behavior is deprecated and made array API compliant
  2. Additional features/utilities/API that is not yet present in jax.numpy and needs inclusion (e.g. introducing jax.numpy.matmul, which already exists in jax.numpy.linalg).

This PR also adds several TODO items describing what is required to cull that portion of the array_api submodule, with the understanding that once it is empty, jax.numpy will be fully compliant. I figured it would be a bit neater to keep the TODO notes dense in this submodule, rather than spreading them across the jax.numpy submodule on their corresponding functions. It's also consistent with the TODOs for new functionality or namespace elements.

This PR also modifies jax.numpy.isdtype to accept _ScalarMeta and other dtype-interpretable inputs.

Note that the array-api-tests issues many UserWarnings for the special cases test, as well as for their reporting utilities due to not understanding what @jit wrapped functions are in JAX, so this PR suppresses them in the jax-array-api workflow.

This PR has been validated against the array-api-tests suite for version 2023.12, using jax/experimental/array_api/skips.txt -- although it is worth noting that the test suite does not cover everything, e.g. is still missing support for copy and device keyword tests.

cc: @jakevdp

@Micky774 Micky774 changed the title Refactor array_api namespace, relying more on jax.numpy Refactor array_api namespace, relying more directly on jax.numpy Apr 30, 2024
@Micky774 Micky774 force-pushed the array-api-trim branch 2 times, most recently from 4642197 to 36e53d9 Compare April 30, 2024 20:52
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really nice! A few comments below.

.github/workflows/jax-array-api.yml Outdated Show resolved Hide resolved
jax/experimental/array_api/__init__.py Outdated Show resolved Hide resolved
jax/experimental/array_api/_statistical_functions.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array-api-trim branch 2 times, most recently from a146278 to 9e85f58 Compare April 30, 2024 21:31
@jakevdp jakevdp self-assigned this Apr 30, 2024
@Micky774 Micky774 force-pushed the array-api-trim branch 2 times, most recently from 5c0b3de to f9dbcac Compare April 30, 2024 22:30
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/_src/dtypes.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array-api-trim branch 2 times, most recently from 0897b7f to c4cfbcb Compare May 1, 2024 00:24
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/experimental/array_api/_elementwise_functions.py Outdated Show resolved Hide resolved
pyproject.toml Outdated Show resolved Hide resolved
pyproject.toml Outdated Show resolved Hide resolved
@Micky774
Copy link
Collaborator Author

Micky774 commented May 1, 2024

The current failure is due to needing to add some signature test skips for logical_{and, or, xor} -- that'll be fixed in the next push

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! This really makes clear what the remaining TODOs are 😀

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 2, 2024
@copybara-service copybara-service bot merged commit 187b2ac into google:main May 2, 2024
14 checks passed
@Micky774 Micky774 deleted the array-api-trim branch May 2, 2024 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants