-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TFBartForConditionalGeneration #5411
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, very clean! The testing suite is impressive.
There are only a few items left to do before we merge:
- Could you add
TFBartModel
andTFBartForConditionalGeneration
to the auto models? - Please remove the "defaults to :obj:
None
" in docstrings, we don't do that anymore - Please make sure that docstrings are not larger than 119 characters like we do in the rest of the repo. You can just add a line return and ensure the following line is at the same indentation level to make sphinx happy.
- Most assertions do not have an error message. Please add messages so that they're easier to debug for users and for us.
Also, you seem to have added all the if/else statement corresponding to mBART, Pegasus and Blenderbot. Have you tried using these models with this code? If it works, doing PR on the port of all of these models in TF would be big!
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. | ||
Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the model weights. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch 👌
|
||
@require_torch | ||
# @slow | ||
class FastIntegrationTests(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool test!
tests/test_modeling_tf_bart.py
Outdated
def test_compile_tf_model(self): | ||
# This passes for TFBartForConditionalGeneration, fails for TFBartModel | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it fail? Compilation seems like something necessary in TF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This passes for TFBartForConditionalGeneration
.
To make it pass for TFBartModel
, the decoder and encoder need to always return Tuple
and a bunch of other hacks @patrickvonplaten had to add for T5 (like https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_t5.py#L1149)
Since so few people use BartModel
directly, I decided that supporting this compilation case (that would likely never be used) was not worth losing the readability benefit of ModelOutputs
and also adding 30 lines of annoying/hacky code. And that if we wanted to add that feature later we could revisit my decision.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is one of the most necessary because compile
is a very important function in TF, basically if you cannot compile you cannot train it with .fit()
which will become a big problem knowing that the TF Trainer will move to .compile()
+ .fit()
to train a model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I added test coverage to make sure that TFBartForConditionalGeneration
can be compiled.
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Thanks for the review @LysandreJik !
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the cleanup! LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks a lot for iterating!
* half done * doc improvement * Cp test file * brokedn * broken test * undo some mess * ckpt * borked * Halfway * 6 passing * boom boom * Much progress but still 6 * boom boom * merged master * 10 passing * boom boom * Style * no t5 changes * 13 passing * Integration test failing, but not gibberish * Frustrated * Merged master * 4 fail * 4 fail * fix return_dict * boom boom * Still only 4 * prepare method * prepare method * before delete classif * Skip tests to avoid adding boilerplate * boom boom * fast tests passing * style * boom boom * Switch to supporting many input types * remove FIXMENORM * working * Fixed past_key_values/decoder_cached_states confusion * new broken test * Fix attention mask kwarg name * undo accidental * Style and reviewers * style * Docs and common tests * Cleaner assert messages * copy docs * style issues * Sphinx fix * Simplify caching logic * test does not require torch * copy _NoLayerEmbedTokens * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update tests/test_modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Line length and dont document None * Add pipeline test coverage * assert msg * At parity * Assert messages * mark slow * Update compile test * back in init * Merge master * Fix tests Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This reverts commit eb42d58.
TFBartForConditionalGeneration
, which can generate summaries that are equivalent to pytorch.TODO this PR:
adjust_logits_during_generation
)Future PRs: