Skip to content

Commit

Permalink
Add custom stop token ids for generation (huggingface#20727)
Browse files Browse the repository at this point in the history
* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
  • Loading branch information
tokestermw authored and miyu386 committed Feb 9, 2023
1 parent 1fe81c5 commit abfa0cc
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 64 deletions.
44 changes: 29 additions & 15 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -42,8 +42,8 @@
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
Expand Down Expand Up @@ -74,8 +74,8 @@
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
Expand Down Expand Up @@ -212,7 +212,7 @@ def process(
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
Expand All @@ -234,6 +234,9 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
Expand All @@ -253,7 +256,7 @@ def process(
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
Expand Down Expand Up @@ -307,11 +310,14 @@ def finalize(
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -376,7 +382,8 @@ def finalize(
indices[i, : len(best_idx)] = torch.tensor(best_idx)

if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]

return UserDict(
{
Expand Down Expand Up @@ -491,7 +498,7 @@ def process(
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.Tensor]:
r"""
Args:
Expand All @@ -512,8 +519,8 @@ def process(
The scores of all tokens in the vocabulary for each of the beam hypotheses.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
Expand Down Expand Up @@ -549,6 +556,9 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
Expand All @@ -568,7 +578,7 @@ def process(
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
if (eos_token_id is not None) and (next_token.item() in eos_token_id):

# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
Expand Down Expand Up @@ -773,10 +783,13 @@ def finalize(
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -840,7 +853,8 @@ def finalize(
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]

return UserDict(
{
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ class GenerationConfig(PushToHubMixin):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
language token.
forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`):
The id of the token to force as the last generated token when `max_length` is reached.
forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash.
Note that using `remove_invalid_values` can slow down generation.
Expand All @@ -152,10 +153,10 @@ class GenerationConfig(PushToHubMixin):
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where
penalty starts and `decay_factor` represents the factor of exponential decay
suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their
A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their
log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
Expand Down Expand Up @@ -183,8 +184,8 @@ class GenerationConfig(PushToHubMixin):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
> Generation parameters exclusive to encoder-decoder models
Expand Down
61 changes: 39 additions & 22 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import inspect
import math
from typing import Callable, Iterable, List, Optional, Tuple
from typing import Callable, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -100,24 +100,27 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""

def __init__(self, min_length: int, eos_token_id: int):
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")

if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

self.min_length = min_length
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = -float("inf")
return scores


Expand Down Expand Up @@ -431,11 +434,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""

def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):

if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
Expand All @@ -449,7 +452,14 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)

bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

bad_words_ids = list(
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)
)
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
Expand Down Expand Up @@ -664,20 +674,24 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`int`):
The id of the token to force as the last generated token when `max_length` is reached.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
"""

def __init__(self, max_length: int, eos_token_id: int):
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0
scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = 0
return scores


Expand Down Expand Up @@ -707,23 +721,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`int`):
The id of the *end-of-sequence* token.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
input_ids_seq_length (`int`):
The length of the input sequence.
"""

def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int):
def __init__(
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id

def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
for i in self.eos_token_id:
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
return scores


Expand Down
Loading

0 comments on commit abfa0cc

Please sign in to comment.