Skip to content

Restoring to CPU #986

@lachinov

Description

@lachinov

Hi,

I'm using orbax checkpoint version '0.5.20'
After moving to the new API, I got an issue similar to #646 , but with a twist. I don't know the structure of the PyTree beforehand.
The code snippet for restoration looks as follows and works fine when the device configuration is the same as during the training. I want to read the saved data to the CPU and from there apply my sharding.

checkpoint = checkpoint_manager.restore(step=best_step,
                                            args=orbax.checkpoint.args.Composite(
                                                configuration=orbax.checkpoint.args.PyTreeRestore(),
                                                metric_values=orbax.checkpoint.args.PyTreeRestore(),
                                                hparams=orbax.checkpoint.args.JsonRestore(),
                                            ),
                                            )

It would be possible if I knew configuration signature.
However, configuration tree structure is learnable and dynamic, so I don't know its signature beforehand.

Is there any solution besides changing the saving logic?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions