diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index c59e8c52f1c9..cf682d139336 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -15,6 +15,7 @@ """ Classes to support TF Encoder-Decoder architectures """ +import tempfile from typing import Optional import tensorflow as tf @@ -254,6 +255,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): >>> # This is only for copying some specific attributes of this particular model. >>> model.config = _model.config + Example:: + + >>> from transformers import TFEncoderDecoderModel + >>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16") + """ from_pt = kwargs.pop("from_pt", False) @@ -369,6 +375,14 @@ def from_encoder_decoder_pretrained( kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + # This is necessary to make `from_pretrained` following `save_pretrained` work correctly + if kwargs_encoder.get("from_pt", None): + del kwargs_encoder["from_pt"] + with tempfile.TemporaryDirectory() as tmp_dirname: + encoder.save_pretrained(tmp_dirname) + del encoder + encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder) + decoder = kwargs_decoder.pop("model", None) if decoder is None: if decoder_pretrained_model_name_or_path is None: @@ -397,6 +411,14 @@ def from_encoder_decoder_pretrained( kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + # This is necessary to make `from_pretrained` following `save_pretrained` work correctly + if kwargs_decoder.get("from_pt", None): + del kwargs_decoder["from_pt"] + with tempfile.TemporaryDirectory() as tmp_dirname: + decoder.save_pretrained(tmp_dirname) + del decoder + decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder) + # Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly. if encoder.name != "encoder": raise ValueError("encoder model must be created with the name `encoder`.") diff --git a/tests/test_modeling_tf_encoder_decoder.py b/tests/test_modeling_tf_encoder_decoder.py index 1880f70ace61..77737116574f 100644 --- a/tests/test_modeling_tf_encoder_decoder.py +++ b/tests/test_modeling_tf_encoder_decoder.py @@ -457,6 +457,14 @@ def test_bert2bert_summarization(self): self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS]) + # Test with the TF checkpoint + model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16") + + output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist() + summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS]) + @require_tf class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase): @@ -785,6 +793,16 @@ def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self): max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy())) self.assertAlmostEqual(max_diff, 0.0, places=3) + # Make sure `from_pretrained` following `save_pretrained` work and give the same result + with tempfile.TemporaryDirectory() as tmp_dirname: + encoder_decoder_tf.save_pretrained(tmp_dirname) + encoder_decoder_tf = TFEncoderDecoderModel.from_pretrained(tmp_dirname) + + logits_tf_2 = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits + + max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy())) + self.assertAlmostEqual(max_diff, 0.0, places=3) + # TensorFlow => PyTorch with tempfile.TemporaryDirectory() as tmp_dirname: encoder_decoder_tf.save_pretrained(tmp_dirname)