Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Generate] Facilitate PyTorch generate using ModelOutputs #6735

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 25, 2020

This PR:

  • forces to use return_dict=True for generation in PyTorch. This should not lead to any problem because .generate() cannot be used to compute gradients
  • The handling of encoder_outputs is simplified for Bart, T5 and the Encoder. Previously, there was an ugly hack that forces the second position of the encoder/decoder outputs for T5 and Bart to be a tuple containing both decoder_past_key_values and encoder_outputs whereas encoder_outputs was a duplicated output. With the new cleaner API, only the decoder_past_key_values are returned because the encoder_outputs can be accessed differently.
  • Fixes num_beams error in GPT2DoubleHead model #6319
  • Adds better documentation for the Encoder Decoder model + test + better example
  • Most importantly, this PR lays the groundwork for a better GenerationOutput object (@sgugger). Returning a list of all attentions and all hidden states when using .generate() was a feature many people asked for. This will now be made possible by forcing to use "return_dict" in .generate() so that "decoder_attentions" and "attentions" can be accessed by keyword.

Important: The new handling of encoder_outputs introduces a small breaking change in that "output[1]" is now not a mixed tuple of encoder_outputs and decoder_past_key_values, but only decoder_past_key_values, whereas the encoder_outputs can be accessed as before.

@codecov
Copy link

codecov bot commented Aug 25, 2020

Codecov Report

Merging #6735 into master will decrease coverage by 1.02%.
The diff coverage is 78.33%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6735      +/-   ##
==========================================
- Coverage   80.48%   79.46%   -1.03%     
==========================================
  Files         157      157              
  Lines       28794    28822      +28     
==========================================
- Hits        23175    22903     -272     
- Misses       5619     5919     +300     
Impacted Files Coverage Δ
src/transformers/modeling_bart.py 94.24% <50.00%> (-1.35%) ⬇️
src/transformers/modeling_t5.py 76.70% <57.14%> (-7.14%) ⬇️
src/transformers/modeling_tf_t5.py 89.57% <70.27%> (-1.37%) ⬇️
src/transformers/modeling_encoder_decoder.py 92.00% <93.33%> (-0.40%) ⬇️
src/transformers/generation_utils.py 96.93% <100.00%> (+0.26%) ⬆️
src/transformers/modeling_gpt2.py 86.82% <100.00%> (+0.14%) ⬆️
src/transformers/modeling_openai.py 23.87% <100.00%> (-48.39%) ⬇️
src/transformers/modeling_outputs.py 100.00% <100.00%> (ø)
src/transformers/modeling_tf_gpt2.py 95.01% <100.00%> (ø)
src/transformers/modeling_tf_openai.py 22.58% <100.00%> (-72.26%) ⬇️
... and 17 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a32d85f...190985c. Read the comment docs.

@patrickvonplaten patrickvonplaten changed the title [GPT2] Allow GPT2DoubleHead to generate [WIP, GPT2] Allow GPT2DoubleHead to generate Aug 25, 2020
@patrickvonplaten patrickvonplaten changed the title [WIP, GPT2] Allow GPT2DoubleHead to generate [Generate] Facilitate PyTorch generate using ModelOutputs Aug 26, 2020
Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely improves readability, awesome!
Since it's in a risky part of the repo, I made you an alias to run tons of integration tests:

run_generation_integration_tests () {
	# assumes USE_CUDA is exported, rather than passed
	RUN_SLOW=1 pytest tests/test_modeling_pegasus.py
	RUN_SLOW=1 pytest tests/test_modeling_bart.py
	RUN_SLOW=1 pytest tests/test_modeling_t5.py
	RUN_SLOW=1 pytest tests/test_modeling_marian.py
	RUN_SLOW=1 pytest tests/test_modeling_mbart.py
	RUN_SLOW=1 pytest tests/test_modeling_encoder_decoder.py
	RUN_SLOW=1 pytest tests/test_pipelines.py
	RUN_SLOW=1 pytest tests/test_modeling_gpt2.py
}

src/transformers/generation_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_utils.py Show resolved Hide resolved
src/transformers/generation_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_bart.py Show resolved Hide resolved
tests/test_modeling_encoder_decoder.py Show resolved Hide resolved
tests/test_modeling_encoder_decoder.py Outdated Show resolved Hide resolved
tests/test_modeling_encoder_decoder.py Show resolved Hide resolved
tests/test_modeling_encoder_decoder.py Show resolved Hide resolved
tests/test_modeling_t5.py Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool, love the refactor.

I don't think the breaking change is an issue: the users should not have used the past output to get the encoder outputs given that the encoder outputs is another output in the same tuple.

LGTM, very cool!

src/transformers/generation_utils.py Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks a lot for all the work! The renamings are fine by me but should be done before next release IMO.
Also, not that forcing return_dict=True will not work with jit, but I don't know if there are plans to have the generate method support tracing.

docs/source/model_doc/encoderdecoder.rst Outdated Show resolved Hide resolved
src/transformers/generation_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_utils.py Outdated Show resolved Hide resolved
src/transformers/generation_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_encoder_decoder.py Outdated Show resolved Hide resolved
src/transformers/modeling_encoder_decoder.py Outdated Show resolved Hide resolved
src/transformers/modeling_encoder_decoder.py Outdated Show resolved Hide resolved
src/transformers/modeling_encoder_decoder.py Show resolved Hide resolved
src/transformers/modeling_gpt2.py Show resolved Hide resolved
src/transformers/modeling_transfo_xl.py Show resolved Hide resolved
Copy link
Member

@yjernite yjernite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Looks like we can now deprecate the _use_cache function in the GenerationMixin, no?

@patrickvonplaten
Copy link
Contributor Author

LGTM! Looks like we can now deprecate the _use_cache function in the GenerationMixin, no?

yes!

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 1, 2020

IMPORTANT This PR does a bigger renaming from "decoder_past_key_values" to "past_key_values" as suggested by @sshleifer. This required changes for T5, TFT5 and Bart. For each of the three models it is made sure that decoder_past_values can still be used as an input to keep backwards compatibility.

Would be great if @LysandreJik (and @sgugger, @sshleifer depending on time difference) can review this quickly one last time.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from decoder_past_key_values to past_key_values looks good to me as you implemented backwards compatiblity. LGTM!

@patrickvonplaten
Copy link
Contributor Author

@sshleifer - all EncoderDecoder Slow tests pass. There was one bart test that failed because of Broken Internet connection. I ran this single test again separately and it was fine. PR looks good to me now -> merging.

@patrickvonplaten patrickvonplaten merged commit afc4ece into huggingface:master Sep 1, 2020
@patrickvonplaten patrickvonplaten deleted the fix_double_head_gpt2 branch September 1, 2020 10:43
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
…ace#6735)

* fix generate for GPT2 Double Head

* fix gpt2 double head model

* fix  bart / t5

* also add for no beam search

* fix no beam search

* fix encoder decoder

* simplify t5

* simplify t5

* fix t5 tests

* fix BART

* fix transfo-xl

* fix conflict

* integrating sylvains and sams comments

* fix tf past_decoder_key_values

* fix enc dec test
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
…ace#6735)

* fix generate for GPT2 Double Head

* fix gpt2 double head model

* fix  bart / t5

* also add for no beam search

* fix no beam search

* fix encoder decoder

* simplify t5

* simplify t5

* fix t5 tests

* fix BART

* fix transfo-xl

* fix conflict

* integrating sylvains and sams comments

* fix tf past_decoder_key_values

* fix enc dec test
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

num_beams error in GPT2DoubleHead model
5 participants