-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixes cache miss in abstract_eval_shape for bcoo dot general #23238
Conversation
Thanks! This issue also comes up in a number of other places in this file – would you like to fix those as well as part of this PR? |
Regarding the CI failure: in slice_sizes = rng.randint(0, shape, len(shape)).tolist() We may also need to cast it to a tuple within the function to ensure that it's hashable. |
Actually, strike that idea: what we need is to move this line to above the |
Will do! I missed some tests locally due to forgetting to install my local repo (d'oh). So will re-compile jaxlib+jax and run test suite. |
No need to compile jaxlib: you should be able to use the most recent jaxlib release along with a local install of |
Hmm. When I do that, I'm unable to import anything from jax... python -c "import jax; print(jax.__version__)"
Traceback (most recent call last):
File "<string>", line 1, in <module>
AttributeError: module 'jax' has no attribute '__version__' I'll see if I can pin down what is going on. Maybe mismatch in version between dev and non-dev? |
That error generally comes from version skew between jax and jaxlib. Try doing |
Great! That seems to have fixed things on my end. I didn't see any failures in tests/sparse_bcoo_bcsr_test.py or tests/sparse_test.py |
Thanks again for this! It was a tricky issue and this fix will make a lot of code more efficient. |
This change fixes the cache miss issue when performing
abstract_eval_shape
forbcoo_dot_general
. This was discussed in #15915 (reply in thread) .