Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weight loading issue #14016

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" Classes to support TF Encoder-Decoder architectures """


import tempfile
from typing import Optional

import tensorflow as tf
Expand Down Expand Up @@ -254,6 +255,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
Example::
>>> from transformers import TFEncoderDecoderModel
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
"""

from_pt = kwargs.pop("from_pt", False)
Expand Down Expand Up @@ -369,6 +375,14 @@ def from_encoder_decoder_pretrained(
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)

decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
Expand Down Expand Up @@ -397,6 +411,14 @@ def from_encoder_decoder_pretrained(
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)

# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.")
Expand Down
18 changes: 18 additions & 0 deletions tests/test_modeling_tf_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,14 @@ def test_bert2bert_summarization(self):

self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])

# Test with the TF checkpoint
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")

output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])


@require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
Expand Down Expand Up @@ -785,6 +793,16 @@ def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)

# Make sure `from_pretrained` following `save_pretrained` work and give the same result
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
encoder_decoder_tf = TFEncoderDecoderModel.from_pretrained(tmp_dirname)

logits_tf_2 = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits

max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)

# TensorFlow => PyTorch
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
Expand Down