Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix embeddings resizing in TF models #8657

Merged
merged 31 commits into from
Dec 14, 2020

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Nov 19, 2020

What does this PR do?

Currenlty when the embeddings are resized the biases are not resized in same time. In TF there is no explicit link between the decoder weights and biases in a dense layer contrarily than in PT. This PR fixes this issue by resizing in same time the biases, even thought I don't know if this is the best solution. @LysandreJik @sgugger what do you think?

Comment on lines 791 to 851
if new_num_tokens is not None:
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name="bias"
)
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name="bias"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that this does not erase the current value of the bias, instead replacing it by all zeros?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I does erase yes. I thought that the resizing were happening only with a fresh new model just before the training to adapt the model the vocab size it has to be trained on. Something like:

from transformers import TFBertMaskedLM, BertConfig

config = BertConfig()
model = TFBertMaskedLM(config)
model.resize_token_embeddings(new_size)

But I can make not erasing the current bias values, it is not a problem :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the case! Resizing is important also in the case of fine-tuning. In that case, you would want to keep the existing token embeddings, but add randomly initialized columns for the added tokens.

@@ -486,16 +486,17 @@ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Var
# todo: initializer range is not always passed in config.
init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = self.add_weight(
"weight",
name=word_embeddings.name.split(":")[0],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix a naming issue. When the resized is created, the name is changed as well and raises a naming issue when the updated model saved and then loaded.

from transformers import BertConfig, TFBertForMaskedLM
model = TFBertForMaskedLM(BertConfig())
model(model.dummy_inputs)
model.resize_token_embeddings(28996)
model.save_pretrained("here")
model = TFBertForMaskedLM.from_pretrained("here")

Gives

Some layers from the model checkpoint at here were not used when initializing TFBertForMaskedLM: ['']
- This IS expected if you are initializing TFBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
.....

Weights and bias are not properly loaded. After naming fix:

All model checkpoint layers were used when initializing TFBertForMaskedLM.

All the layers of TFBertForMaskedLM were initialized from the model checkpoint at here.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh boy, that part of the lib is really clumsy (compared to the PyTorch side), thanks for tackling this resizing problem!

After discussing with @LysandreJik we would like to suggest a different API (mainly to avoid redefining resize_bias and override resize_token_embeddings at each model.

First thing first, get_output_embeddings is not used anywhere inside the tf utils, so I'd suggest removing it. (It is used on the PyTorch side to tie the weights, but there is no such thing in TF.) The only thing it would break from a search in the repo is the check in generation_tf_utils (L187) but that check is super clumsy too (it should use one of the auto mapping instead of using that attribute to determine if a model has a LM head).

Then, we can replace this get_output_embeddings by get_output_bias that would return None by default. The resize_token_embeddings method should then check for the result of this method, and if it's not None, resize the bias. This way you would avoid the multiple copies of resize_bias :-)

Let me know if this makes sense.

super().build(input_shape)

def resize_bias(self, new_num_tokens):
if new_num_tokens is not None:
num_tokens_to_copy = min(self.bias.shape[0], new_num_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.bias is never used, so we can forget about resizing it. We should also delete it from the model, then we can add it to the variable _keys_to_ignore_on_load_unexected to avoid the warning.

Comment on lines 678 to 684
init_weights = self.decoder.value()[:num_tokens_to_copy]
self.decoder = self.add_weight(
shape=(self.config.vocab_size, self.config.embedding_size),
initializer="zeros",
trainable=True,
name=self.decoder.name.split(":")[0],
)
self.decoder.assign(init_weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is about the weights, so it shouldn't be there.

@jplu
Copy link
Contributor Author

jplu commented Nov 24, 2020

Thanks @sgugger for your useful comments. I was thinking the same about get_output_embeddings but I didn't want to change to much things in same time.

I like very much the solution you proposed and I'm totally fine with it!

@jplu jplu force-pushed the fix-resize-ebd branch 2 times, most recently from 3d3b129 to 3705fe0 Compare November 25, 2020 09:11
@jplu
Copy link
Contributor Author

jplu commented Nov 25, 2020

@sgugger I have reworked the resizing for the bias and applied it on BERT at first for testing. Are you agree with this new way to do? If yes, I will do the same for the other models.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks perfect to me. Just one nit on the naming: since you're returning the layer that has the bias instead of the actual bias, I think get_bias should be named get_output_layer_with_bias (or something better if you're more inspired).

@@ -184,10 +197,10 @@ def generate(
"""

# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
is_lm, list_models = self.is_lm_model()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Way cleaner, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really agree here -> I prefer to leave the self.get_output_embeddings() check here. E.g. TFRag won't be in any of the TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING. And since TFRag will have two models of that can both generate, but won't both be able to be in the same MAPPING class we will already run into problems here. Also this is inconsistent with PyTorch.

Logically, I like this check as it is right now. If the model has output_embeddings then it can generate because all one needs for generate is a logit output vector. I don't want to tie this functionality to the LM Mappings in the modeling_auto.py. It creates an unnecessary dependency IMO, makes in unnecessarily more inflexible IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to be sure, self.get_output_embeddings() is nowhere used except for a simple check in the generation_tf_utils.py, so what I deduce from that is that it is a kind of useless method. Unless they are used somewhere else?

For sure we can create a smarter way to detect if a model has an LM layer, if the role of self.get_output_embeddings() is just to know that. What could be the best compromise?

"""
Returns the model's output embeddings.
Get the layer that handles a bias attribute in case the model has an LM head.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Get the layer that handles a bias attribute in case the model has an LM head.
Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the embeddings.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the point made below, the changes look good.

Comment on lines 48 to 59
def is_lm_model(self):
from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
)

list_models = list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values())
list_models.extend(list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values()))
list_models.extend(list(TF_MODEL_FOR_MASKED_LM_MAPPING.values()))

return (type(self) in list_models, [model.__name__ for model in list_models])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I find the API for this method a bit weird. I wouldn't expect an is_lm_model method which implies a boolean return to return a tuple (bool, List[str]).

  2. I also wouldn't expect that to be a method of TFGenerationMixin given its description:

A class containing all of the functions supporting generation, to be used as a mixin in
    :class:`~transformers.TFPreTrainedModel`.
  1. I find it weird to have an is_lm_model() method on models. If we have this, why not is_sequence_classification_model, is_question_answering_model, etc.

Given how it's implemented, and why it's used for, I'd rather have an has_tied_weights() method which returns a boolean value, and not a tuple. If you want to have the list of lm models, I would put that in another method, which I wouldn't place on TFPreTrainedModel (or any of its mixins), as it's a simple operation over three mappings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'm fine to divide and rename this method into two separate and distinct role 👍

@@ -931,8 +955,28 @@ def __init__(self, config, *inputs, **kwargs):
self.albert = TFAlbertMainLayer(config, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")

def get_output_embeddings(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like to keep this API - completely removing the function name is very breaking in my opinion. This is not really a private method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does-it breaks more precisely?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean I would think that people made use of this method for example if they wrap their language model in a custom class or if they built their own generation method

@jplu
Copy link
Contributor Author

jplu commented Nov 29, 2020

@sgugger @patrickvonplaten @LysandreJik This PR takes care of resizing all the bias, and if we start to change how the embeddings are resized + modify the generation, I think it would be a bit too much and out of the scope of this PR. Then, what I propose is to keep how it was at the beginning in generation_tf_utils.py and the self.get_output_embeddings methods and move this discussion on another PR. In this another PR I would like as well to fully review how the resizing is done, because the number of line of codes can be largely reduced and simplified. What do you think?

@patrickvonplaten
Copy link
Contributor

It would be awesome if we can keep the get_output_embeddings() method and leave generate() as it is and only focus on the resizing problem here. I'm 100% on board with fixing the resizing problem and it'd be awesome to do this orthogonally to get_output_embeddings().

A couple of reasons why I would like to keep get_output_embeddings() (I can copy this to the new PR as well):

  1. Consistency with PyTorch. In PyTorch get_output_embeddings() is even more integrated with other functionalities (like weight tying) and I think we should stay consistent in TF and PT
  2. get_output_embeddings() is an important function IMO to quickly get the correct logit matrix. Without this function it's not at all always obvious how to get the output embeddings for some models (especially EncoderDecoder, RAG, ...). A unified API for all models is of great help here IMO and I use it a lot actually
  3. Don't want to tie the capability of a model to generate() with the MODEL_FOR_.... classes - this is inconsistent with PyTorch and unnecessarily creates a dependency IMO.

@jplu
Copy link
Contributor Author

jplu commented Nov 30, 2020

Thanks a lot @patrickvonplaten for sharing this! I think we should move this talk to a more suited place, and meanwhile I will revert that part of the changes.

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2020

I disagree with you on this @patrickvonplaten

  1. Consistency with PyTorch. In PyTorch get_output_embeddings() is even more integrated with other functionalities (like weight tying) and I think we should stay consistent in TF and PT

The weight tying cannot be done the same way in TF (and honestly the resizing on the PyTorch side is a bit hacky and very hard to understand it kind of goes against our principle of no magic code), so this alone is not an argument for keeping the get_output_embeddings method

  1. get_output_embeddings() is an important function IMO to quickly get the correct logit matrix. Without this function it's not at all always obvious how to get the output embeddings for some models (especially EncoderDecoder, RAG, ...). A unified API for all models is of great help here IMO and I use it a lot actually

The problem is that this function is always implemented to return the input embeddings, so the function as it is does not do anything more than get_input_embeddings while giving the user a false sense of what it returns. (Note that there is no model in TF apart from mobileBERT that has the capability of having different weights for the embeddings and the decoder, the weights are always tied).

  1. Don't want to tie the capability of a model to generate() with the MODEL_FOR_.... classes - this is inconsistent with PyTorch and unnecessarily creates a dependency IMO.

The PyTorch side has no assert, so in that case, the consistent thing is to remove the assert entirely.

I could be convinced to leave the get_output_embeddings method for mobileBERT only since it's the only model where it returns something useful, but it's dangerous to have it otherwise (unless we had a way to untie the weights, but that's for another PR!)

@sgugger
Copy link
Collaborator

sgugger commented Dec 1, 2020

Ok we debriefed a bit with @patrickvonplaten to avoid spamming the PR. I had missed that some models are already using an output embeddings that is different from the input embeddings (most models are tied), like T5 or mT5. So those, like mobileBERT, will definitely need the get_output_embeddings method. Right now though, the resizing does not work for those models.

In the end, we both agree on keeping that method, add the get_output_bias method and the resize_embeddings should use the outputs of those two methods as well as get_input_embeddings in all the things it has to resize. To check if the input embeddings and output_embeddings are the same (and not resize them twice) we could use the ._handle_name attribute of their weights (or something else if you have a better idea).

Does that all make sense?

@jplu
Copy link
Contributor Author

jplu commented Dec 1, 2020

Ok, I'm totally fine with this 👍 ! Nevertheless, there are still few things I don't get.

Right now though, the resizing does not work for those models.

What do you mean by the resizing does not work? Which one? Do you have a more specific example?

To check if the input embeddings and output_embeddings are the same (and not resize them twice) we could use the ._handle_name attribute of their weights (or something else if you have a better idea).

I don't understand this sentence, do you have an example? What do we have to check if the input/output embeddings are different if we get them with two separate methods (namely get_input_embeddings and get_output_embeddings).

@sgugger
Copy link
Collaborator

sgugger commented Dec 1, 2020

The new T5 and mT5 models have an output embedding layer that is sometimes tied to the input embeddings (so same weights like BERT) and sometimes different. When it's different, it is not resized.

I don't understand this sentence, do you have an example? What do we have to check if the input/output embeddings are different if we get them with two separate methods (namely get_input_embeddings and get_output_embeddings).

The output embeddings are, very often, the same as the input embeddings (BERT situation) so in most instances get_output_embeddings will return the same thing as get_input_embeddings (which is why we initially decided to remove get_output_embeddings when discussing together). However, in some cases, it returns something different (mT5 and T5 as mentioned above, or mobileBERT) which is (with avoiding a breaking change) the main argument to keep this get_output_embeddings method. However, when taking its result in resize_embeddings, we should check if we get a different result from get_input_embeddings. Is that clearer?

@jplu
Copy link
Contributor Author

jplu commented Dec 1, 2020

Crystal clear!!! Thanks a lot for the details! I will proceed to the changes once the sprint is finished 👍

@jplu
Copy link
Contributor Author

jplu commented Dec 7, 2020

@patrickvonplaten I have put back the get_output_embeddings, does it seems ok for you now, or did I forget something?

@@ -818,6 +819,32 @@ def __init__(self, config, *inputs, **kwargs):
def get_output_embeddings(self):
return self.albert.embeddings

def resize_token_embeddings(self, new_num_tokens):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PT, albert can just use the "normal" resize_token_embeddings function, see: https://github.com/huggingface/transformers/pull/8880/files#r534143307

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a special case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't do that, the two bias are not resized, so yes it is mandatory.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment on why it's necessary would help (I remember asking you too why there were two biases in ALBERT ;-) )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a look, there is already a comment ;)

# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I was responding to fast here and didn't look closely. That's great then :-)

@@ -1051,6 +1051,24 @@ def __init__(self, config, *inputs, **kwargs):
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)

def resize_token_embeddings(self, new_num_tokens):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! This is the same in PT!

@@ -612,14 +612,24 @@ def set_input_embeddings(self, value):
else:
raise NotImplementedError

def get_output_embeddings(self) -> tf.keras.layers.Layer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave the default get_output_embeddings function for now as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method restored!

@jplu
Copy link
Contributor Author

jplu commented Dec 10, 2020

LGTM! @LysandreJik just missing your approval. The Flax tests do not pass and I don't know why :(

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for working on this and improving the API @jplu!

@LysandreJik LysandreJik merged commit 51d9c56 into huggingface:master Dec 14, 2020
@jplu jplu deleted the fix-resize-ebd branch December 14, 2020 10:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants