-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom stop token ids for generation #20727
Merged
sgugger
merged 35 commits into
huggingface:main
from
tokestermw:add-custom-stop-token-ids-for-generation
Jan 3, 2023
Merged
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
ecdd003
Add StopIdStoppingCriteria
tokestermw 739674d
add a working test for stop id criteria
tokestermw 3bce1cd
add to global scope
tokestermw 99904eb
add stop_ids to generate
tokestermw 20aeeca
add pipeline test
tokestermw 683c320
use tokenizer encode in test
tokestermw 2781b53
add test to generation utils
tokestermw 5316041
reformat
tokestermw 947af73
fixup
tokestermw 15834a8
make-fix-copies
tokestermw 0e6eb18
rename to stop_token_id
tokestermw 8c6e474
use stop_tokens instead
tokestermw 305e349
add to text to text generation
tokestermw 64556a8
make fixup
tokestermw 6f0812d
make repo-consistency
tokestermw 70c2dda
Add support for list of ints for eos_token_id inside generation/utils.py
tokestermw b0cccd6
Instead of having if elses, cast the eos_token_id into a List[int]
tokestermw fba3345
Add List[int] support for logits_process.py
tokestermw 405f79c
add List[int] for beam_search.py
tokestermw 8298e87
add List[int] for forced_eos_token_id
tokestermw dd6f36d
Merge branch 'main' into add-custom-stop-token-ids-for-generation
tokestermw 0062ddf
revert stop token id stopping criteria changes
tokestermw 6d69af0
make fixup
tokestermw 9ad7689
fix tests
tokestermw 4df3b46
add eos_token_id to generation/utils.py and added tests test_utils.py
tokestermw 96ccb52
add eos_token_id type hints and fix for pad tokens
tokestermw dd34d52
add comments
tokestermw fc789bf
remove some prints and remove forced false test
tokestermw e9bd3a9
fix
tokestermw b25f052
put back test_stop_sequence_stopping_criteria
tokestermw 33e49ef
remove unused import and make fixup
tokestermw 6428562
add a none check
tokestermw a60afdb
update docstring
tokestermw f402cb0
add more docstring for list ints
tokestermw 84b8e1d
make fixup
tokestermw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
|
||
import inspect | ||
import math | ||
from typing import Callable, Iterable, List, Optional, Tuple | ||
from typing import Callable, Iterable, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -100,24 +100,27 @@ class MinLengthLogitsProcessor(LogitsProcessor): | |
Args: | ||
min_length (`int`): | ||
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. | ||
eos_token_id (`int`): | ||
eos_token_id (`Union[int, List[int]]`): | ||
The id of the *end-of-sequence* token. | ||
""" | ||
|
||
def __init__(self, min_length: int, eos_token_id: int): | ||
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): | ||
if not isinstance(min_length, int) or min_length < 0: | ||
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") | ||
|
||
if not isinstance(eos_token_id, int) or eos_token_id < 0: | ||
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") | ||
if isinstance(eos_token_id, int): | ||
eos_token_id = [eos_token_id] | ||
if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]): | ||
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") | ||
|
||
self.min_length = min_length | ||
self.eos_token_id = eos_token_id | ||
|
||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
cur_len = input_ids.shape[-1] | ||
if cur_len < self.min_length: | ||
scores[:, self.eos_token_id] = -float("inf") | ||
for i in self.eos_token_id: | ||
scores[:, i] = -float("inf") | ||
return scores | ||
|
||
|
||
|
@@ -395,11 +398,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): | |
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words | ||
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, | ||
add_special_tokens=False).input_ids`. | ||
eos_token_id (`int`): | ||
eos_token_id (`Union[int, List[int]]`): | ||
The id of the *end-of-sequence* token. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
""" | ||
|
||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): | ||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): | ||
|
||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: | ||
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") | ||
|
@@ -413,7 +416,14 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): | |
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." | ||
) | ||
|
||
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) | ||
if eos_token_id is None: | ||
eos_token_id = [] | ||
if isinstance(eos_token_id, int): | ||
eos_token_id = [eos_token_id] | ||
|
||
bad_words_ids = list( | ||
filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids) | ||
) | ||
self.bad_words_id_length_1 = [] | ||
self.bad_words_id_length_greater_than_1 = [] | ||
for word in bad_words_ids: | ||
|
@@ -628,20 +638,23 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): | |
Args: | ||
max_length (`int`): | ||
The maximum length of the sequence to be generated. | ||
eos_token_id (`int`): | ||
eos_token_id (`Union[int, List[int]]`): | ||
The id of the token to force as the last generated token when `max_length` is reached. | ||
""" | ||
|
||
def __init__(self, max_length: int, eos_token_id: int): | ||
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): | ||
self.max_length = max_length | ||
if isinstance(eos_token_id, int): | ||
eos_token_id = [eos_token_id] | ||
self.eos_token_id = eos_token_id | ||
|
||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
cur_len = input_ids.shape[-1] | ||
if cur_len == self.max_length - 1: | ||
num_tokens = scores.shape[1] | ||
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") | ||
scores[:, self.eos_token_id] = 0 | ||
scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf") | ||
for i in self.eos_token_id: | ||
scores[:, i] = 0 | ||
return scores | ||
|
||
|
||
|
@@ -671,23 +684,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): | |
exponential_decay_length_penalty (`tuple(int, float)`, *optional*): | ||
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty | ||
starts and `decay_factor` represents the factor of exponential decay | ||
eos_token_id (`int`): | ||
eos_token_id (`Union[int, List[int]]`): | ||
The id of the *end-of-sequence* token. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too! |
||
input_ids_seq_length (`int`): | ||
The length of the input sequence. | ||
""" | ||
|
||
def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int): | ||
def __init__( | ||
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int | ||
): | ||
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length | ||
self.regulation_factor = exponential_decay_length_penalty[1] | ||
if isinstance(eos_token_id, int): | ||
eos_token_id = [eos_token_id] | ||
self.eos_token_id = eos_token_id | ||
|
||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: | ||
cur_len = input_ids.shape[-1] | ||
if cur_len > self.regulation_start: | ||
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow( | ||
self.regulation_factor, cur_len - self.regulation_start | ||
) | ||
for i in self.eos_token_id: | ||
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start) | ||
return scores | ||
|
||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we adapt the text of the doc here? The docstrings also need to be updated in
beam_search.py
FYI.