-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to restore checkpoints if not all arrays are sharded? #381
Comments
I think this may be consequence of using the Flax |
Thanks, it appears to work. Here is a follow-up question: in the documentation, the In the Flax documentation here, it seems to be possible to pass a smaller target reference. However, I've trying creating one with the same axis dimension as my shardings, but I get an error like
What is the exact logic? How can I create a reference smaller than the actual train state? |
For starters, the Secondly, it is possible to skip the initialization of I'm not actually sure what the Flax documentation is talking about when they say the reference may be smaller than the actual train state.... the |
Fantastic, seems to have worked like a charm. Thanks a lot! |
Reopening the issue since I have got a follow-up problem. In my code, the restoring function comes within a jitted function. Unfortunately, I get the following concretization error:
Any idea how I can work around this issue? |
This is an easy answer: you can't restore from within a jitted function. This is an issue we encountered before with a few other Flax users, and we reached the conclusion that they would just need to move their restore outside the jitted function. Sorry! |
Alright, then perhaps there is a way to obtain a tree of shapes corresponding to the checkpoint, without actually restoring the checkpoint? I tried using |
There is, but it's currently not well integrated into the API. Orbax does support getting the pytree structure of the checkpoint via the |
Thanks. Related to Anyway, in order to get shapes, currently I'm restoring the whole checkpoint as lazy (using
Would you say this method allocates the full checkpoint in memory? |
If you only have a This does materialize the entire checkpoint. If you want to look up the shapes without doing so, you have to parse through the .zarray files, or use Tensorstore to do so. Again, we're working on a more user-friendly way of doing this. In the future, LazyArray will also have properties like shape and dtype even before materialization. |
Hi,
I am trying to restore a checkpoint where some of its array are not partitioned. I believe this is a common use case, as one may not want (or can't) to partition all parameters in a model. The following is a minimal example:
I get the following error message:
"ValueError: Sharding of jax.Array cannot be None. Provide
mesh
andmesh_axes
ORsharding
."I then found out that I can transform
unsharded
into anumpy.array
, restore it, then assign again an empty partition. As follows:This seems to work as desired. Would you say this is the right way to do it? Is there an easier way?
Thanks!
The text was updated successfully, but these errors were encountered: