Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move eos_token_id to stopping criteria #29459

Merged
merged 29 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
584de18
add eos stopping criteria
zucchini-nlp Mar 5, 2024
79a47c4
minor fix
zucchini-nlp Mar 5, 2024
6c93f8b
Update tests/generation/test_stopping_criteria.py
zucchini-nlp Mar 5, 2024
f59b83f
check eos is not None and fix tests
zucchini-nlp Mar 5, 2024
8ebad2d
make style and fixup
zucchini-nlp Mar 5, 2024
3e2507b
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
b77b6ab
Update tests/generation/test_utils.py
zucchini-nlp Mar 6, 2024
edc76df
Update tests/generation/test_utils.py
zucchini-nlp Mar 6, 2024
f71a687
Update src/transformers/generation/__init__.py
zucchini-nlp Mar 6, 2024
387be0e
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
14ece04
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
bc3eea9
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
7518aea
camel case everywhere
zucchini-nlp Mar 6, 2024
8e5ec57
call stopping criteria list for candidate ids
zucchini-nlp Mar 7, 2024
2544d12
make style and fixup
zucchini-nlp Mar 7, 2024
12acbc4
Empty commit
zucchini-nlp Mar 7, 2024
1107673
Empty commit to pass flaky test
zucchini-nlp Mar 7, 2024
1ffc554
set max length in PromptLookupCandidateGenerator
zucchini-nlp Mar 7, 2024
ca0a414
Merge 'upstream/main' into stopping_criteria
zucchini-nlp Mar 7, 2024
ce093c1
Update src/transformers/generation/utils.py
zucchini-nlp Mar 8, 2024
aa7fae4
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 8, 2024
5375d97
lets fix this typo in docs
zucchini-nlp Mar 9, 2024
d103d29
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 15, 2024
9f59abb
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 18, 2024
48f33fc
Update src/transformers/generation/utils.py
zucchini-nlp Mar 26, 2024
801af07
Update src/transformers/generation/utils.py
zucchini-nlp Mar 26, 2024
a385c6d
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 26, 2024
7c00bb1
update PR
zucchini-nlp Mar 26, 2024
530064d
empty commit
zucchini-nlp Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"EosTokenCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
Expand Down Expand Up @@ -216,6 +217,7 @@
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
The number of tokens to be output as candidate tokens.
max_length (`int`):
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
Defaults to 20, which is the max length used as default in generation config.
"""

def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
max_length: int = 20,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
self.max_length = max_length

if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
Expand All @@ -273,6 +278,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
"""
input_length = input_ids.size(1)

# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
return input_ids, None

chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
Expand All @@ -292,7 +301,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)
end_idx = min(end_idx, input_length, self.max_length)

if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
Expand Down
23 changes: 22 additions & 1 deletion src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from abc import ABC
from copy import deepcopy
from typing import Optional
from typing import List, Optional, Union

import torch

Expand Down Expand Up @@ -129,6 +129,27 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


class EosTokenCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
By default, it uses the `model.generation_config.eos_token_id`.

Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Comment on lines +138 to +139
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to also support list of lists ? To have stopping tokens (if the eos is ["<","eos",">"])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is it really true for existing models to have eos tokens as a sequence? I guess you are referring to custom eos tokens, when users want to stop at "" but they haven't trained the model with "" as special token. If that's the case, there is StopStringsCriteria PR coming or users are free to write their custom criteria

@gante wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the case where we want to stop on a string will be covered by #28932 (and is way more complex)

This PR is exclusively to port the existing EOS logic into its own stopping criteria :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I forgot it was being covered

"""

def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = torch.tensor(eos_token_id)

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done


class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
Expand Down