Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tuning #30

Closed
niddal-imam opened this issue Sep 17, 2019 · 6 comments
Closed

Fine-tuning #30

niddal-imam opened this issue Sep 17, 2019 · 6 comments

Comments

@niddal-imam
Copy link

Hi Holmeyoung,

When I retrain a pre-trained model, it sounds like the model forgot what it has learned. I mean if I trained a model on synthetic images and fine-tuned the model with real-world images, the model accuracy on the synthetic images decreases. Please correct me if I am wrong, to use a pertained model, I should freeze the last layers. If so, how can I freeze the last layers?

Thanks

@Holmeyoung
Copy link
Owner

Hi,

  1. load the pre-trained model use the same net as the pre-trained model
crnn = net.CRNN(params.imgH, params.nc, nclass, params.nh)

here, the nclass should't equal to len(params.alphabet) + 1, it should be the classed number of pre-trained model.

  1. change the last layer to yourself.
crnn.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

nclass = len(params.alphabet) + 1

@niddal-imam
Copy link
Author

niddal-imam commented Sep 17, 2019

Thanks for the quick response.
Now I should first load the pre-trained model by changing the params.py
pretrained = 'path/to/my/pre-trained'
But I did not understand the second point. What should I change?

Thanks

@Holmeyoung
Copy link
Owner

After load the model, change the rnn layer.

@niddal-imam
Copy link
Author

Thank you Holmeyoung. Should I change crnn.py:
self.cnn = cnn self.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass))
to
crnn.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass))?

@Holmeyoung
Copy link
Owner

Hi, you should change train.py to

import torch.nn as nn
from models.crnn import BidirectionalLSTM
def net_init():
    nclass_pre = 11 # the nclass of your pre-trained model:  = len(params.alphabet--pre version) + 1
    nclass = len(params.alphabet) + 1
    crnn = net.CRNN(params.imgH, params.nc, nclass_pre, params.nh)
    crnn.apply(weights_init)
    if params.pretrained != '':
        print('loading pretrained model from %s' % params.pretrained)
        if params.multi_gpu:
            crnn = torch.nn.DataParallel(crnn)
        crnn.load_state_dict(torch.load(params.pretrained))
    
    crnn.rnn = nn.Sequential(
            BidirectionalLSTM(512, params.nh, params.nh),
            BidirectionalLSTM(params.nh, params.nh, nclass))
    return crnn

@niddal-imam
Copy link
Author

Thank you very much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants