-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Generate] Facilitate PyTorch generate using ModelOutputs
#6735
[Generate] Facilitate PyTorch generate using ModelOutputs
#6735
Conversation
5ebc827
to
670ee39
Compare
Codecov Report
@@ Coverage Diff @@
## master #6735 +/- ##
==========================================
- Coverage 80.48% 79.46% -1.03%
==========================================
Files 157 157
Lines 28794 28822 +28
==========================================
- Hits 23175 22903 -272
- Misses 5619 5919 +300
Continue to review full report at Codecov.
|
ModelOutputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely improves readability, awesome!
Since it's in a risky part of the repo, I made you an alias to run tons of integration tests:
run_generation_integration_tests () {
# assumes USE_CUDA is exported, rather than passed
RUN_SLOW=1 pytest tests/test_modeling_pegasus.py
RUN_SLOW=1 pytest tests/test_modeling_bart.py
RUN_SLOW=1 pytest tests/test_modeling_t5.py
RUN_SLOW=1 pytest tests/test_modeling_marian.py
RUN_SLOW=1 pytest tests/test_modeling_mbart.py
RUN_SLOW=1 pytest tests/test_modeling_encoder_decoder.py
RUN_SLOW=1 pytest tests/test_pipelines.py
RUN_SLOW=1 pytest tests/test_modeling_gpt2.py
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool, love the refactor.
I don't think the breaking change is an issue: the users should not have used the past output to get the encoder outputs given that the encoder outputs is another output in the same tuple.
LGTM, very cool!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks a lot for all the work! The renamings are fine by me but should be done before next release IMO.
Also, not that forcing return_dict=True
will not work with jit, but I don't know if there are plans to have the generate method support tracing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Looks like we can now deprecate the _use_cache
function in the GenerationMixin
, no?
0eee734
to
e2166b1
Compare
yes! |
IMPORTANT This PR does a bigger renaming from "decoder_past_key_values" to "past_key_values" as suggested by @sshleifer. This required changes for Would be great if @LysandreJik (and @sgugger, @sshleifer depending on time difference) can review this quickly one last time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from decoder_past_key_values
to past_key_values
looks good to me as you implemented backwards compatiblity. LGTM!
@sshleifer - all EncoderDecoder Slow tests pass. There was one bart test that failed because of Broken Internet connection. I ran this single test again separately and it was fine. PR looks good to me now -> merging. |
…ace#6735) * fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
…ace#6735) * fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
…uggingface#6735)" This reverts commit bf4184e.
This PR:
return_dict=True
for generation in PyTorch. This should not lead to any problem because .generate() cannot be used to compute gradientsencoder_outputs
is simplified for Bart, T5 and the Encoder. Previously, there was an ugly hack that forces the second position of the encoder/decoder outputs for T5 and Bart to be a tuple containing bothdecoder_past_key_values
andencoder_outputs
whereasencoder_outputs
was a duplicated output. With the new cleaner API, only thedecoder_past_key_values
are returned because theencoder_outputs
can be accessed differently..generate()
was a feature many people asked for. This will now be made possible by forcing to use "return_dict" in .generate() so that "decoder_attentions" and "attentions" can be accessed by keyword.Important: The new handling of
encoder_outputs
introduces a small breaking change in that "output[1]" is now not a mixed tuple of encoder_outputs anddecoder_past_key_values
, but onlydecoder_past_key_values
, whereas the encoder_outputs can be accessed as before.