Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom TF weights loading #7422

Merged
merged 20 commits into from Oct 5, 2020
Merged

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Sep 28, 2020

This PR provides a custom weight loading function in order to take into account dynamic model architecture building. More precisely, the brand new loading function takes into account the authorized_unexpected_keys and authorized_missing_keys class attributes enabling the possibility to ignore some layers in the models.

@codecov
Copy link

codecov bot commented Sep 28, 2020

Codecov Report

Merging #7422 into master will decrease coverage by 0.21%.
The diff coverage is 87.50%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #7422      +/-   ##
==========================================
- Coverage   78.51%   78.30%   -0.22%     
==========================================
  Files         184      181       -3     
  Lines       36734    35917     -817     
==========================================
- Hits        28843    28125     -718     
+ Misses       7891     7792      -99     
Impacted Files Coverage Δ
src/transformers/modeling_tf_pytorch_utils.py 88.37% <50.00%> (-1.94%) ⬇️
src/transformers/modeling_tf_utils.py 88.06% <92.30%> (+0.84%) ⬆️
src/transformers/modeling_tf_bert.py 98.91% <100.00%> (+<0.01%) ⬆️
src/transformers/modeling_tf_mobilebert.py 24.59% <0.00%> (-72.35%) ⬇️
src/transformers/configuration_mobilebert.py 26.47% <0.00%> (-70.59%) ⬇️
src/transformers/modeling_mobilebert.py 23.51% <0.00%> (-65.93%) ⬇️
src/transformers/modeling_tf_xlm.py 58.52% <0.00%> (-34.74%) ⬇️
src/transformers/modeling_lxmert.py 69.91% <0.00%> (-20.82%) ⬇️
src/transformers/trainer_utils.py 63.30% <0.00%> (-5.35%) ⬇️
src/transformers/trainer.py 63.23% <0.00%> (-1.59%) ⬇️
... and 24 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 95f792a...6f52cc9. Read the comment docs.

@jplu jplu marked this pull request as ready for review September 28, 2020 10:15
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks for adding this to the TF side!
I have two nits on the docs and I know I'm the one to blame since I know where you copied it from ;-). We can fix the PreTrainedModel docstrings to match in this PR or I'll do another one for that.

src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_utils.py Outdated Show resolved Hide resolved
@jplu
Copy link
Contributor Author

jplu commented Sep 28, 2020

Just merged your suggestions :)

@@ -518,7 +518,7 @@ def __init__(self, config, **kwargs):
self.return_dict = config.use_return_dict
self.embeddings = TFBertEmbeddings(config, name="embeddings")
self.encoder = TFBertEncoder(config, name="encoder")
self.pooler = TFBertPooler(config, name="pooler")
self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this - it's great that it works now :-) ! Could you maybe also add it for TFAlbert, TFMobileNet and TFLongformer as is done in the PyTorch versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!

@@ -853,8 +855,7 @@ def call(self, inputs, **kwargs):

@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):

authorized_missing_keys = [r"pooler"]
authorized_unexpected_keys = [r"pooler", r"nsp___cls"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome :-)

@@ -781,6 +781,8 @@ class TFBertForPreTrainingOutput(ModelOutput):
BERT_START_DOCSTRING,
)
class TFBertModel(TFBertPreTrainedModel):
authorized_unexpected_keys = [r"nsp___cls"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add "nsp__cls" to authorized_unexpected_keys here? I'm not sure that we need it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if one initializes a TFBertModel with nsp layer it is correct to show a warning that nsp__cls is not used because nsp__cls is not part of TFBertMainLayer in contrast to pooler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it, just in case it appears in the weights we don't raise the exception. Do you think it would be better to authorize none keys here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine, let's remove it then!

resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
from tensorflow.python.keras import backend as K
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this import to the top?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great addition! This function allows better warnings and to remove the unnecessary self.pooler layer for some models.

Two things I would improve:

  • Can you add the add_pooling_layer logic to TFAbert, TFMobileBert and TFLongformer as well?
  • I don't think we should add layers such as nsp__cls to authorized_unexpected_keys as explained below

@jplu
Copy link
Contributor Author

jplu commented Sep 28, 2020

@patrickvonplaten I have done some updates, let me know if it looks like what you have in mind.

@jplu jplu marked this pull request as draft September 28, 2020 17:43
@jplu
Copy link
Contributor Author

jplu commented Sep 28, 2020

There is an issue with Longformer apparently.

@jplu
Copy link
Contributor Author

jplu commented Sep 28, 2020

Ok, I found why, and I should have thought about this much before.... 😣

We cannot have None into a tuple, the logic works only when the return_dict is True.

@jplu jplu marked this pull request as ready for review September 29, 2020 14:38
@jplu
Copy link
Contributor Author

jplu commented Oct 5, 2020

@LysandreJik are we able to merge?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot @jplu!

Comment on lines -320 to +330
@slow
def test_model_from_pretrained(self):
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
model = TFBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")
self.assertIsNotNone(model)

def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
)
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
for layer in output_loading_info["missing_keys"]:
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to put tiny models here and test those without slow tests in a single model, so as to have feedback on this without relying on the slow testing suite, but on each PR's CI.

@patrickvonplaten
Copy link
Contributor

Good to merge for me

@LysandreJik
Copy link
Member

Ran the slow tests, they pass.

@LysandreJik LysandreJik merged commit 9cf7b23 into huggingface:master Oct 5, 2020
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
* First try

* Fix TF utils

* Handle authorized unexpected keys when loading weights

* Add several more authorized unexpected keys

* Apply style

* Fix test

* Address Patrick's comments.

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply style

* Make return_dict the default behavior and display a warning message

* Revert

* Replace wrong keyword

* Revert code

* Add forgot key

* Fix bug in loading PT models from a TF one.

* Fix sort

* Add a test for custom load weights in BERT

* Apply style

* Remove unused import

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
@jplu jplu deleted the tf-weight-load branch June 13, 2023 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants