-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[shape_poly] Simplify the indexing computations to be compatible with shape polymorphism #18679
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
gnecula
force-pushed
the
poly_getitem2
branch
from
November 27, 2023 11:35
830a016
to
941f857
Compare
gnecula
force-pushed
the
poly_getitem2
branch
3 times, most recently
from
November 28, 2023 08:25
3c82539
to
35341ad
Compare
gnecula
force-pushed
the
poly_getitem2
branch
from
November 28, 2023 10:46
35341ad
to
782fdab
Compare
gnecula
force-pushed
the
poly_getitem2
branch
2 times, most recently
from
November 29, 2023 11:44
855c251
to
8325e1e
Compare
jakevdp
approved these changes
Nov 30, 2023
…ith shape polymorphism Currently, we do not support shape polymorphism when we index with a slice, e.g., `x[a:b:c]`, and insted we direct the user to use to `lax.dynamic_slice`. This is only because so far we have not tried to ensure that the index and bounds checking computations in gather are compatible with shape polymorphism. The problem was that there were a lot of conditionals, e.g., `if start >= stop` that cannot be handled in general in presence of symbolic shapes. Here we introduce a new helper function `_preprocess_slice` to contain all the computations for the start and the size of the slice. To test that this does not break the JAX index computations, I ran the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
gnecula
force-pushed
the
poly_getitem2
branch
from
December 1, 2023 06:40
8325e1e
to
2d1ce13
Compare
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 2, 2023
This bug was introduced in google#18679, and was not caught in unit tests because we were not testing cases when the slice needs to be clamped.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 2, 2023
This bug was introduced in google#18679, and was not caught in unit tests because we were not testing cases when the slice needs to be clamped.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 2, 2023
This bug was introduced in google#18679, and was not caught in unit tests because we were not testing cases when the slice needs to be clamped.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 2, 2023
This bug was introduced in google#18679, and was not caught in unit tests because we were not testing cases when the slice needs to be clamped.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 5, 2023
This fixes a bug introduced in google#18679, for the case when some elements of the slice are `jax.Array`. We add a new test also.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Dec 5, 2023
This fixes a bug introduced in google#18679, for the case when some elements of the slice are `jax.Array`. We add a new test also.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Currently, we do not support shape polymorphism when we index with a
slice, e.g.,
x[a:b:c]
, and insted direct the user to use tolax.dynamic_slice
. This is only because so far we have not triedto ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g.,
if start >= stop
that cannot behandled in general in presence of symbolic shapes.
Here we introduce a new helper function
_preprocess_slice
to containall the computations for the start and the size of the slice.
To test that this does not break the JAX index computations, I ran
the tests with
JAX_NUM_GENERATED_CASES=1000
, especially thelax_numpy_indexer_test.py
.