Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom stop token ids for generation #20727

Merged
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ecdd003
Add StopIdStoppingCriteria
tokestermw Dec 4, 2022
739674d
add a working test for stop id criteria
tokestermw Dec 4, 2022
3bce1cd
add to global scope
tokestermw Dec 4, 2022
99904eb
add stop_ids to generate
tokestermw Dec 4, 2022
20aeeca
add pipeline test
tokestermw Dec 4, 2022
683c320
use tokenizer encode in test
tokestermw Dec 5, 2022
2781b53
add test to generation utils
tokestermw Dec 5, 2022
5316041
reformat
tokestermw Dec 5, 2022
947af73
fixup
tokestermw Dec 5, 2022
15834a8
make-fix-copies
tokestermw Dec 5, 2022
0e6eb18
rename to stop_token_id
tokestermw Dec 10, 2022
8c6e474
use stop_tokens instead
tokestermw Dec 10, 2022
305e349
add to text to text generation
tokestermw Dec 10, 2022
64556a8
make fixup
tokestermw Dec 10, 2022
6f0812d
make repo-consistency
tokestermw Dec 10, 2022
70c2dda
Add support for list of ints for eos_token_id inside generation/utils.py
tokestermw Dec 19, 2022
b0cccd6
Instead of having if elses, cast the eos_token_id into a List[int]
tokestermw Dec 19, 2022
fba3345
Add List[int] support for logits_process.py
tokestermw Dec 19, 2022
405f79c
add List[int] for beam_search.py
tokestermw Dec 19, 2022
8298e87
add List[int] for forced_eos_token_id
tokestermw Dec 19, 2022
dd6f36d
Merge branch 'main' into add-custom-stop-token-ids-for-generation
tokestermw Dec 31, 2022
0062ddf
revert stop token id stopping criteria changes
tokestermw Dec 31, 2022
6d69af0
make fixup
tokestermw Dec 31, 2022
9ad7689
fix tests
tokestermw Dec 31, 2022
4df3b46
add eos_token_id to generation/utils.py and added tests test_utils.py
tokestermw Dec 31, 2022
96ccb52
add eos_token_id type hints and fix for pad tokens
tokestermw Dec 31, 2022
dd34d52
add comments
tokestermw Dec 31, 2022
fc789bf
remove some prints and remove forced false test
tokestermw Dec 31, 2022
e9bd3a9
fix
tokestermw Dec 31, 2022
b25f052
put back test_stop_sequence_stopping_criteria
tokestermw Dec 31, 2022
33e49ef
remove unused import and make fixup
tokestermw Dec 31, 2022
6428562
add a none check
tokestermw Jan 2, 2023
a60afdb
update docstring
tokestermw Jan 3, 2023
f402cb0
add more docstring for list ints
tokestermw Jan 3, 2023
84b8e1d
make fixup
tokestermw Jan 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -212,7 +212,7 @@ def process(
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
Expand All @@ -234,6 +234,9 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
Expand All @@ -253,7 +256,7 @@ def process(
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
Expand Down Expand Up @@ -307,11 +310,14 @@ def finalize(
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -376,7 +382,8 @@ def finalize(
indices[i, : len(best_idx)] = torch.tensor(best_idx)

if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]

return UserDict(
{
Expand Down Expand Up @@ -491,7 +498,7 @@ def process(
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.Tensor]:
r"""
Args:
Expand Down Expand Up @@ -549,6 +556,9 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
Expand All @@ -568,7 +578,7 @@ def process(
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):

# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
Expand Down Expand Up @@ -773,10 +783,13 @@ def finalize(
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -840,7 +853,8 @@ def finalize(
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]

return UserDict(
{
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class GenerationConfig(PushToHubMixin):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
language token.
forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`):
The id of the token to force as the last generated token when `max_length` is reached.
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
Expand All @@ -152,10 +152,10 @@ class GenerationConfig(PushToHubMixin):
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
penalty starts and `decay_factor` represents the factor of exponential decay
suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their
A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
Expand Down Expand Up @@ -183,7 +183,7 @@ class GenerationConfig(PushToHubMixin):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we adapt the text of the doc here? The docstrings also need to be updated in beam_search.py FYI.


> Generation parameters exclusive to encoder-decoder models
Expand Down
52 changes: 34 additions & 18 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import inspect
import math
from typing import Callable, Iterable, List, Optional, Tuple
from typing import Callable, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -100,24 +100,27 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
"""

def __init__(self, min_length: int, eos_token_id: int):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")

if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

self.min_length = min_length
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = -float("inf")
return scores


Expand Down Expand Up @@ -395,11 +398,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.

"""

def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):

if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
Expand All @@ -413,7 +416,14 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)

bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

bad_words_ids = list(
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
)
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
Expand Down Expand Up @@ -628,20 +638,23 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached.
"""

def __init__(self, max_length: int, eos_token_id: int):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0
scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = 0
return scores


Expand Down Expand Up @@ -671,23 +684,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`int`):
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too!

input_ids_seq_length (`int`):
The length of the input sequence.
"""

def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
def __init__(
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
for i in self.eos_token_id:
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
return scores


Expand Down
Loading