-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom stop token ids for generation #20727
Add custom stop token ids for generation #20727
Conversation
ed18fb2
to
6f0812d
Compare
The documentation is not available anymore as the PR was closed or merged. |
cc @gante |
Think we could actually allow |
Hi @tokestermw 👋 Like my colleagues, I also think this would be a helpful feature! I also agree with @patrickvonplaten, allowing the existing argument ( In a nutshell, if unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) as long as we always cast @tokestermw WDYT? |
Got it thanks for the suggestion! I can certainly make it so we use
eos_token_id.
It is also to port to TF/FLAX, which do not use StoppingCriterion.
ah good to know :)
I can look at this again this weekend
…On Fri, Dec 16, 2022 at 8:59 AM, Joao Gante ***@***.***> wrote:
Hi @tokestermw <https://github.com/tokestermw> [image: 👋]
Like my colleagues, I also think this would be a helpful feature! I also
agree with @patrickvonplaten <https://github.com/patrickvonplaten>,
allowing the existing argument (eos_token_id) to also accept list of
integers would result in a cleaner interface and fewer lines of code to
maintain :) It is also to port to TF/FLAX, which do not use
StoppingCriterion.
In a nutshell, if eos_token_id can be a list of integers, we can replace the
existing check
<https://github.com/huggingface/transformers/blob/26dd041c6e45379141302e2d293ab4cd9cf805d4/src/transformers/generation/utils.py#L2154>
with
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long())
as long as we always cast eos_token_id to a list before the generation
loop. In other words, 2 lines of change (per generation method) would
probably do the trick!
@tokestermw <https://github.com/tokestermw> WDYT?
—
Reply to this email directly, view it on GitHub
<#20727 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABEA3R6MBS7EMZQVFCS4POTWNSNXPANCNFSM6AAAAAAS3PYPPE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Hi @gante,
You can see the changes here: https://github.com/tokestermw/transformers/pull/1/files If this change looks good, I can merge into this PR, and start polishing (fixing tests, docs, remove dead code, etc.). thanks! |
@tokestermw that's a comprehensive set of changes, it looks great to me! ❤️ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this and nice new tests! Just make sure all the docstrings using eos_token_id
are updated and we should be good to merge!
@@ -183,7 +183,7 @@ class GenerationConfig(PushToHubMixin): | |||
The id of the *padding* token. | |||
bos_token_id (`int`, *optional*): | |||
The id of the *beginning-of-sequence* token. | |||
eos_token_id (`int`, *optional*): | |||
eos_token_id (`Union[int, List[int]]`, *optional*): | |||
The id of the *end-of-sequence* token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we adapt the text of the doc here? The docstrings also need to be updated in beam_search.py
FYI.
@@ -395,11 +398,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): | |||
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words | |||
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, | |||
add_special_tokens=False).input_ids`. | |||
eos_token_id (`int`): | |||
eos_token_id (`Union[int, List[int]]`): | |||
The id of the *end-of-sequence* token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here.
@@ -671,23 +684,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): | |||
exponential_decay_length_penalty (`tuple(int, float)`, *optional*): | |||
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty | |||
starts and `decay_factor` represents the factor of exponential decay | |||
eos_token_id (`int`): | |||
eos_token_id (`Union[int, List[int]]`): | |||
The id of the *end-of-sequence* token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here too!
* Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup
* Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup
Ideally, generation should stop at '\n', but this feature is brand new on transformers (huggingface/transformers#20727)
Is this feature already available on the transformers version available through pip (4.25.1)? I have tried enabling it and the generation continued on even though I set
(I'm also not sure why 2 integers are returned by EDIT Nevermind, I got it working
|
* Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup
* Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup
Ideally, generation should stop at '\n', but this feature is brand new on transformers (huggingface/transformers#20727)
Update (using eos_token_id instead): #20727 (comment)
What does this PR do?
Hi 🤗 team!
This adds stop token ids inside, e.g.
model.generate(..., stop_token_ids=[10, 25])
, and syntactic sugar for the generation pipelines, e.g.pipeline(..., stop_tokens=['\n'])
. When the generation detects the specified token ids for all examples in the batch, it will stop.Rationale
max_new_tokens
without digging intoStoppingCriterion
.StoppingCriteria
objects.Usage Example
How to Test
Related PR(s)
There is a
stop_sequence
argument for theTextGeneration
pipeline: #18444But it's limited to a single token, only in the text generation pipeline, and overwrites
eos_token_id
. Instead, we useStoppingCriteria
directly.This PR is a bit overlapping with above, so please let me know if this approach is not optimal.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings. (☢️ noting i've tried to update the docs from the instructions, but they don't seem correct)
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten @Narsil