Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.load(pruned_model) #26

Open
wuzhiyang2016 opened this issue Nov 21, 2018 · 7 comments
Open

torch.load(pruned_model) #26

wuzhiyang2016 opened this issue Nov 21, 2018 · 7 comments

Comments

@wuzhiyang2016
Copy link

wuzhiyang2016 commented Nov 21, 2018

when i use torch.load() to load pruned model , error happened: AttributeError: 'module' object has no attribute 'ModifiedVGG16Model', anyone meet this problem?

@ghost
Copy link

ghost commented Dec 2, 2018

@wuzhiyang2016 I've got the same problem. I added the ModifiedVGG16Model class in my test script, with no success. Do you still have the same issue?

@ghost
Copy link

ghost commented Dec 2, 2018

@jacobgil Any idea?

@jacobgil
Copy link
Owner

jacobgil commented Dec 2, 2018

The code in the repo saves the entire model with pickling, instead of the state dict, which is actually a bad practice.
A better way would be to save only the state_dict, and then load it.
Going to change it now to

state_dict = model.state_dict()
save(state_dict, 'model.chkpt')

Then you can load the model like this:

model = ModifiedVGG16Model()
checkpoint = torch.load(checkpoint_path, \
        map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint)
model.eval()

@wuzhiyang2016
Copy link
Author

@wuzhiyang2016 I've got the same problem. I added the ModifiedVGG16Model class in my test script, with no success. Do you still have the same issue?

like the author, when we load the saved model , ModifiedVGG16Model() should be defined

@mlcoop
Copy link

mlcoop commented Aug 2, 2019

@jacobgil Hey! There is a problem with size mismatch. For example: RuntimeError: Error(s) in loading state_dict for ModifiedVGG16Model: size mismatch for features.0.weight: copying a param with shape torch.Size([50, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]). size mismatch for features.0.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for features.2.weight: copying a param with shape torch.Size([44, 50, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).

Is there way to solve it ?

@ms-krajesh
Copy link

@jacobgil Hey! There is a problem with size mismatch. For example: RuntimeError: Error(s) in loading state_dict for ModifiedVGG16Model: size mismatch for features.0.weight: copying a param with shape torch.Size([50, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]). size mismatch for features.0.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for features.2.weight: copying a param with shape torch.Size([44, 50, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).

Is there way to solve it ?

@jacobgil : Since the pruned model will be having different in and out filters in each layer. Is there any way to load the pruned model state dictionary with original model class?

@mlcoop
Copy link

mlcoop commented Aug 15, 2019

@ms-krajesh Probably not a solution but I just rewrote model architecture with altered layer numbers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants