Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cross attentions to TFGPT2Model #14038

Merged

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Oct 17, 2021

What does this PR do?

Add cross attention to TFGPT2.

This was previously done in #13222, but we decided to move this to a new PR.
I also added TFGPT2EncoderDecoderModelTest with test_bert2gpt2_summarization.

tokenizer_in = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer_out = AutoTokenizer.from_pretrained("gpt2")

"""Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.layer...`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add another data point to the "We need to think about how to do cross-framework model loading better" chart

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to the previous PR adding cross-attention to other models. Tests looks good and are passing, so I'm happy to approve it.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Oct 19, 2021

Just correct the line about

"""Not working, because pt checkpoint has `encoder.encoder.layer...` while tf model has `encoder.bert.layer...`.

For tf here, it has encoder.bert.encoder.layer... instead.

@ydshieh ydshieh force-pushed the add_crossattn_to_tf_gpt2 branch from 4b93999 to 403600d Compare October 28, 2021 13:26
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me but asking @patrickvonplaten for a review as wel

@patrickvonplaten
Copy link
Contributor

Sorry, I tried fixing some merge conflicts, but I think this introduced a new error. @ydshieh could you maybe quickly go into the PR again and fix those last tests? :-) The PR looks good for me otherwise!

@ydshieh
Copy link
Collaborator Author

ydshieh commented Nov 2, 2021

Sorry, I tried fixing some merge conflicts, but I think this introduced a new error. @ydshieh could you maybe quickly go into the PR again and fix those last tests? :-) The PR looks good for me otherwise!

It's OK now :)

@patrickvonplaten patrickvonplaten merged commit bd21ed4 into huggingface:master Nov 3, 2021
@ydshieh ydshieh deleted the add_crossattn_to_tf_gpt2 branch November 3, 2021 09:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants