Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.Array must be fully replicated to be saved in aggregate file #353

Closed
tatami-galaxy opened this issue Jun 11, 2023 · 6 comments
Closed

Comments

@tatami-galaxy
Copy link

I'm trying to save a checkpoint and getting this error message. Saving code :

ckpt = {'state': state, 'config': model.config} 
save_args = orbax_utils.save_args_from_target(ckpt)
checkpoint_manager.save(global_step + 1, ckpt, save_kwargs={'save_args': save_args})

This line in orbax/checkpoint/pytree_checkpoint_handler.py is throwing the error :

if isinstance(value, jax.Array) and not value.is_fully_replicated:
     raise ValueError(
         'jax.Array must be fully replicated to be saved in aggregate file.'
     )

state is an instance of flax.training.train_state. What could be causing this? I tried disabling jax.Array with jax.config.update('jax_array', False) but that does not work with jax and jaxlib 0.4.7.

@cpgaffney1
Copy link
Collaborator

To use the aggregate option, you should either have numpy arrays, basic scalar types, or you can reshard your jax.Arrays to be replicated across all devices. This is pretty easy to do - just supply a sharding of None instead. JAX documentation has lots of details

@tatami-galaxy
Copy link
Author

tatami-galaxy commented Jun 13, 2023

Hi @cpgaffney1, thanks for the response. I am replicating the state prior to training with state = jax_utils.replicate(state). I also shard the batches during training with flax.training.common_utils.shard. This error goes away if I call flax.jax_utils.unreplicate() on the state before saving like so : ckpt = {'state': unreplicate(state), 'config': model.config}. Is this supposed to happen?

@cpgaffney1
Copy link
Collaborator

jax_utils uses jax.device_put_replicated. When I run the following

import jax

replicated = jax.device_put_replicated(np.arange(32), jax.devices())
replicated.is_fully_replicated

is_fully_replicated is actually False, which is surprising. Checking to see if this is expected behavior.

@cpgaffney1
Copy link
Collaborator

I'm told that this is a known bug. In the meantime, you should use a different method to replicate the state - perhaps just use pjit.

@tatami-galaxy
Copy link
Author

Thanks. What would be the difference between saving the replicated state after pjit vs calling unreplicate() and saving as I'm doing now?

@cpgaffney1
Copy link
Collaborator

If replicating the arrays, they can be safely saved into the msgpack file. If calling unreplicate, I believe flax's behavior is to instruct Orbax to save using Tensorstore, which supports sharded arrays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants