Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lean Model Checkpointing #504

Closed
dorukhansergin opened this issue Feb 8, 2019 · 5 comments
Closed

Lean Model Checkpointing #504

dorukhansergin opened this issue Feb 8, 2019 · 5 comments
Assignees
Labels
docs More/better docs

Comments

@dorukhansergin
Copy link
Contributor

Hi,

I ran a Trial and have my model saves my model using torchbearer.callbacks.checkpointers.Best to a file model.pt.

When I load the file with torch.load and run try to make a forward pass with it, I get the following error:

model = MyModule()
state_dict = torch.load('vae.pt')
model.load_state_dict(state_dict) # <== I get the error here

AttributeError: 'StateKey' object has no attribute 'startswith'

I get that model is being saved so that I can be recovered to be ready for torchbearer, but how can we save the model lean?

It seems like here, the model is only saved for reusability by torchbearer.

Thanks a lot!

@ethanwharris
Copy link
Member

Hello,

Thanks for your feedback :) we need to document this better, but you hopefully should be able to load the model with model.load_state_dict(state_dict[torchbearer.MODEL]) if model is a subclass of nn.module. Let us know if that works. I'll leave the issue open as this should be presented better, any ideas you may have are welcome!

@dorukhansergin
Copy link
Contributor Author

@ethanwharris , thank you so much for the prompt reply.

I will try to create a simple script to replicate the error(s) along with the details of my configuration. I came up with a band-aid solution, which I hope to publish too for people who might be experiencing the same problem. Or maybe I'm just missing something.

I see that you got rid of pass_state in the master repo. This issue might be related to it.

@dorukhansergin
Copy link
Contributor Author

Okay, now I got it.

  • model.load_state_dict(state_dict[torchbearer.MODEL]) works great. Thanks for that heads-up.
  • I totally forgot that I defined frowards signature with state to be able to pass state. My post-training operations assume the signature to be forward(x) only, throwing errors about state not being passed. Therefore, I need to overload forward in my Module with the original version to be on the safe side. I don't know if this is a common problem, but this could be added to the document. In fact, as an apology for wasting your time, allow me to do the pull request for it.

I'm a big fan of torchbearer and the team's work. Keep it up!

Thanks for the help

@ethanwharris
Copy link
Member

No problem! glad to be of help. A PR would be much appreciated as this is definitely something we should document better

@ethanwharris
Copy link
Member

Closed by #508

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs More/better docs
Projects
None yet
Development

No branches or pull requests

2 participants