New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom TF weights loading #7422
Conversation
Codecov Report
@@ Coverage Diff @@
## master #7422 +/- ##
==========================================
- Coverage 78.51% 78.30% -0.22%
==========================================
Files 184 181 -3
Lines 36734 35917 -817
==========================================
- Hits 28843 28125 -718
+ Misses 7891 7792 -99
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, thanks for adding this to the TF side!
I have two nits on the docs and I know I'm the one to blame since I know where you copied it from ;-). We can fix the PreTrainedModel
docstrings to match in this PR or I'll do another one for that.
Just merged your suggestions :) |
src/transformers/modeling_tf_bert.py
Outdated
@@ -518,7 +518,7 @@ def __init__(self, config, **kwargs): | |||
self.return_dict = config.use_return_dict | |||
self.embeddings = TFBertEmbeddings(config, name="embeddings") | |||
self.encoder = TFBertEncoder(config, name="encoder") | |||
self.pooler = TFBertPooler(config, name="pooler") | |||
self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this - it's great that it works now :-) ! Could you maybe also add it for TFAlbert, TFMobileNet and TFLongformer as is done in the PyTorch versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok!
src/transformers/modeling_tf_bert.py
Outdated
@@ -853,8 +855,7 @@ def call(self, inputs, **kwargs): | |||
|
|||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) | |||
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): | |||
|
|||
authorized_missing_keys = [r"pooler"] | |||
authorized_unexpected_keys = [r"pooler", r"nsp___cls"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome :-)
src/transformers/modeling_tf_bert.py
Outdated
@@ -781,6 +781,8 @@ class TFBertForPreTrainingOutput(ModelOutput): | |||
BERT_START_DOCSTRING, | |||
) | |||
class TFBertModel(TFBertPreTrainedModel): | |||
authorized_unexpected_keys = [r"nsp___cls"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to add "nsp__cls"
to authorized_unexpected_keys
here? I'm not sure that we need it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if one initializes a TFBertModel
with nsp
layer it is correct to show a warning that nsp__cls
is not used because nsp__cls
is not part of TFBertMainLayer
in contrast to pooler
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need it, just in case it appears in the weights we don't raise the exception. Do you think it would be better to authorize none keys here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine, let's remove it then!
resolved_archive_file (:obj:`str`): | ||
The location of the H5 file. | ||
""" | ||
from tensorflow.python.keras import backend as K |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this import to the top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great addition! This function allows better warnings and to remove the unnecessary self.pooler
layer for some models.
Two things I would improve:
- Can you add the
add_pooling_layer
logic to TFAbert, TFMobileBert and TFLongformer as well? - I don't think we should add layers such as
nsp__cls
toauthorized_unexpected_keys
as explained below
@patrickvonplaten I have done some updates, let me know if it looks like what you have in mind. |
There is an issue with Longformer apparently. |
Ok, I found why, and I should have thought about this much before.... 😣 We cannot have |
6cccda6
to
48104a2
Compare
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@LysandreJik are we able to merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks a lot @jplu!
@slow | ||
def test_model_from_pretrained(self): | ||
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: | ||
for model_name in ["bert-base-uncased"]: | ||
model = TFBertModel.from_pretrained(model_name) | ||
self.assertIsNotNone(model) | ||
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random") | ||
self.assertIsNotNone(model) | ||
|
||
def test_custom_load_tf_weights(self): | ||
model, output_loading_info = TFBertForTokenClassification.from_pretrained( | ||
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True | ||
) | ||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"]) | ||
for layer in output_loading_info["missing_keys"]: | ||
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes sense to put tiny models here and test those without slow tests in a single model, so as to have feedback on this without relying on the slow testing suite, but on each PR's CI.
Good to merge for me |
Ran the slow tests, they pass. |
51c36e4
to
294c56b
Compare
* First try * Fix TF utils * Handle authorized unexpected keys when loading weights * Add several more authorized unexpected keys * Apply style * Fix test * Address Patrick's comments. * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply style * Make return_dict the default behavior and display a warning message * Revert * Replace wrong keyword * Revert code * Add forgot key * Fix bug in loading PT models from a TF one. * Fix sort * Add a test for custom load weights in BERT * Apply style * Remove unused import Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This reverts commit c8b54af.
This PR provides a custom weight loading function in order to take into account dynamic model architecture building. More precisely, the brand new loading function takes into account the
authorized_unexpected_keys
andauthorized_missing_keys
class attributes enabling the possibility to ignore some layers in the models.