Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size mismatch while loading pretrained model for fine-tuning with additional characters #210

Closed
PrithaGanguly opened this issue Aug 4, 2020 · 7 comments

Comments

@PrithaGanguly
Copy link

Hi,

I have followed all the necessary steps for FT as given in the other threads but size mismatch error keeps recurring as the num of characters in the pre-trained model and custom dataset is not same.
Can somebody please guide me as to how to load the pre-trained model with a modified prediction layer such it can be used for fine-tuning?

Thanks in advance!

@PrithaGanguly
Copy link
Author

Figured it out, thanks!

@SrijithBalachander
Copy link

SrijithBalachander commented Aug 22, 2020

@PrithaGanguly Hi, I have a similar doubt!
So did you replace the final linear layer with the required number of characters as the output dimension? Or were you able to just modify the character set and make it work?
Thanks!

@PrithaGanguly
Copy link
Author

@SrijithBalachander Hi, I was facing this issue when my character set was larger as compared to the pretrained model. Therefore, to make it work, while loading the pretrained model weights, I simply modified the code in train.py as follows:
if opt.FT:
checkpoint = torch.load(opt.saved_model)
checkpoint = {k: v for k, v in checkpoint.items() if (k in model.state_dict().keys()) and (model.state_dict()[k].shape == checkpoint[k].shape)}
for name in model.state_dict().keys() :
if name in checkpoint.keys() :
model.state_dict()[name].copy_(checkpoint[name])
else:
model.load_state_dict(torch.load(opt.saved_model))

Hope this helps!

@kimlia545
Copy link

@PrithaGanguly
Hello Can you help me?
I am trying to train by adding characters

'train.py'
if opt.FT:
model.load_state_dict(torch.load(opt.saved_model), strict=False)

=> I modified it in your way

if opt.FT:
checkpoint = torch.load(opt.saved_model)
checkpoint = {k: v for k, v in checkpoint.items() if (k in model.state_dict().keys()) and (model.state_dict()[k].shape == checkpoint[k].shape)}
for name in model.state_dict().keys() :
if name in checkpoint.keys() :
model.state_dict()[name].copy_(checkpoint[name])
else:
model.load_state_dict(torch.load(opt.saved_model))

but I got problem
python train_test.py --train_data data_lmdb/training --workers 0 --valid_data data_lmdb/validation --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn --saved_model pretrained_models/TPS-ResNet-BiLSTM-Attn.pth --FT
RuntimeError: Error(s) in loading state_dict for DataParallel:
size mismatch for module.Prediction.attention_cell.rnn.weight_ih: copying a param with shape torch.Size([1024, 294]) from checkpoint, the shape in current model is torch.Size([1024, 1637]).
size mismatch for module.Prediction.generator.weight: copying a param with shape torch.Size([38, 256]) from checkpoint, the shape in current model is torch.Size([1381, 256]).
size mismatch for module.Prediction.generator.bias: copying a param with shape torch.Size([38]) from checkpoint, the shape in current model is torch.Size([1381]).

python train_test.py --train_data data_lmdb/training --workers 0 --valid_data data_lmdb/validation --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC --saved_model pretrained_models/TPS-ResNet-BiLSTM-CTC.pth --FT
RuntimeError: Error(s) in loading state_dict for DataParallel:
size mismatch for module.Prediction.weight: copying a param with shape torch.Size([37, 256]) from checkpoint, the shape in current model is torch.Size([1379, 256]).
size mismatch for module.Prediction.bias: copying a param with shape torch.Size([37]) from checkpoint, the shape in current model is torch.Size([1379]).

@Treeboy2762
Copy link

@kimlia545 indentation 확인해보시길 바랍니다

    if opt.FT:
        checkpoint = torch.load(opt.saved_model)
        checkpoint = {k: v for k, v in checkpoint.items() if (k in model.state_dict().keys()) and (model.state_dict()[k].shape == checkpoint[k].shape)}
        for name in model.state_dict().keys() :
            if name in checkpoint.keys() :
                model.state_dict()[name].copy_(checkpoint[name])
    else:
        model.load_state_dict(torch.load(opt.saved_model))

@LolaWei
Copy link

LolaWei commented Dec 13, 2021

thanks! this solved my mismatch question!

@devAbreu
Copy link

This works for me.

if opt.FT:
    current_model_dict = model.state_dict()
    loaded_state_dict = torch.load(path, map_location=torch.device('cpu'))
    new_state_dict={k:v if v.size()==current_model_dict[k].size()  else  current_model_dict[k] for k,v in 
    zip(current_model_dict.keys(), loaded_state_dict.values())}
    model.load_state_dict(new_state_dict, strict=False)
else:
    model.load_state_dict(torch.load(opt.saved_model))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants