Skip to content

Commit

Permalink
Move eos_token_id to stopping criteria (#29459)
Browse files Browse the repository at this point in the history
* add eos stopping criteria

* minor fix

* Update tests/generation/test_stopping_criteria.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* check eos is not None and fix tests

* make style and fixup

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/__init__.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* camel case everywhere

* call stopping criteria list for candidate ids

* make style  and fixup

* Empty commit

* Empty commit to pass flaky test

* set max length in PromptLookupCandidateGenerator

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* lets fix this typo in docs

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update PR

* empty commit

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 27, 2024
1 parent 31c575b commit 0efcf32
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 76 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"EosTokenCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
Expand Down Expand Up @@ -216,6 +217,7 @@
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
The number of tokens to be output as candidate tokens.
max_length (`int`):
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
Defaults to 20, which is the max length used as default in generation config.
"""

def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
max_length: int = 20,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
self.max_length = max_length

if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
Expand All @@ -264,6 +269,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
"""
input_length = input_ids.size(1)

# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
return input_ids, None

chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
Expand All @@ -283,7 +292,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)
end_idx = min(end_idx, input_length, self.max_length)

if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
Expand Down
23 changes: 22 additions & 1 deletion src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from abc import ABC
from copy import deepcopy
from typing import Optional
from typing import List, Optional, Union

import torch

Expand Down Expand Up @@ -129,6 +129,27 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


class EosTokenCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the "end-of-sequence" token is generated.
By default, it uses the `model.generation_config.eos_token_id`.
Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""

def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = torch.tensor(eos_token_id)

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done


class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
Expand Down

0 comments on commit 0efcf32

Please sign in to comment.