Skip to content

Commit

Permalink
Fix weight loading issue (#14016)
Browse files Browse the repository at this point in the history
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh committed Nov 15, 2021
1 parent 74e6111 commit a67d47b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" Classes to support TF Encoder-Decoder architectures """


import tempfile
from typing import Optional

import tensorflow as tf
Expand Down Expand Up @@ -254,6 +255,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
Example::
>>> from transformers import TFEncoderDecoderModel
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
"""

from_pt = kwargs.pop("from_pt", False)
Expand Down Expand Up @@ -369,6 +375,14 @@ def from_encoder_decoder_pretrained(
kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_encoder.get("from_pt", None):
del kwargs_encoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder.save_pretrained(tmp_dirname)
del encoder
encoder = TFAutoModel.from_pretrained(tmp_dirname, *model_args, **kwargs_encoder)

decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
Expand Down Expand Up @@ -397,6 +411,14 @@ def from_encoder_decoder_pretrained(
kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)

# This is necessary to make `from_pretrained` following `save_pretrained` work correctly
if kwargs_decoder.get("from_pt", None):
del kwargs_decoder["from_pt"]
with tempfile.TemporaryDirectory() as tmp_dirname:
decoder.save_pretrained(tmp_dirname)
del decoder
decoder = TFAutoModelForCausalLM.from_pretrained(tmp_dirname, **kwargs_decoder)

# Make sure these 2 `tf.keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
if encoder.name != "encoder":
raise ValueError("encoder model must be created with the name `encoder`.")
Expand Down
18 changes: 18 additions & 0 deletions tests/test_modeling_tf_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,14 @@ def test_bert2bert_summarization(self):

self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])

# Test with the TF checkpoint
model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")

output_ids = model.generate(input_ids=input_dict["input_ids"], max_length=None).numpy().tolist()
summary = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])


@require_tf
class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
Expand Down Expand Up @@ -785,6 +793,16 @@ def test_encoder_decoder_save_load_from_encoder_decoder_from_pt(self):
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)

# Make sure `from_pretrained` following `save_pretrained` work and give the same result
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
encoder_decoder_tf = TFEncoderDecoderModel.from_pretrained(tmp_dirname)

logits_tf_2 = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits

max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy()))
self.assertAlmostEqual(max_diff, 0.0, places=3)

# TensorFlow => PyTorch
with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname)
Expand Down

0 comments on commit a67d47b

Please sign in to comment.