Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to restore checkpoints if not all arrays are sharded? #381

Closed
gianlucadetommaso opened this issue Jun 27, 2023 · 10 comments
Closed

How to restore checkpoints if not all arrays are sharded? #381

gianlucadetommaso opened this issue Jun 27, 2023 · 10 comments

Comments

@gianlucadetommaso
Copy link

gianlucadetommaso commented Jun 27, 2023

Hi,
I am trying to restore a checkpoint where some of its array are not partitioned. I believe this is a common use case, as one may not want (or can't) to partition all parameters in a model. The following is a minimal example:

import numpy as np
import jax
import flax

mesh_shape = (2, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, ('x', 'y'))

sharded = jax.device_put(np.arange(4).reshape(2, 2), jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')))
unsharded = jax.device_put(np.arange(1), jax.sharding.NamedSharding(mesh, PartitionSpec()))
ckpt = dict(sharded=sharded, unsharded=unsharded)

jax.distributed.initialize("localhost:8889", num_processes=1, process_id=0)

async_ckptr = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)
async_ckpt_mgr = orbax.checkpoint.CheckpointManager('/tmp/example', async_ckptr)
async_ckpt_mgr.wait_until_finished()
async_ckpt_mgr.save(0, ckpt)
async_ckpt_mgr.wait_until_finished()

restore_args = flax.training.orbax_utils.restore_args_from_target(ckpt)
restored_ckpt = async_ckpt_mgr.restore(0, items=ckpt, restore_kwargs={'restore_args': restore_args})

I get the following error message:
"ValueError: Sharding of jax.Array cannot be None. Provide mesh and mesh_axes OR sharding."

I then found out that I can transform unsharded into a numpy.array, restore it, then assign again an empty partition. As follows:

ref_ckpt = dict(sharded=sharded, unsharded=np.array(unsharded))
restore_args = flax.training.orbax_utils.restore_args_from_target(ref_ckpt)
restored_ckpt = async_ckpt_mgr.restore(0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args})
restored_ckpt["unsharded"] = jax.device_put(restored_ckpt["unsharded"], NamedSharding(mesh, PartitionSpec()))

This seems to work as desired. Would you say this is the right way to do it? Is there an easier way?

Thanks!

@cpgaffney1
Copy link
Collaborator

I think this may be consequence of using the Flax restore_args_from_target function which makes certain assumptions about how on-device, fully replicated arrays should be restored. Try using construct_restore_args instead.

@gianlucadetommaso
Copy link
Author

Thanks, it appears to work.

Here is a follow-up question: in the documentation, the target object passed to both construct_restore_args and CheckpointManager.restore have the same shapes as the train_state that we are trying to restore. This seems not ideal, considering that it may take up quite some memory.

In the Flax documentation here, it seems to be possible to pass a smaller target reference. However, I've trying creating one with the same axis dimension as my shardings, but I get an error like

ValueError: Cannot intersect index domain { [0, 40*) } with index domain { }: Ranks do not match [source locations='tensorstore/index_space/index_transform.cc:484']

What is the exact logic? How can I create a reference smaller than the actual train state?

@cpgaffney1
Copy link
Collaborator

For starters, the item argument for CheckpointManager.restore is optional, and is quite unnecessary in your case since it's just a dict, and not a custom PyTree.

Secondly, it is possible to skip the initialization of target if you are constrained by memory. Simply use restore_args=dict(sharded=orbax.checkpoint.ArrayRestoreArgs(sharding=...), unsharded=orbax.checkpoint.ArrayRestoreArgs(sharding=...))

I'm not actually sure what the Flax documentation is talking about when they say the reference may be smaller than the actual train state.... the construct_restore_args function is really only intended for use when you already have the entire PyTree initialized with arrays of the correct shape and sharding. If you have jax.ShapeDtypeStruct, for example, you would just use similar logic to initialize the restore_args.

@gianlucadetommaso
Copy link
Author

Fantastic, seems to have worked like a charm. Thanks a lot!

@gianlucadetommaso
Copy link
Author

Reopening the issue since I have got a follow-up problem. In my code, the restoring function comes within a jitted function. Unfortunately, I get the following concretization error:

jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(uint32[])>with<DynamicJaxprTrace(level=2/0)>

This value became a tracer due to JAX operations on these lines:
  operation a:u32[1] = host_local_array_to_global_array[
  global_mesh=Mesh(device_ids=array([[0]]), axis_names=('processes', 'local_devices'))
  pspec=PartitionSpec('processes',)
] b

Any idea how I can work around this issue?

@cpgaffney1
Copy link
Collaborator

This is an easy answer: you can't restore from within a jitted function. This is an issue we encountered before with a few other Flax users, and we reached the conclusion that they would just need to move their restore outside the jitted function. Sorry!

@gianlucadetommaso
Copy link
Author

Alright, then perhaps there is a way to obtain a tree of shapes corresponding to the checkpoint, without actually restoring the checkpoint? I tried using jax.eval_shape in combination with a CheckpointManager.restore, but it turns out that jax.eval_shape jits the function, therefore this approach did not work.

@cpgaffney1
Copy link
Collaborator

There is, but it's currently not well integrated into the API. Orbax does support getting the pytree structure of the checkpoint via the structure API or using lazy restore. However, if you want the shapes of the arrays, we currently don't have an API for that, though we're working on it. What you can do is parse through the .zarray files, which store the shape (and other metadata) for each parameter.

@gianlucadetommaso
Copy link
Author

Thanks. Related to structure, I've noticed that when there is only one checkpoint in the directory, structure creates a pytree with the same structure as the checkpoint, with numpy arrays as values. However, when multiple checkpoints are available, it returns a pytree with lists of placeholder strings as values. I find this difference a little odd. For example, it would seem I could get shapes in the first case, but not in the second?

Anyway, in order to get shapes, currently I'm restoring the whole checkpoint as lazy (using ArrayRestoreArgs(lazy=True)), then I get the shape of their values in a tree_map as follows:

shapes = tree_map(lambda v: v.get().shape, restored_checkpoint)

Would you say this method allocates the full checkpoint in memory?

@cpgaffney1
Copy link
Collaborator

If you only have a checkpoint file in the directory, that means all your values were saved with aggregate=True, so they are all stored in one file. Otherwise, they will be stored separately using Tensorstore, in which case they will be represented with PLACEHOLDER in the structure file. You're right that this difference is a bit annoying, and it's not easy to get metadata for all arrays, to look at things like the shape and stuff without doing a full restoration.

This does materialize the entire checkpoint. If you want to look up the shapes without doing so, you have to parse through the .zarray files, or use Tensorstore to do so. Again, we're working on a more user-friendly way of doing this. In the future, LazyArray will also have properties like shape and dtype even before materialization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants