Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFBartForConditionalGeneration #5411

Merged
merged 88 commits into from Oct 21, 2020

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Jun 30, 2020

  • adds TFBartForConditionalGeneration, which can generate summaries that are equivalent to pytorch.

TODO this PR:

  • fast tests besides two
  • reasonable xsum generations
  • tests passing
  • fix slow cnn test (tf needs to call adjust_logits_during_generation)
  • functional dropout
  • simplify torch and tf caching logic
  • docs
  • upload applicable tf/h5 weights.

Future PRs:

@sshleifer sshleifer changed the title Bart tensorflow [WIP, dont merge] TFBart Jun 30, 2020
@sshleifer sshleifer linked an issue Jun 30, 2020 that may be closed by this pull request
@sshleifer sshleifer linked an issue Oct 15, 2020 that may be closed by this pull request
@sshleifer sshleifer changed the title [WIP] TFBart Add TFBartForConditionalGeneration Oct 15, 2020
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, very clean! The testing suite is impressive.

There are only a few items left to do before we merge:

  • Could you add TFBartModel and TFBartForConditionalGeneration to the auto models?
  • Please remove the "defaults to :obj:None" in docstrings, we don't do that anymore
  • Please make sure that docstrings are not larger than 119 characters like we do in the rest of the repo. You can just add a line return and ensure the following line is at the same indentation level to make sphinx happy.
  • Most assertions do not have an error message. Please add messages so that they're easier to debug for users and for us.

Also, you seem to have added all the if/else statement corresponding to mBART, Pegasus and Blenderbot. Have you tried using these models with this code? If it works, doing PR on the port of all of these models in TF would be big!

src/transformers/modeling_tf_bart.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_bart.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_bart.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_bart.py Show resolved Hide resolved
src/transformers/modeling_tf_bart.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_bart.py Outdated Show resolved Hide resolved
Comment on lines -720 to 721
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the model weights.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch 👌


@require_torch
# @slow
class FastIntegrationTests(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool test!

tests/test_modeling_tf_bart.py Outdated Show resolved Hide resolved
Comment on lines 115 to 117
def test_compile_tf_model(self):
# This passes for TFBartForConditionalGeneration, fails for TFBartModel
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it fail? Compilation seems like something necessary in TF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This passes for TFBartForConditionalGeneration.

To make it pass for TFBartModel, the decoder and encoder need to always return Tuple and a bunch of other hacks @patrickvonplaten had to add for T5 (like https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_t5.py#L1149)

Since so few people use BartModel directly, I decided that supporting this compilation case (that would likely never be used) was not worth losing the readability benefit of ModelOutputs and also adding 30 lines of annoying/hacky code. And that if we wanted to add that feature later we could revisit my decision.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is one of the most necessary because compile is a very important function in TF, basically if you cannot compile you cannot train it with .fit() which will become a big problem knowing that the TF Trainer will move to .compile() + .fit() to train a model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I added test coverage to make sure that TFBartForConditionalGeneration can be compiled.

sshleifer and others added 3 commits October 16, 2020 11:22
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
@sshleifer sshleifer mentioned this pull request Oct 16, 2020
sshleifer and others added 2 commits October 16, 2020 13:22
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
@sshleifer
Copy link
Contributor Author

sshleifer commented Oct 16, 2020

Thanks for the review @LysandreJik !

  • mBART, Pegasus and Blenderbot, and Marian will be in the next PR. (this is too big already for me to hold in my tiny brain).
  • Your 4 bullets: Will do!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the cleanup! LGTM!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot for iterating!

@LysandreJik LysandreJik merged commit 8298421 into huggingface:master Oct 21, 2020
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
* half done

* doc improvement

* Cp test file

* brokedn

* broken test

* undo some mess

* ckpt

* borked

* Halfway

* 6 passing

* boom boom

* Much progress but still 6

* boom boom

* merged master

* 10 passing

* boom boom

* Style

* no t5 changes

* 13 passing

* Integration test failing, but not gibberish

* Frustrated

* Merged master

* 4 fail

* 4 fail

* fix return_dict

* boom boom

* Still only 4

* prepare method

* prepare method

* before delete classif

* Skip tests to avoid adding boilerplate

* boom boom

* fast tests passing

* style

* boom boom

* Switch to supporting many input types

* remove FIXMENORM

* working

* Fixed past_key_values/decoder_cached_states confusion

* new broken test

* Fix attention mask kwarg name

* undo accidental

* Style and reviewers

* style

* Docs and common tests

* Cleaner assert messages

* copy docs

* style issues

* Sphinx fix

* Simplify caching logic

* test does not require torch

* copy _NoLayerEmbedTokens

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update tests/test_modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_bart.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Line length and dont document None

* Add pipeline test coverage

* assert msg

* At parity

* Assert messages

* mark slow

* Update compile test

* back in init

* Merge master

* Fix tests

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
TensorFlow Anything TensorFlow
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Does bart need to cache prev_key_padding_mask? TF BART ?
5 participants