Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why we need the init_weight function in BERT pretrained model #4701

Closed
allanj opened this issue Jun 1, 2020 · 5 comments
Closed

Why we need the init_weight function in BERT pretrained model #4701

allanj opened this issue Jun 1, 2020 · 5 comments
Labels
Usage General questions about the library

Comments

@allanj
Copy link
Contributor

allanj commented Jun 1, 2020

❓ Questions & Help

I have already tried asking the question is SO, which you can find the link here.

Details

In the code by Hugginface transformers, there are many fine-tuning models have the function init_weight.
For example(here), there is a init_weight function at last. Even though we use from_pretrained, it will still call the constructor and call init_weight function.

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

As I know, it will call the following code

def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    elif isinstance(module, BertLayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

My question is If we are loading the pre-trained language model, why do we need to initialize the weight for every module?

I guess I must be misunderstanding something here.

@BramVanroy BramVanroy added the Usage General questions about the library label Jun 1, 2020
@BramVanroy
Copy link
Collaborator

BramVanroy commented Jun 1, 2020

Have a look at the code for .from_pretrained(). What actually happens is something like this:

  • find the correct base model class to initialise
  • initialise that class with pseudo-random initialisation (by using the _init_weights function that you mention)
  • find the file with the pretrained weights
  • overwrite the weights of the model that we just created with the pretrained weightswhere applicable

This ensure that layers were not pretrained (e.g. in some cases the final classification layer) do get initialised in _init_weights but don't get overridden.

@allanj
Copy link
Contributor Author

allanj commented Jun 1, 2020

Great. Thanks. I also read through the code and that really clears my confusion.

@allanj allanj closed this as completed Jun 1, 2020
@BramVanroy
Copy link
Collaborator

Good. If the answer was sufficient on Stack Overflow as well, please close that too.

@sunersheng
Copy link

Have a look at the code for .from_pretrained(). What actually happens is something like this:

  • find the correct base model class to initialise
  • initialise that class with pseudo-random initialisation (by using the _init_weights function that you mention)
  • find the file with the pretrained weights
  • overwrite the weights of the model that we just created with the pretrained weightswhere applicable

This ensure that layers were not pretrained (e.g. in some cases the final classification layer) do get initialised in _init_weights but don't get overridden.

when we construct BertForSequenceClassification from pre-trained model, didn't we overwrite the loaded weights with random initialisation?

@BramVanroy
Copy link
Collaborator

@sunersheng No, the random initialization happens first and then the existing weights are loaded into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Usage General questions about the library
Projects
None yet
Development

No branches or pull requests

3 participants