Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move eos_token_id to stopping criteria #29459

Merged
merged 29 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
584de18
add eos stopping criteria
zucchini-nlp Mar 5, 2024
79a47c4
minor fix
zucchini-nlp Mar 5, 2024
6c93f8b
Update tests/generation/test_stopping_criteria.py
zucchini-nlp Mar 5, 2024
f59b83f
check eos is not None and fix tests
zucchini-nlp Mar 5, 2024
8ebad2d
make style and fixup
zucchini-nlp Mar 5, 2024
3e2507b
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
b77b6ab
Update tests/generation/test_utils.py
zucchini-nlp Mar 6, 2024
edc76df
Update tests/generation/test_utils.py
zucchini-nlp Mar 6, 2024
f71a687
Update src/transformers/generation/__init__.py
zucchini-nlp Mar 6, 2024
387be0e
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
14ece04
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
bc3eea9
Update src/transformers/generation/stopping_criteria.py
zucchini-nlp Mar 6, 2024
7518aea
camel case everywhere
zucchini-nlp Mar 6, 2024
8e5ec57
call stopping criteria list for candidate ids
zucchini-nlp Mar 7, 2024
2544d12
make style and fixup
zucchini-nlp Mar 7, 2024
12acbc4
Empty commit
zucchini-nlp Mar 7, 2024
1107673
Empty commit to pass flaky test
zucchini-nlp Mar 7, 2024
1ffc554
set max length in PromptLookupCandidateGenerator
zucchini-nlp Mar 7, 2024
ca0a414
Merge 'upstream/main' into stopping_criteria
zucchini-nlp Mar 7, 2024
ce093c1
Update src/transformers/generation/utils.py
zucchini-nlp Mar 8, 2024
aa7fae4
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 8, 2024
5375d97
lets fix this typo in docs
zucchini-nlp Mar 9, 2024
d103d29
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 15, 2024
9f59abb
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 18, 2024
48f33fc
Update src/transformers/generation/utils.py
zucchini-nlp Mar 26, 2024
801af07
Update src/transformers/generation/utils.py
zucchini-nlp Mar 26, 2024
a385c6d
Merge remote-tracking branch 'upstream/main' into stopping_crtiteria
zucchini-nlp Mar 26, 2024
7c00bb1
update PR
zucchini-nlp Mar 26, 2024
530064d
empty commit
zucchini-nlp Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"MaxNewTokensCriteria",
"MaxLengthCriteria",
"MaxTimeCriteria",
"EOSTokenCriteria",
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
Expand Down Expand Up @@ -218,6 +219,7 @@
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
EOSTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
Expand Down
24 changes: 23 additions & 1 deletion src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from abc import ABC
from copy import deepcopy
from typing import Optional
from typing import List, Optional, Union

import torch

Expand Down Expand Up @@ -129,6 +129,28 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


class EOSTokenCriteria(StoppingCriteria):
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
"""
This class can be used to stop generation whenever the "end-of-sequence" token in generated.
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
By default, it uses the `EOS` token from model's generation config.
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

Args:
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Comment on lines +138 to +139
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to also support list of lists ? To have stopping tokens (if the eos is ["<","eos",">"])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is it really true for existing models to have eos tokens as a sequence? I guess you are referring to custom eos tokens, when users want to stop at "" but they haven't trained the model with "" as special token. If that's the case, there is StopStringsCriteria PR coming or users are free to write their custom criteria

@gante wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the case where we want to stop on a string will be covered by #28932 (and is way more complex)

This PR is exclusively to port the existing EOS logic into its own stopping criteria :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I forgot it was being covered

"""

def __init__(self, eos_token_id: Union[int, List[int]]):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
eos_token_ids = torch.tensor(self.eos_token_id, dtype=torch.int64, device=input_ids.device)
gante marked this conversation as resolved.
Show resolved Hide resolved
is_done = (input_ids[:, -1].unsqueeze(1) == eos_token_ids).any(dim=1)
gante marked this conversation as resolved.
Show resolved Hide resolved
return is_done


class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
Expand Down
176 changes: 125 additions & 51 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from .stopping_criteria import (
EOSTokenCriteria,
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
Expand Down Expand Up @@ -942,6 +943,8 @@ def _get_stopping_criteria(
)
if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
if generation_config.eos_token_id is not None:
criteria.append(EOSTokenCriteria(eos_token_id=generation_config.eos_token_id))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria

Expand Down Expand Up @@ -1922,11 +1925,24 @@ def _contrastive_search(
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
sequential = sequential if sequential is not None else self.generation_config.low_memory
if eos_token_id is not None:
warnings.warn(
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should rather be set in the generation_config I guess? Why can't we leave this as a kwarg and set the generation_config.eos_token_id ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the general idea, we want to stop accepting eos as input argument. The warning is for backward compatibility, since we used to give priority to user defined eos_token_id from args before checking the generation_config.eos_token_id

Copy link
Member

@gante gante Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ as @zucchini-nlp wrote. This is merely for backward compatibility, if users want to keep calling the decoding methods directly (at their own risk, since we're deprecating their public API)

The decoding methods don't set up new logits processors nor stopping criteria. As such, The correct replacement for the EOS feature is to pass the new EOSTokenCriteria :)

The long-term goal is to disable calling decoding methods directly at all, so we can optimize the codebase.

FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id
gante marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
sequential = sequential if sequential is not None else self.generation_config.low_memory
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
output_attentions = (
Expand Down Expand Up @@ -2198,15 +2214,8 @@ def _contrastive_search(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)

# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

if unfinished_sequences.max() == 0:
this_peer_finished = True

Expand Down Expand Up @@ -2383,9 +2392,23 @@ def _greedy_search(
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None:
warnings.warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
Expand Down Expand Up @@ -2487,14 +2510,7 @@ def _greedy_search(
model_inputs=model_inputs,
)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
Expand Down Expand Up @@ -2680,10 +2696,23 @@ def _sample(
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None:
warnings.warn(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
output_attentions = (
Expand Down Expand Up @@ -2786,14 +2815,7 @@ def _sample(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
Expand Down Expand Up @@ -3007,7 +3029,21 @@ def _beam_search(
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None:
warnings.warn(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
Expand Down Expand Up @@ -3401,7 +3437,21 @@ def _beam_sample(
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None:
warnings.warn(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
Expand Down Expand Up @@ -3748,7 +3798,21 @@ def _group_beam_search(
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None:
warnings.warn(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
Expand Down Expand Up @@ -4159,7 +4223,21 @@ def _constrained_beam_search(
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None:
warnings.warn(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
Expand Down Expand Up @@ -4502,11 +4580,23 @@ def _assisted_decoding(
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if eos_token_id is not None and pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
if eos_token_id is not None:
warnings.warn(
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.",
FutureWarning,
)
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
else:
eos_token_id = [
criteria.eos_token_id for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
]
if not eos_token_id and self.generation_config.eos_token_id:
stopping_criteria.append(EOSTokenCriteria(eos_token_id=eos_token_id))
eos_token_id = self.generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
output_attentions = (
Expand Down Expand Up @@ -4562,13 +4652,7 @@ def _assisted_decoding(
candidate_logits = candidate_logits.to(self.device)

candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = (
~candidate_input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
.bool()
)
last_assistant_token_is_eos = stopping_criteria[-1](candidate_input_ids, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means we the last stopping_criteria is always EOSTokenCriteria. It's not necessarily obvious and if people use custom passed criteria I am not sure this will always be the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, will fix it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker good catch, I missed this one too :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante need you to look. I removed some code, given that assisted decoding works only with one batch, so we can make some assumptions about when to stop generating.


# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
Expand Down Expand Up @@ -4701,17 +4785,7 @@ def _assisted_decoding(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
input_ids[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
Expand Down
16 changes: 16 additions & 0 deletions tests/generation/test_stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch

from transformers.generation import (
EOSTokenCriteria,
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
Expand Down Expand Up @@ -98,6 +99,21 @@ def test_max_time_criteria(self):
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(all(criteria(input_ids, scores)))

def test_eos_token_criteria(self):
criteria = EOSTokenCriteria(eos_token_id=0)

input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 0
self.assertTrue(all(criteria(input_ids, scores)))

input_ids, scores = self._get_tensors(5)
input_ids[:2, -1] = 0
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

input_ids, scores = self._get_tensors(5)
input_ids[:, -1] = 1
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])

def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)

Expand Down