-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.Array must be fully replicated to be saved in aggregate file #353
Comments
To use the |
Hi @cpgaffney1, thanks for the response. I am replicating the state prior to training with |
jax_utils uses jax.device_put_replicated. When I run the following
|
I'm told that this is a known bug. In the meantime, you should use a different method to replicate the state - perhaps just use pjit. |
Thanks. What would be the difference between saving the replicated state after pjit vs calling unreplicate() and saving as I'm doing now? |
If replicating the arrays, they can be safely saved into the msgpack file. If calling unreplicate, I believe flax's behavior is to instruct Orbax to save using Tensorstore, which supports sharded arrays. |
I'm trying to save a checkpoint and getting this error message. Saving code :
This line in
orbax/checkpoint/pytree_checkpoint_handler.py
is throwing the error :state
is an instance offlax.training.train_state
. What could be causing this? I tried disabling jax.Array with jax.config.update('jax_array', False) but that does not work with jax and jaxlib 0.4.7.The text was updated successfully, but these errors were encountered: