-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py #19477
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@sgugger @patrickvonplaten context: this is the implementation by the authors of this NeurIPS paper, as first proposed in #19182 -- a new generation strategy with very interesting results! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this new method. The added code in generate
looks clean to me, I just left a couple of nits.
For the example, it would really be good if you could write a new one leveraging the auto-APIs instead of copying the old run_generation.py
which is severly outdated and only works for very few models.
src/transformers/generation_utils.py
Outdated
@@ -3446,3 +3766,152 @@ def top_k_top_p_filtering( | |||
) | |||
|
|||
return logits | |||
|
|||
|
|||
# ========== utils for contrastive search decoding method ========= # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those utils would probably be best in a submodule (like we have generation_beam_search
or generation_beam_constraint
) to avoid generation_utils
being too big.
# kwargs["language"] = tokenizer.lang2id[language] | ||
|
||
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers | ||
# XLM masked-language modeling (MLM) models need masked token | ||
# is_xlm_mlm = "mlm" in args.model_name_or_path | ||
# if is_xlm_mlm: | ||
# kwargs["mask_token_id"] = tokenizer.mask_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clean up?
Hi, @sgugger, thank you so much for your suggestions. I will fix these problems quickly! |
Hello @sgugger, is there any document or introduction for auto-APIs? |
The documentation would be the place to start. You can also look at all other examples! |
Hello, @sgugger, I have fixed the problems based on your valuable suggestions! Besides, I have updated the test scripts to the auto-APIs of inference. The command line to run this test script can be found in its docstring. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly have nits on the docstrings. The new example looks great, thanks a lot!
Oh, I am still working on the integration test. |
@gmftbyGMFTBY you probably have to add the Our CI doesn't run tests with |
Ok, I got it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, seems ready to go in (except for the minor comment I added)
@patrickvonplaten can you do a final check, please? :)
src/transformers/generation_utils.py
Outdated
# 10. prepare logits warper: get the TopKLogitsWarper for contrastive_search | ||
logits_warper = self._get_logits_warper( | ||
top_k=top_k, | ||
top_p=top_p, | ||
typical_p=typical_p, | ||
temperature=temperature, | ||
num_beams=num_beams, | ||
renormalize_logits=renormalize_logits, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't have been removed :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During the test, I found that adding the logits_warper (the TopKLogitsWarper is used) could influence the generations of contrastive search. Because the TopKLogitsWarper filters the logits of other tokens (not Top-k tokens) and calculates the softmax
, the model confidence
is different from the case that the TopKLogitsWarper is not used.
So, in this case, I think the contrastive search should disable the logits_warper by default.
Is there any solution that the TopKLogitsWarper warper is not activated for contrastive search?
For example, edit the _get_logits_warper
function and pass the penalty_alpha
parameter to indicate whether the top_k
parameter is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. In that case, I think we can do without the logits warper for now :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay!
src/transformers/generation_utils.py
Outdated
return self.contrastive_search( | ||
input_ids, | ||
top_k=top_k, | ||
penalty_alpha=penalty_alpha, | ||
logits_processor=logits_processor, | ||
logits_warper=logits_warper, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(related to the comment above)
src/transformers/generation_utils.py
Outdated
logits_processor: Optional[LogitsProcessorList] = None, | ||
logits_warper: Optional[LogitsProcessorList] = None, | ||
stopping_criteria: Optional[StoppingCriteriaList] = None, | ||
max_length: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length: Optional[int] = None, |
src/transformers/generation_utils.py
Outdated
stopping_criteria (`StoppingCriteriaList`, *optional*): | ||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | ||
used to tell if the generation loop should stop. | ||
max_length (`int`, *optional*, defaults to 20): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length (`int`, *optional*, defaults to 20): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe not add a deprecated argument :-)
src/transformers/generation_utils.py
Outdated
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | ||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() | ||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | ||
if max_length is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should remove this if statement
) | ||
# compute the candidate tokens by the language model and collects their hidden_states | ||
output = self(output_hidden_states=True, **next_model_inputs) | ||
past_key_values = output.past_key_values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) - not all language gerenation models can return past_key_values
(e.g. TransfoXL or XLNet) these models are still surprisingly used a lot:
- https://huggingface.co/xlnet-base-cased (> 200k downloads)
Maybe we could add a better error message here?
past_key_values = output.past_key_values | |
if "past_key_values" not in output: | |
raise ValueError(f"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive generation.") | |
past_key_values = output.past_key_values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay!
items = [] | ||
# item is either the key or the value matrix | ||
for item in layer: | ||
bsz_and_beam, num_head, seq_len, esz = item.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it holds always true that the past_key_values
have this size. Did we test contrastive search on all of the following models:
- GPT2
- T5
- GPT-J
- BART
?
It should work at least on those 4 models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, GPT2, T5, BART, GPT-J, and OPT models work fine.
Our implementation is compatible with the encoder-decoder models (the degeneration penalty is calculated on the decoder's hidden states). But we didn't carefully conduct the human evaluation of the encoder-decoder models, such as T5 and BART. Whether contrastive search could significantly boost their performance is still an open problem for us.
@@ -1693,6 +1693,25 @@ def test_diverse_beam_search(self): | |||
], | |||
) | |||
|
|||
@slow | |||
def test_contrastive_search(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible I'd be really happy if we could also test this on BART, T5 and GPT-J . Then we should have covered 95% of the model architectures. But ok to do in a follow-up PR . Currently I don't expect the method to work with T5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me to merge, but it'd be nice to always make sure the method works for all T5, BART, GPT2 and GPT-J. Also, we currently have only slow tests which is dangerous given that changes in generate can also affect contrastive search.
If it doesn't take too much time, I'd advocate to at least add 7 more tests:
- 4 fast tests with dummy models that just check constrastive search outputs the correct shape, one for each GPT2, T5, BART, GPT-J
- 3 more slow tests exactly like
test_contrastive_search
for T5, BART, GPT-J
I leave it up to you @gante to decide :-)
Overall, great work! Thanks a lot @gmftbyGMFTBY for adapting the code so quickly here!
EDIT: Sorry, actually let's please remove the deprecated max_length
before merging - that's actually a "must-do" before merging IMO (so not a fully approval here 😅 )
Okay, I am working on it! Thanks a lot for your reviews! |
BTW @gmftbyGMFTBY, Just read a through your extremely nice issue! It seems like you experimented with OPT as well, so maybe let's add a test for OPT as well then ? :-) OPT's Also, if the paper is only concerned with open-ended generation (so less with encoder-decoder architectures), I'm also totally fine with not testing for T5 and BART (it's a nice to have, but if it takes too much time and it's not too important - happy to skip it!). Regarding the fast dummy test, could you maybe make use of those dummy models:
The tests colud look very similar to:
just much shorter, i.e. they only need to test for shape equality. |
Yeah, we have already tested the OPT models, and it works fine. I will supply more tests to the pre-trained models that you mentioned. |
@patrickvonplaten more tests about these models are added:
These tests are passed successfully. Can you do the final check about this PR? |
Thank you for being part of this process @gmftbyGMFTBY 🙌 All queries have been addressed and the PR looks in a good state, merging! |
@gante @patrickvonplaten @sgugger Wow, Thank you very much for your help and support. Love huggingface team! |
@gante @patrickvonplaten @sgugger -- Many thanks for your kind help throughout the process! It means a great deal to me and @gmftbyGMFTBY. Huggingface is the best! |
Great work @gmftbyGMFTBY and @yxuansu, thanks for bearing with us through the PR :-) |
…he codebase of generation_utils.py (huggingface#19477) * add: the contrastive search for generaton_utils * add: testing scripts for contrastive search under examples/text-generation * update the quality of codes * revise the docstring; make the generation_contrastive_search.py scripts; * revise the examples/pytorch/text-generation/run_generation_contrastive_search.py to the auto-APIs format * revise the necessary documents * fix: revise the docstring of generation_contrastive_search.py * Fix the code indentation * fix: revise the nits and examples in contrastive_search docstring. * fix the copyright * delete generation_contrastive_search.py * revise the logic in contrastive_search * update the intergration test and the docstring * run the tests over * add the slow decorate to the contrastive_search intergrate test * add more test * do the style, quality, consistency checks
Adding the state-of-the-art contrastive search decoding method for the
generation_utils
codebaseFixes #19182
In this PR, I add the source codes of our proposed state-of-the-art decoding methods for the off-the-shelf neural text generation models. The main changes are in the following files: (1)
src/transformers/generation_utils.py
; (2)examples/pytorch/text-generation/run_generation_contrastive_search.py
. To run the test script, please follow these commands:Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
According to the suggestions of @gante, @patrickvonplaten and @sgugger can review this PR.