Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move eos_token_id to stopping criteria #29459

Merged
merged 29 commits into from Mar 27, 2024

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

This PR is a small step for torch.compile and generate compatibility. It moves EOS token to stopping criteria, so now we can loop while stopping_criteria and get rid of extra checks on EOS at the end of each generate method.

All the generate tests, including slow are passing.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good 💪 A few comments to further refine the idea

In the body of generate, we set pad_token_id to eos_token_id when the latter exists and the former is None.

As such, we can further modify the decoding functions as follows:

  1. the block
            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - 
unfinished_sequences)

can become

            # finished sentences should have their next token be a padding token
            if pad_token_id is None:
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

because the case that raises the exception can never be triggered from generate (although we should add an integration test that ensures this remains true, if it doesn't already exist!).

  1. see my in-code comment :)

src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
tests/generation/test_stopping_criteria.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
@gante
Copy link
Member

gante commented Mar 5, 2024

Ah, the pipeline tests are ignored by default (just like slow tests). You need to add RUN_PIPELINE_TESTS=1 to run them :)

@zucchini-nlp
Copy link
Member Author

Done for all comments. I checked that users have to possibility of messing up when calling the these methods directly and ran again all tests (+slow, +pipeline)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for iterating 🙌

src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
@gante gante requested a review from ArthurZucker March 5, 2024 19:41
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, would be nice to make sure our assumption is always correct, nice cleanup otherwise!

src/transformers/generation/__init__.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
Comment on lines +138 to +139
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to also support list of lists ? To have stopping tokens (if the eos is ["<","eos",">"])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is it really true for existing models to have eos tokens as a sequence? I guess you are referring to custom eos tokens, when users want to stop at "" but they haven't trained the model with "" as special token. If that's the case, there is StopStringsCriteria PR coming or users are free to write their custom criteria

@gante wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the case where we want to stop on a string will be covered by #28932 (and is way more complex)

This PR is exclusively to port the existing EOS logic into its own stopping criteria :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I forgot it was being covered

Comment on lines 1922 to 1923
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should rather be set in the generation_config I guess? Why can't we leave this as a kwarg and set the generation_config.eos_token_id ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the general idea, we want to stop accepting eos as input argument. The warning is for backward compatibility, since we used to give priority to user defined eos_token_id from args before checking the generation_config.eos_token_id

Copy link
Member

@gante gante Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ as @zucchini-nlp wrote. This is merely for backward compatibility, if users want to keep calling the decoding methods directly (at their own risk, since we're deprecating their public API)

The decoding methods don't set up new logits processors nor stopping criteria. As such, The correct replacement for the EOS feature is to pass the new EOSTokenCriteria :)

The long-term goal is to disable calling decoding methods directly at all, so we can optimize the codebase.

.prod(dim=0)
.bool()
)
last_assistant_token_is_eos = stopping_criteria[-1](candidate_input_ids, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means we the last stopping_criteria is always EOSTokenCriteria. It's not necessarily obvious and if people use custom passed criteria I am not sure this will always be the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, will fix it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker good catch, I missed this one too :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante need you to look. I removed some code, given that assisted decoding works only with one batch, so we can make some assumptions about when to stop generating.

zucchini-nlp and others added 14 commits March 6, 2024 13:52
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One change in the speculative decoding diff and we're ready to go :)

.prod(dim=0)
.bool()
)
is_done_candidate = stopping_criteria(candidate_input_ids, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧠 (this is actually much more versatile than the previous version!)

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@gante
Copy link
Member

gante commented Mar 8, 2024

@zucchini-nlp needs a rebase with main and it should be ready to be merged

(@ArthurZucker I'm assuming this is ready to be merged, since all comments were addressed. I'm deciding to merge since this unblocks me on the torch.compile front, let us know if you like some post-merge changes :) )

@amyeroberts
Copy link
Collaborator

amyeroberts commented Mar 8, 2024

@gante Woah, hold on before merging! You still need a core maintainers approval: even if the comments have been addressed it's important to make sure the amendments are approved

@gante
Copy link
Member

gante commented Mar 18, 2024

@zucchini-nlp this PR needs to be rebased with main to fix CI :)

@gante
Copy link
Member

gante commented Mar 20, 2024

@ArthurZucker ping for approval, if all tasks are complete :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Make sure to run the slow tests for important models before merging! + testing generation with Llama and eos_token_id list!

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@@ -2364,10 +2370,26 @@ def _greedy_search(
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None:
warnings.warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

zucchini-nlp and others added 5 commits March 26, 2024 19:31
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@zucchini-nlp
Copy link
Member Author

Rebased main and ran all tests in generation again, including slow, ran a few generation models' tests. Everything is passing, can be merged now. @gante

No idea why the examples test is failing, does not have anything to do with this PR.

@gante gante merged commit 0efcf32 into huggingface:main Mar 27, 2024
21 checks passed
@agnosticlines
Copy link

Just a heads up this commit breaks transformers generation on Apple Silicon as isin is not implemented for the MPS backend

itazap pushed a commit that referenced this pull request May 14, 2024
* add eos stopping criteria

* minor fix

* Update tests/generation/test_stopping_criteria.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* check eos is not None and fix tests

* make style and fixup

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/__init__.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* camel case everywhere

* call stopping criteria list for candidate ids

* make style  and fixup

* Empty commit

* Empty commit to pass flaky test

* set max length in PromptLookupCandidateGenerator

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* lets fix this typo in docs

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update PR

* empty commit

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants