Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'global_step' When I load the weight of TAPIR #93

Open
KiritoHarlod opened this issue Apr 28, 2024 · 5 comments
Open

KeyError: 'global_step' When I load the weight of TAPIR #93

KiritoHarlod opened this issue Apr 28, 2024 · 5 comments

Comments

@KiritoHarlod
Copy link

I downloaded tapir_checkpoint_panning.npy from the link provided, but it didn't load. I printed ckpt_state.keys() in the restore function in experiment_utils.py and found only 'params' and 'state' in the key of ckpt_state. There is no 'global_step' in the key. Below are my command to run the evaluation.
python3 ./tapnet/experiment.py --config=./tapnet/configs/tapir_config.py --jaxline_mode=eval_davis_points --config.checkpoint_dir=./tapnet/checkpoint/ --config.experiment_kwargs.config.davis_points_path=/tapvid_davis/tapvid_davis.pkl
How do I get the code to run? Thanks!

@cdoersch
Copy link
Collaborator

I guess we didn't test this codepath; this is more intended for evaluating checkpoints you've trained yourself.

The logic for restoring the Experiment object from the checkpoint is here: https://github.com/google-deepmind/tapnet/blob/main/utils/experiment_utils.py#L160 -- if you change that line to just initialize the global_step to 0, it should probably get you past that error. I'm not sure how much more modification it will need to re-intialize everything correctly, but hopefully not too much.

@KiritoHarlod
Copy link
Author

I guess we didn't test this codepath; this is more intended for evaluating checkpoints you've trained yourself.

The logic for restoring the Experiment object from the checkpoint is here: https://github.com/google-deepmind/tapnet/blob/main/utils/experiment_utils.py#L160 -- if you change that line to just initialize the global_step to 0, it should probably get you past that error. I'm not sure how much more modification it will need to re-intialize everything correctly, but hopefully not too much.

@cdoersch OK, I made a few simple changes that made the tapir_checkpoint_panning.npy load successfully, but I don't know if it would affect TAPIR's results. Hope to get your advice.

experiment_state.global_step = 0 if 'global_step' not in ckpt_state.keys() else int(ckpt_state['global_step'])
exp_mod = experiment_state.experiment_module
for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
  if name == 'opt_state':
    name = 'state' if name not in ckpt_state.keys() else name
  setattr(exp_mod, attr, utils.bcast_local_devices(ckpt_state[name]))

@cdoersch
Copy link
Collaborator

Interesting -- I'm fairly sure that isn't a correct initialization for opt_state (it should be an Adam state), but for evaluation I guess it doesn't matter? IIRC Jaxline will call the full experiment init (including parameters and opt_state) and then overwrite with checkpoint values, so maybe the right thing to do is just to not set any attributes at all if name=='opt_state'?

@jaoguerreiro
Copy link

jaoguerreiro commented Apr 29, 2024

What would be the correct way to run the pretrained tapir model on some videos locally (gpu)?

@KiritoHarlod
Copy link
Author

Interesting -- I'm fairly sure that isn't a correct initialization for opt_state (it should be an Adam state), but for evaluation I guess it doesn't matter? IIRC Jaxline will call the full experiment init (including parameters and opt_state) and then overwrite with checkpoint values, so maybe the right thing to do is just to not set any attributes at all if name=='opt_state'?

@cdoersch Thank you for the suggestion, I'll give it a try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants