Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Fix weight loading issue" #14406

Closed

Conversation

patrickvonplaten
Copy link
Contributor

Reverts #14016

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Nov 15, 2021

Sorry for reverting the PR here - it's on me! I merged it too quickly. We had some internal discussion and came to the conclusion that this hack is probably not worth the functionality it would give us here.

Saving and loading a model with tempfile inside the from_encoder_decoder_pretrained(...) function is a big hack and it's questionable whether it's worth it.

Just to compare the current design to how it would look like if we revert the PR for @LysandreJik @sgugger @Rocketknight1

If we leave master as it is, one can convert a PyTorch model checkpoint correctly as follows:

current design

from transformers import EncoderDecoderModel, TFEncoderDecoderModel

_model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
_model.encoder.save_pretrained("./encoder")
_model.decoder.save_pretrained("./decoder")

model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
    "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
)

# then this works:
model.save_pretrained("./")
model = TFEncoderDecoderModel.from_pretrained("./")

If we remove the hack, the (in my opinion currently only way) to convert a PT checkpoint to TF is the following:

design after removing hack

from transformers import EncoderDecoderModel, TFEncoderDecoderModel, TFAutoModel, TFAutoModelForCausalLM

_model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
_model.encoder.save_pretrained("./encoder")
_model.decoder.save_pretrained("./decoder")

# all these lines are currently done automatically. There is not really a way around doing them if we remove the hack IMO
_encoder = TFAutoModel.from_pretrained("./encoder", from_pt=True)
_decoder = TFAutoModelForCausalLM.from_pretrained("./decoder", from_pt=True)
_encoder.save_pretrained("./encoder")
_decoder.save_pretrained("./decoder")

model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("./encoder", "./decoder")

# then this works:
model.save_pretrained("./")
model = TFEncoderDecoderModel.from_pretrained("./")

So we can see that removing the hack would force the user to do the exact same thing we are doing right now

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Nov 15, 2021

Give that the hack only lives in modeling_tf_encoder_decoder.py and having thought about it again, I'm actually in favor of not merging this PR, but I defer to @LysandreJik and @sgugger to decide here.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 15, 2021

No problem for me. I leave HF members to make the decision. Just to make this clear to users would be fine on my side 😀

@LysandreJik
Copy link
Member

Agree to close this and keep the current hack, as long as we mention that the TFEncoderDecoder is experimental.

In my opinion, TensorFlow is globally ill-suited for managing several models into a single one like it is done here, and it will always have some hacky/kind-of-broken edge cases.

I would advocate for keeping the work @ydshieh has done so far and see if the community appreciates/uses the feature before spending time refactoring this complex piece of software.

@sgugger
Copy link
Collaborator

sgugger commented Nov 15, 2021

I agree with you two @LysandreJik and @patrickvonplaten, even if I'm really not a fan of the hack behind the scenes. Let's worry about making it better when we have a wide adoption of the TFEncoderDecoder :-)

@patrickvonplaten patrickvonplaten deleted the revert-14016-fix_tf_enc_dec_weight_loading branch November 15, 2021 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants