Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom stop token ids for generation #20727

Merged

Conversation

tokestermw
Copy link
Contributor

@tokestermw tokestermw commented Dec 12, 2022

Update (using eos_token_id instead): #20727 (comment)

What does this PR do?

Hi 🤗 team!

This adds stop token ids inside, e.g. model.generate(..., stop_token_ids=[10, 25]), and syntactic sugar for the generation pipelines, e.g. pipeline(..., stop_tokens=['\n']). When the generation detects the specified token ids for all examples in the batch, it will stop.

Rationale

  • It's common to set a stop id/token for text generation tasks. For example for dialogue, we may want to stop it when the speaker changes.
  • It's convenient to have arguments for stop tokens similar to max_new_tokens without digging into StoppingCriterion.
  • Some servers like DeepSpeed MII uses gRPC and it's difficult to pass StoppingCriteria objects.

Usage Example

# in pipeline
prompt = """Hello I believe in"""
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", stop_tokens=[' fe'])
text_generator(prompt)

# from generate
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = gpt2_tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device)

stop_token_ids = gpt2_tokenizer.encode(" fe")
gpt2_model.generate(input_ids=input_ids, stop_token_ids=stop_token_ids)

How to Test

pytest tests/generation/test_stopping_criteria.py::StoppingCriteriaTestCase::test_stop_token_id_criteria
pytest tests/generation/test_utils.py::GenerationIntegrationTests::test_stop_token_ids_stopping_criteria
pytest tests/pipelines/test_pipelines_text_generation.py::TextGenerationPipelineTests::test_stop_token_ids_stopping_criteria
pytest tests/pipelines/test_pipelines_text_generation.py::TextGenerationPipelineTests::test_stop_tokens_stopping_criteria

Related PR(s)

There is a stop_sequence argument for the TextGeneration pipeline: #18444

But it's limited to a single token, only in the text generation pipeline, and overwrites eos_token_id. Instead, we use StoppingCriteria directly.

This PR is a bit overlapping with above, so please let me know if this approach is not optimal.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • [] Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings. (☢️ noting i've tried to update the docs from the instructions, but they don't seem correct)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten @Narsil

@tokestermw tokestermw force-pushed the add-custom-stop-token-ids-for-generation branch from ed18fb2 to 6f0812d Compare December 12, 2022 05:21
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 12, 2022

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Collaborator

sgugger commented Dec 12, 2022

cc @gante

@patrickvonplaten
Copy link
Contributor

Think we could actually allow eos_token_id to be both an integer and a list of integers no ? Both in the config and in the input.

@gante
Copy link
Member

gante commented Dec 16, 2022

Hi @tokestermw 👋

Like my colleagues, I also think this would be a helpful feature! I also agree with @patrickvonplaten, allowing the existing argument (eos_token_id) to also accept list of integers would result in a cleaner interface and fewer lines of code to maintain :) It is also to port to TF/FLAX, which do not use StoppingCriterion.

In a nutshell, if eos_token_id can be a list of integers, we can replace the existing check with

unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long())

as long as we always cast eos_token_id to a list before the generation loop. In other words, 2 lines of change (per generation method) would probably do the trick!

@tokestermw WDYT?

@tokestermw
Copy link
Contributor Author

tokestermw commented Dec 16, 2022 via email

@tokestermw
Copy link
Contributor Author

Hi @gante,

  • Made eos_token_id into Union[int, List[int]] type. I convert into a list at the beginning of the respective functions. Also, looks like eos_token_id is used in a few more places, e.g. beam_search.py.
  • Some parts where we insert the eos_token_id, I only insert the first token id, here and here

You can see the changes here: https://github.com/tokestermw/transformers/pull/1/files

If this change looks good, I can merge into this PR, and start polishing (fixing tests, docs, remove dead code, etc.).

thanks!

@gante
Copy link
Member

gante commented Dec 21, 2022

@tokestermw that's a comprehensive set of changes, it looks great to me! ❤️

@tokestermw tokestermw changed the title Add custom stop token ids for generation ANTS-310: Add custom stop token ids for generation Dec 21, 2022
@patrickvonplaten
Copy link
Contributor

Awesome this looks nice to me, @gante @sgugger ok for you?

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 👍

@gante gante requested a review from sgugger January 2, 2023 17:16
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this and nice new tests! Just make sure all the docstrings using eos_token_id are updated and we should be good to merge!

@@ -183,7 +183,7 @@ class GenerationConfig(PushToHubMixin):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we adapt the text of the doc here? The docstrings also need to be updated in beam_search.py FYI.

@@ -395,11 +398,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.

@@ -671,23 +684,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too!

@sgugger sgugger merged commit 45da7ce into huggingface:main Jan 3, 2023
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jan 4, 2023
* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
gante added a commit that referenced this pull request Jan 4, 2023
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
oobabooga added a commit to oobabooga/text-generation-webui that referenced this pull request Jan 8, 2023
Ideally, generation should stop at '\n', but this feature is brand new
on transformers (huggingface/transformers#20727)
@oobabooga
Copy link
Contributor

oobabooga commented Jan 8, 2023

Is this feature already available on the transformers version available through pip (4.25.1)? I have tried enabling it and the generation continued on even though I set eos_token_id, in my case to

tokenizer.encode('\n', return_tensors='pt')[1]

(I'm also not sure why 2 integers are returned by encode instead of just 1)

EDIT

Nevermind, I got it working

n = tokenizer.encode('\n', return_tensors='pt')[0][1]
output = model.generate(input_ids, eos_token_id=n).cuda()

venkat-natchi pushed a commit to venkat-natchi/transformers that referenced this pull request Jan 22, 2023
* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
venkat-natchi pushed a commit to venkat-natchi/transformers that referenced this pull request Jan 22, 2023
miyu386 pushed a commit to miyu386/transformers that referenced this pull request Feb 9, 2023
* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
miyu386 pushed a commit to miyu386/transformers that referenced this pull request Feb 9, 2023
Ph0rk0z pushed a commit to Ph0rk0z/text-generation-webui-testing that referenced this pull request Apr 17, 2023
Ideally, generation should stop at '\n', but this feature is brand new
on transformers (huggingface/transformers#20727)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants