Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move eos_token_id to stopping criteria #29459

Merged
merged 29 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
584de18
add eos stopping criteria
zucchini-nlp Mar 5, 2024
79a47c4
minor fix
zucchini-nlp Mar 5, 2024
6c93f8b
Update tests/generation/test_stopping_criteria.py
zucchini-nlp Mar 5, 2024
f59b83f
check eos is not None and fix tests
zucchini-nlp Mar 5, 2024
8ebad2d
make style and fixup
zucchini-nlp Mar 5, 2024
3e2507b
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
b77b6ab
Update tests/generation/test_utils.py
zucchini-nlp Mar 6, 2024
edc76df
Update tests/generation/test_utils.py
zucchini-nlp Mar 6, 2024
f71a687
Update src/transformers/generation/__init__.py
zucchini-nlp Mar 6, 2024
387be0e
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
14ece04
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
bc3eea9
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
7518aea
camel case everywhere
zucchini-nlp Mar 6, 2024
8e5ec57
call stopping criteria list for candidate ids
zucchini-nlp Mar 7, 2024
2544d12
make style and fixup
zucchini-nlp Mar 7, 2024
12acbc4
Empty commit
zucchini-nlp Mar 7, 2024
1107673
Empty commit to pass flaky test
zucchini-nlp Mar 7, 2024
1ffc554
set max length in PromptLookupCandidateGenerator
zucchini-nlp Mar 7, 2024
ca0a414
Merge 'upstream/main' into stopping_criteria
zucchini-nlp Mar 7, 2024
ce093c1
Update src/transformers/generation/utils.py
zucchini-nlp Mar 8, 2024
aa7fae4
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 8, 2024
5375d97
lets fix this typo in docs
zucchini-nlp Mar 9, 2024
d103d29
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 15, 2024
9f59abb
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 18, 2024
48f33fc
Update src/transformers/generation/utils.py
zucchini-nlp Mar 26, 2024
801af07
Update src/transformers/generation/utils.py
zucchini-nlp Mar 26, 2024
a385c6d
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 26, 2024
7c00bb1
update PR
zucchini-nlp Mar 26, 2024
530064d
empty commit
zucchini-nlp Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,11 @@ class EOSTokenCriteria(StoppingCriteria):
def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
self.eos_token_id = torch.tensor(eos_token_id)

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
eos_token_ids = torch.tensor(self.eos_token_id, dtype=torch.int64, device=input_ids.device)
is_done = (input_ids[:, -1].unsqueeze(1) == eos_token_ids).any(dim=1)
is_done = torch.isin(input_ids, self.eos_token_id.to(input_ids.device))[:, -1]
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
return is_done


Expand Down
80 changes: 48 additions & 32 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,7 +1542,6 @@ def generate(
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand All @@ -1557,7 +1556,6 @@ def generate(
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand All @@ -1577,7 +1575,6 @@ def generate(
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand Down Expand Up @@ -1606,7 +1603,6 @@ def generate(
logits_warper=logits_warper,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand Down Expand Up @@ -1640,7 +1636,6 @@ def generate(
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand Down Expand Up @@ -1680,7 +1675,6 @@ def generate(
logits_warper=logits_warper,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand Down Expand Up @@ -1714,7 +1708,6 @@ def generate(
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand Down Expand Up @@ -1788,7 +1781,6 @@ def typeerror():
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
Expand Down Expand Up @@ -1933,12 +1925,15 @@ def _contrastive_search(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down Expand Up @@ -2400,12 +2395,15 @@ def _greedy_search(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down Expand Up @@ -2704,12 +2702,15 @@ def _sample(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down Expand Up @@ -3037,12 +3038,15 @@ def _beam_search(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down Expand Up @@ -3445,12 +3449,15 @@ def _beam_sample(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down Expand Up @@ -3806,12 +3813,15 @@ def _group_beam_search(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down Expand Up @@ -4231,12 +4241,15 @@ def _constrained_beam_search(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down Expand Up @@ -4588,12 +4601,15 @@ def _assisted_decoding(
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
# TODO remove when the method is totally private and beam scorer refactored
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None and self.generation_config.eos_token_id is not None:
eos_token_id = self.generation_config.eos_token_id
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
Expand Down
1 change: 1 addition & 0 deletions tests/generation/test_stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_eos_token_criteria(self):

input_ids, scores = self._get_tensors(5)
input_ids[:2, -1] = 0
input_ids[2, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])

input_ids, scores = self._get_tensors(5)
Expand Down