New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error using batch with jit #5
Comments
…by broadcasting numpy arrays and closing over non-arrays. The first argument has to be a numpy array with the leading dimension of size `device_count` for pmapping. *NO ARRAYS ARE ALLOWED IN KEYWORD ARGUMENTS. PiperOrigin-RevId: 270212471
Thanks for bringing this up. I believe this issue should have been fixed by Roman's work on jit-compilation in batching. I've tried your repro in colab and it seems to work, but let me know if any problems persist. |
Thanks, but it didn't solve it. |
Ah I see, thanks! So this appears to work if you remove the One point that we should make more clear in the docs, especially when we have docs, is that we have found poor memory characteristics of applying |
I understand, thanks! |
It looks like the latest JAX version has this fixed, and both Please note that I am also unsure of how reasonable it is to apply |
Closing, please let me know if there are still related issues! |
I get the error
Too many leaves for PyTreeDef; expected 6.
when I'm trying to run the following code -The text was updated successfully, but these errors were encountered: