-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use cross_attention_hidden_size in Encoder-Decoder models #14378
Use cross_attention_hidden_size in Encoder-Decoder models #14378
Conversation
I ran slow tests for all the encoder-decoder models test scripts, and it is fine. (e.g. BTW, is there an easy way to run all cross tests in a test script, i.e. disabling |
encoder(encoder.dummy_inputs) | ||
decoder(decoder.dummy_inputs) | ||
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TF encoder-decoder model family doesn't work smoothly with checkpoint loading, and requires some hacks to make it working.
In the case here, if a TF composite model (whose weights are created under the scope of the top model) saves its encoder/decoder component separately, the 2 checkpoints will contain the top model names, i.e. the encoder/decoder checkpoint weights will begin with tf_encoder_decoder_model
.
This causes problems when we want to load them again, in particular, in from_encoder_decoder_pretrained.
However, if a TF composite model is constructed by having the encoder & decoder models first, their weight names don't have the top model name, and we can save the 2 components and reload them again.
P.S.: Once PR #14016 is merged, the equivalence tests need to be reworked in order to pass.
# self.assertTrue(config.hidden_size != decoder_config.hidden_size) | ||
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict) | ||
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no easy way to deal with enc_to_dec_proj
for TF composite models with regard to checkpoint loading, while we need to load the encoder/decoder components separately.
if ( | ||
self.encoder.config.hidden_size != self.decoder.config.hidden_size | ||
and self.decoder.config.cross_attention_hidden_size is None | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made this block the same as in other encoder/decoder models.
if ( | ||
self.encoder.config.hidden_size != self.decoder.config.hidden_size | ||
and self.decoder.config.cross_attention_hidden_size is None | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made this block the same as in other encoder/decoder models.
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. " | ||
f"In this case make sure that `encoder_pretrained_model_name_or_path` defined" | ||
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " | ||
"to be defined." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it to be the same as the corresponding occurrence in other encoder decoder models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for adding this consistency.
Hey @ydshieh, We need to slightly update this PR for the speech encoder decoder classes sadly so that the newly introduced variable transformers/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py Line 227 in cea17ac
The other files can stay the same :-) |
No problem, @patrickvonplaten. But I have a slight doubt at this line:
Should it be
if |
f5c0df5
to
7b9d31a
Compare
I made the necessary updates where
despite a slight doubt. |
(Fixed) The failed TF/Torch test is due to #14016 being merged to master (and I rebased this PR on master), which is expected. I will take care of this issue. |
d78af6a
to
f178e1d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this!
What does this PR do?
Add a projection layer (
enc_to_dec_proj
) between encoder and decoder models in composite models, incorporating the attributecross_attention_hidden_size
.pt/tf equivalence
andpt/flax equivalence
tests in tf/flax composite model test scripts.