Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shape_poly] Simplify the indexing computations to be compatible with shape polymorphism #18679

Merged
merged 1 commit into from
Dec 1, 2023

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Nov 27, 2023

Currently, we do not support shape polymorphism when we index with a
slice, e.g., x[a:b:c], and insted direct the user to use to
lax.dynamic_slice. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., if start >= stop that cannot be
handled in general in presence of symbolic shapes.

Here we introduce a new helper function _preprocess_slice to contain
all the computations for the start and the size of the slice.

To test that this does not break the JAX index computations, I ran
the tests with JAX_NUM_GENERATED_CASES=1000, especially the lax_numpy_indexer_test.py.

@gnecula gnecula self-assigned this Nov 27, 2023
@gnecula gnecula added the pull ready Ready for copybara import and testing label Nov 27, 2023
@gnecula gnecula force-pushed the poly_getitem2 branch 3 times, most recently from 3c82539 to 35341ad Compare November 28, 2023 08:25
@gnecula gnecula requested a review from jakevdp November 28, 2023 08:50
@gnecula gnecula removed the request for review from jakevdp November 28, 2023 11:17
@gnecula gnecula force-pushed the poly_getitem2 branch 2 times, most recently from 855c251 to 8325e1e Compare November 29, 2023 11:44
@gnecula gnecula requested a review from jakevdp November 29, 2023 12:04
…ith shape polymorphism

Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a:b:c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.

Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.

To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.
@copybara-service copybara-service bot merged commit e60aa3b into google:main Dec 1, 2023
14 checks passed
@gnecula gnecula deleted the poly_getitem2 branch December 1, 2023 11:25
gnecula added a commit to gnecula/jax that referenced this pull request Dec 2, 2023
This bug was introduced in google#18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 2, 2023
This bug was introduced in google#18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 2, 2023
This bug was introduced in google#18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 2, 2023
This bug was introduced in google#18679, and was not caught
in unit tests because we were not testing cases when the
slice needs to be clamped.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 5, 2023
This fixes a bug introduced in google#18679, for the case when some
elements of the slice are `jax.Array`. We add a new test also.
gnecula added a commit to gnecula/jax that referenced this pull request Dec 5, 2023
This fixes a bug introduced in google#18679, for the case when some
elements of the slice are `jax.Array`. We add a new test also.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants