Hi,
I'm using orbax checkpoint version '0.5.20'
After moving to the new API, I got an issue similar to #646 , but with a twist. I don't know the structure of the PyTree beforehand.
The code snippet for restoration looks as follows and works fine when the device configuration is the same as during the training. I want to read the saved data to the CPU and from there apply my sharding.
checkpoint = checkpoint_manager.restore(step=best_step,
args=orbax.checkpoint.args.Composite(
configuration=orbax.checkpoint.args.PyTreeRestore(),
metric_values=orbax.checkpoint.args.PyTreeRestore(),
hparams=orbax.checkpoint.args.JsonRestore(),
),
)
It would be possible if I knew configuration signature.
However, configuration tree structure is learnable and dynamic, so I don't know its signature beforehand.
Is there any solution besides changing the saving logic?
Hi,
I'm using orbax checkpoint version '0.5.20'
After moving to the new API, I got an issue similar to #646 , but with a twist. I don't know the structure of the PyTree beforehand.
The code snippet for restoration looks as follows and works fine when the device configuration is the same as during the training. I want to read the saved data to the CPU and from there apply my sharding.
It would be possible if I knew
configurationsignature.However,
configurationtree structure is learnable and dynamic, so I don't know its signature beforehand.Is there any solution besides changing the saving logic?