Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the correct way to restore the checkpoints? #11

Closed
ka2hyeon opened this issue Feb 21, 2023 · 7 comments
Closed

What is the correct way to restore the checkpoints? #11

ka2hyeon opened this issue Feb 21, 2023 · 7 comments

Comments

@ka2hyeon
Copy link

When I run

tf.saved_model.load('./robotics_transformer/trained_checkpoints/rt1main')

I got a following error,

IndexError: Read less bytes than requested

All the efforts to restore the checkpoint you provided were failed.
For example, the following code also not worked for me.

from tf_agents.utils.common import Checkpointer
checkpointer= Checkpointer(
      agent=agent,
      ckpt_dir='./robotics_transformer/trained_checkpoints/rt1main'
  )
checkpointer.initialize_or_restore()

What is the correct way to restore the checkpoints?

@yaolug
Copy link
Contributor

yaolug commented Feb 23, 2023

The second way should be correct. How is agent defined?

Could you also try

checkpointer= Checkpointer(
      agent=agent,
      ckpt_dir='./robotics_transformer/trained_checkpoints/rt1main',
      global_step=424760
  )

@ka2hyeon
Copy link
Author

ka2hyeon commented Feb 23, 2023

Thank you for reply.
However, specifying global_step argument didn't work either.

The agent is instance of SequenceAgent created in ./sequence_agent_test.py.
In my opinion, this error comes from the mismatching of parameters between the RT-1 model declared in your test code and the checkpoint model. I need parameters used for training the checkpoint, but they cannot be inferred from only with your code.

For example, I wonder which value of following parameters are used in the checkpoint model: num_layers (initialized as 1), layer_size (initialized as 4096), num_heads (initialized as 8), feed_forward_size (initialized as 512). Some parameters can be inferred from your paper (e.g. time_sequence_length=6, vocab_size=256), but I cannot know parameters not mentioned the paper.

@yaolug
Copy link
Contributor

yaolug commented Feb 23, 2023

See configs/transformer_mixin.gin for parameters.

Also for savedmodel, the following could work

from tf_agents.policies import py_tf_eager_policy

py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    model_path='./robotics_transformer/trained_checkpoints/rt1main',
    load_specs_from_pbtxt=True,
    use_tf_function=True,
)

@ka2hyeon
Copy link
Author

All my problem solved! Thank you for a kind help.
tf.saved_model.load('...') didn't work, but py_tf_eager_policy.SavedModelPyTFEagerPolicy(...) worked.
Also, I missed transformer_mixin.gin, and now I can restore all correct parameters.

@AliBuildsAI
Copy link

@ka2hyeon could you please share your environemnt via pip list? I tried both methods on both tf1 and tf2 and got errors.

@ka2hyeon
Copy link
Author

ka2hyeon commented Mar 22, 2023

@AliBuildsAI In my environment, I am using following tensorflow-related packages.

tensorboard 2.8.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow 2.8.2
tensorflow-addons 0.17.1
tensorflow-datasets 4.6.0
tensorflow-estimator 2.8.0
tensorflow-hub 0.12.0
tensorflow-io-gcs-filesystem 0.26.0
tensorflow-metadata 1.9.0
tensorflow-model-optimization 0.7.2
tensorflow-probability 0.16.0
tensorflow-text 2.8.2
tf-agents 0.12.0

@oym1994
Copy link

oym1994 commented Apr 8, 2023

All my problem solved! Thank you for a kind help. tf.saved_model.load('...') didn't work, but py_tf_eager_policy.SavedModelPyTFEagerPolicy(...) worked. Also, I missed transformer_mixin.gin, and now I can restore all correct parameters.

Hi, could you please provide the complete code of py_tf_eager_policy.SavedModelPyTFEagerPolicy(...)? Thank you!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants