Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes cache miss in abstract_eval_shape for bcoo dot general #23238

Merged
merged 3 commits into from
Aug 30, 2024

Conversation

quattro
Copy link
Contributor

@quattro quattro commented Aug 26, 2024

This change fixes the cache miss issue when performing abstract_eval_shape for bcoo_dot_general. This was discussed in #15915 (reply in thread) .

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2024

Thanks! This issue also comes up in a number of other places in this file – would you like to fix those as well as part of this PR?

@jakevdp jakevdp self-assigned this Aug 26, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2024

Regarding the CI failure: in test_bcoo_dynamic_slice we need to cast slice_sizes to a list. i.e. this: https://github.com/google/jax/blob/c33ce857847750b836b8c899c5d48c12b2842afc/tests/sparse_bcoo_bcsr_test.py#L1062
should be changed to this:

slice_sizes = rng.randint(0, shape, len(shape)).tolist()

We may also need to cast it to a tuple within the function to ensure that it's hashable.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2024

Actually, strike that idea: what we need is to move this line to above the eval_shape call inside bcoo_dynamic_slice: https://github.com/google/jax/blob/c33ce857847750b836b8c899c5d48c12b2842afc/jax/experimental/sparse/bcoo.py#L2046

@quattro
Copy link
Contributor Author

quattro commented Aug 26, 2024

Actually, strike that idea: what we need is to move this line to above the eval_shape call:

https://github.com/google/jax/blob/c33ce857847750b836b8c899c5d48c12b2842afc/jax/experimental/sparse/bcoo.py#L2046

Will do! I missed some tests locally due to forgetting to install my local repo (d'oh). So will re-compile jaxlib+jax and run test suite.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2024

No need to compile jaxlib: you should be able to use the most recent jaxlib release along with a local install of jax via pip install -e .[cpu]

@quattro
Copy link
Contributor Author

quattro commented Aug 26, 2024

Hmm. When I do that, I'm unable to import anything from jax...

python -c "import jax; print(jax.__version__)"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
AttributeError: module 'jax' has no attribute '__version__'

I'll see if I can pin down what is going on. Maybe mismatch in version between dev and non-dev?

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 27, 2024

That error generally comes from version skew between jax and jaxlib. Try doing pip install -U .[cpu] assuming you're running locally on a CPU, and haven't installed a GPU jaxlib locally. If you've installed a GPU jaxlib in the past, then update that instead.

@quattro
Copy link
Contributor Author

quattro commented Aug 28, 2024

Great! That seems to have fixed things on my end. I didn't see any failures in tests/sparse_bcoo_bcsr_test.py or tests/sparse_test.py

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Aug 30, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 30, 2024

Thanks again for this! It was a tricky issue and this fix will make a lot of code more efficient.

@copybara-service copybara-service bot merged commit f8a4662 into jax-ml:main Aug 30, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants