-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cross attentions to TFGPT2Model #14038
Add cross attentions to TFGPT2Model #14038
Conversation
tokenizer_in = AutoTokenizer.from_pretrained("bert-base-cased") | ||
tokenizer_out = AutoTokenizer.from_pretrained("gpt2") | ||
|
||
"""Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.layer...`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add another data point to the "We need to think about how to do cross-framework model loading better" chart
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very similar to the previous PR adding cross-attention to other models. Tests looks good and are passing, so I'm happy to approve it.
Just correct the line about
For tf here, it has |
4b93999
to
403600d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me but asking @patrickvonplaten for a review as wel
Sorry, I tried fixing some merge conflicts, but I think this introduced a new error. @ydshieh could you maybe quickly go into the PR again and fix those last tests? :-) The PR looks good for me otherwise! |
It's OK now :) |
What does this PR do?
Add cross attention to TFGPT2.
This was previously done in #13222, but we decided to move this to a new PR.
I also added
TFGPT2EncoderDecoderModelTest
withtest_bert2gpt2_summarization
.