Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RobertaForMaskedLM implementation identical to fairseq #2928

Closed
wants to merge 5 commits into from
Closed

Make RobertaForMaskedLM implementation identical to fairseq #2928

wants to merge 5 commits into from

Conversation

BramVanroy
Copy link
Collaborator

@BramVanroy BramVanroy commented Feb 20, 2020

closes #1874

The implementation of RoBERTa in transformers differs from the original implementation in fairseq, as results showed (cf. #1874). I have documented my findings here #1874 (comment) and made the corresponding changes accordingly in this PR.

Someone should check, however, that removing get_output_embeddings() does not have any adverse side-effects.

In addition, someone who is knowledgeable about Tensorflow should check the TF implementation of RoBERTa, too.

@BramVanroy
Copy link
Collaborator Author

TODO: #2913 (comment)

@codecov-io
Copy link

codecov-io commented Feb 20, 2020

Codecov Report

Merging #2928 into master will decrease coverage by <.01%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master   #2928      +/-   ##
=========================================
- Coverage    75.3%   75.3%   -0.01%     
=========================================
  Files          94      94              
  Lines       15424   15423       -1     
=========================================
- Hits        11615   11614       -1     
  Misses       3809    3809
Impacted Files Coverage Δ
src/transformers/modeling_roberta.py 95.75% <100%> (-0.02%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 59c23ad...1f290e5. Read the comment docs.

@BramVanroy BramVanroy added the Core: Modeling Internals of the library; Models. label Feb 20, 2020
@BramVanroy BramVanroy assigned BramVanroy and unassigned BramVanroy Feb 20, 2020
@joeddav
Copy link
Contributor

joeddav commented Feb 20, 2020

Looks good. I tested it out and the outputs match exactly everywhere I can see. Requested review from @LysandreJik as well.

Regarding the test mentioned by @sshleifer, you can just test that a slice of the outputs match rather than the entire tensor. See here for an example.

@BramVanroy
Copy link
Collaborator Author

Looks good. I tested it out and the outputs match exactly everywhere I can see. Requested review from @LysandreJik as well.

Regarding the test mentioned by @sshleifer, you can just test that a slice of the outputs match rather than the entire tensor. See here for an example.

Thanks, will add tests later. I am still a bit confused why the weights of the embeddings are tied to the LMHead in the original implementation, though. I don't quite get the intention there.

@BramVanroy
Copy link
Collaborator Author

BramVanroy commented Feb 20, 2020

Hm, perhaps this warning message should not be there.

Weights of RobertaForMaskedLM not initialized from pretrained model: ['lm_head.weight']
Weights from pretrained model not used in RobertaForMaskedLM: ['lm_head.decoder.weight']

  • lm_head.weight is initialised because it takes the embedding weights
  • the weights from the pretrained model are not used because they are not required

@joeddav
Copy link
Contributor

joeddav commented Feb 20, 2020

@BramVanroy Where are you getting that warning? I don't see it when I call RobertaForMaskedLM.from_pretrained

@BramVanroy
Copy link
Collaborator Author

You can only see it if your logging level is set to INFO or lower. So you can put the following before loading the model.

import logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO)

@joeddav
Copy link
Contributor

joeddav commented Feb 20, 2020

Oh I see. Looks like the problem is just that the weight param introduced has a different name format than before. Rather than using the functional API as you did here, I would just manually override decoder.weight when weight is passed. I.e.,

self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if weight is not None:
    self.decoder.weight = weight

As you mentioned, it's not a huge issue since the weights are correctly loaded from the embeddings anyway, but probably a bit cleaner if the names align.

Bram Vanroy added 3 commits February 20, 2020 21:04
Re-added decoder name to avoid getting warning messages. In practice, this does not change anything about the model
…irseq

test_lm_inference_identical_to_fairseq compares the output of HuggingFace RoBERTa to a slice of the output tensor of the (original) fairseq RoBERTa
@BramVanroy
Copy link
Collaborator Author

For those interested, I found the answer to the why on Twitter because of a helpful comment. Apparently this is common practice and has been introduced a while back in Using the output of embeddings to improve language models.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @BramVanroy! I can see there's an issue here but I don't think this is the way to solve it.

We actually do tie the weights together, so there's no need to do any additional tying. We actually tie the weights for every model that has an LM head (Masked or causal).

The issue here is because of the bias I introduced a few weeks ago with #2521. The way I did it means that the bias was actually applied twice.

The correct way to fix it would be to change

        x = self.decoder(x) + self.bias

to

        x = self.decoder(x)

in the forward method. The bias is already part of the decoder, so no need to apply it once more.

Do you want to update your PR, or should I do one to fix it?

@BramVanroy
Copy link
Collaborator Author

Hi @BramVanroy! I can see there's an issue here but I don't think this is the way to solve it.

We actually do tie the weights together, so there's no need to do any additional tying. We actually tie the weights for every model that has an LM head (Masked or causal).

The issue here is because of the bias I introduced a few weeks ago with #2521. The way I did it means that the bias was actually applied twice.

The correct way to fix it would be to change

        x = self.decoder(x) + self.bias

to

        x = self.decoder(x)

in the forward method. The bias is already part of the decoder, so no need to apply it once more.

Do you want to update your PR, or should I do one to fix it?

Aha, my bad. I thought I finally contributed something useful! 😳 You can add a PR, I'll close this one. (Perhaps the updated test is still useful so that something like this doesn't happen in the future.)

Can you link to the lines where the weight tying is happening, though? I must have completely missed it.

@BramVanroy BramVanroy closed this Feb 21, 2020
@LysandreJik
Copy link
Member

Your contributions are more than useful, @BramVanroy, and I'm glad you tried to fix an issue when you discovered one, thank you.

To answer your question, the PreTrainedModel abstract class has an init_weights method which ties the input embeddings to the output embeddings.

This method is not directly called by any model class, but it is called by the init_weights method of that same abstract class.

It is this last method that is called by every model during their instantiation, for example with RobertaModel.

This is only the PyTorch way though, the TensorFlow way is different. In TensorFlow, we use a single layer that can be called as an embedding or a linear layer, as you may see in the BertEmbeddings class. Please note the mode flag which makes possible the choice between the layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Disparitry with Fairseq Roberta implementation for predicting the mask token
4 participants