From 0efcf32351d5368e6aa0754e9503d576f7b5ca36 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 27 Mar 2024 17:18:10 +0500 Subject: [PATCH] Move `eos_token_id` to stopping criteria (#29459) * add eos stopping criteria * minor fix * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante * check eos is not None and fix tests * make style and fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante * Update tests/generation/test_utils.py Co-authored-by: Joao Gante * Update tests/generation/test_utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * camel case everywhere * call stopping criteria list for candidate ids * make style and fixup * Empty commit * Empty commit to pass flaky test * set max length in PromptLookupCandidateGenerator * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante * lets fix this typo in docs * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update PR * empty commit --------- Co-authored-by: Joao Gante Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/__init__.py | 2 + .../generation/candidate_generator.py | 11 +- .../generation/stopping_criteria.py | 23 +- src/transformers/generation/utils.py | 236 ++++++++++++------ tests/generation/test_stopping_criteria.py | 17 ++ tests/generation/test_utils.py | 2 - 6 files changed, 215 insertions(+), 76 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 8f2a6ad9600d9..315d5b08a7594 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -82,6 +82,7 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", + "EosTokenCriteria", "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", @@ -216,6 +217,7 @@ WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( + EosTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 3ed65c3816738..0859021956153 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -238,15 +238,20 @@ class PromptLookupCandidateGenerator(CandidateGenerator): The maximum ngram size to be considered for matching in the prompt num_output_tokens (`int`): The number of tokens to be output as candidate tokens. + max_length (`int`): + The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. + Defaults to 20, which is the max length used as default in generation config. """ def __init__( self, num_output_tokens: int = 10, max_matching_ngram_size: int = None, + max_length: int = 20, ): self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 + self.max_length = max_length if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") @@ -264,6 +269,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, """ input_length = input_ids.size(1) + # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. + if self.max_length == input_length + 1: + return input_ids, None + chosen_ids = None match_found = False for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): @@ -283,7 +292,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens - end_idx = min(end_idx, input_length) + end_idx = min(end_idx, input_length, self.max_length) if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index f4624296d237f..bac537b71b96e 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -2,7 +2,7 @@ import warnings from abc import ABC from copy import deepcopy -from typing import Optional +from typing import List, Optional, Union import torch @@ -129,6 +129,27 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) +class EosTokenCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever the "end-of-sequence" token is generated. + By default, it uses the `model.generation_config.eos_token_id`. + + Args: + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + """ + + def __init__(self, eos_token_id: Union[int, List[int]]): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = torch.tensor(eos_token_id) + + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device)) + return is_done + + class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c0d3b1dd6078d..a958c8c86a92b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -75,6 +75,7 @@ UnbatchedClassifierFreeGuidanceLogitsProcessor, ) from .stopping_criteria import ( + EosTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, @@ -690,6 +691,7 @@ def _get_candidate_generator( candidate_generator = PromptLookupCandidateGenerator( num_output_tokens=generation_config.prompt_lookup_num_tokens, max_matching_ngram_size=generation_config.max_matching_ngram_size, + max_length=generation_config.max_length, ) else: candidate_generator = AssistedCandidateGenerator( @@ -892,6 +894,8 @@ def _get_stopping_criteria( ) if generation_config.max_time is not None: criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) + if generation_config.eos_token_id is not None: + criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -1306,7 +1310,7 @@ def generate( Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: @@ -1515,7 +1519,6 @@ def generate( logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1530,7 +1533,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1550,7 +1552,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1579,7 +1580,6 @@ def generate( logits_warper=logits_warper, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1613,7 +1613,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1653,7 +1652,6 @@ def generate( logits_warper=logits_warper, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1687,7 +1685,6 @@ def generate( logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1761,7 +1758,6 @@ def typeerror(): logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, @@ -1916,11 +1912,28 @@ def _contrastive_search( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - sequential = sequential if sequential is not None else self.generation_config.low_memory + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + sequential = sequential if sequential is not None else self.generation_config.low_memory output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( @@ -2186,12 +2199,6 @@ def _contrastive_search( is_encoder_decoder=self.config.is_encoder_decoder, ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - # stop when each sentence is finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 @@ -2365,10 +2372,27 @@ def _greedy_search( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2463,12 +2487,6 @@ def _greedy_search( is_encoder_decoder=self.config.is_encoder_decoder, ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 @@ -2650,10 +2668,27 @@ def _sample( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( @@ -2751,12 +2786,6 @@ def _sample( is_encoder_decoder=self.config.is_encoder_decoder, ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 @@ -2966,7 +2995,25 @@ def _beam_search( if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -3351,7 +3398,25 @@ def _beam_sample( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -3688,7 +3753,25 @@ def _group_beam_search( ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -4089,7 +4172,25 @@ def _constrained_beam_search( if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores @@ -4421,12 +4522,27 @@ def _assisted_decoding( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if eos_token_id is not None and pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + if eos_token_id is not None: + logger.warning_once( + "`eos_token_id` is deprecated in this function and will be removed in v4.41, use" + " `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead." + " Otherwise make sure to set `model.generation_config.eos_token_id`", + FutureWarning, + ) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # TODO remove when the method is totally private and beam scorer refactored + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( @@ -4462,9 +4578,6 @@ def _assisted_decoding( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - # other auxiliary variables - max_len = stopping_criteria[0].max_length - this_peer_finished = False while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] @@ -4476,13 +4589,7 @@ def _assisted_decoding( candidate_logits = candidate_logits.to(self.device) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - last_assistant_token_is_eos = ( - ~candidate_input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - .bool() - ) + is_done_candidate = stopping_criteria(candidate_input_ids, None) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, @@ -4525,15 +4632,13 @@ def _assisted_decoding( # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). - max_matches = max_len - cur_len - 1 if do_sample and candidate_logits is not None: valid_tokens, n_matches = _speculative_sampling( candidate_input_ids, candidate_logits, candidate_length, new_logits, - last_assistant_token_is_eos, - max_matches, + is_done_candidate, ) # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the @@ -4550,9 +4655,8 @@ def _assisted_decoding( n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() # Ensure we don't generate beyond max_len or an EOS token - if last_assistant_token_is_eos and n_matches == candidate_length: + if is_done_candidate and n_matches == candidate_length: n_matches -= 1 - n_matches = min(n_matches, max_matches) valid_tokens = selected_tokens[:, : n_matches + 1] # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated @@ -4625,15 +4729,6 @@ def _assisted_decoding( is_encoder_decoder=self.config.is_encoder_decoder, ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 @@ -4678,8 +4773,7 @@ def _speculative_sampling( candidate_logits, candidate_length, new_logits, - last_assistant_token_is_eos, - max_matches, + is_done_candidate, ): """ Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns @@ -4704,16 +4798,14 @@ def _speculative_sampling( n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) - if last_assistant_token_is_eos and n_matches == candidate_length: + if is_done_candidate and n_matches == candidate_length: # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model # due to acceptance on EOS we fix `n_matches` n_matches -= 1 valid_tokens = new_candidate_input_ids[:, : n_matches + 1] else: - n_matches = min(n_matches, max_matches) - # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = min(candidate_logits.shape[1], max_matches) + gamma = candidate_logits.shape[1] p_n_plus_1 = p[:, n_matches, :] if n_matches < gamma: q_n_plus_1 = q[:, n_matches, :] diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 7fa118c9e3550..0c770972a7fdf 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -26,6 +26,7 @@ import torch from transformers.generation import ( + EosTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, @@ -98,6 +99,22 @@ def test_max_time_criteria(self): criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) self.assertTrue(all(criteria(input_ids, scores))) + def test_eos_token_criteria(self): + criteria = EosTokenCriteria(eos_token_id=0) + + input_ids, scores = self._get_tensors(5) + input_ids[:, -1] = 0 + self.assertTrue(all(criteria(input_ids, scores))) + + input_ids, scores = self._get_tensors(5) + input_ids[:2, -1] = 0 + input_ids[2, -1] = 1 + self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False]) + + input_ids, scores = self._get_tensors(5) + input_ids[:, -1] = 1 + self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False]) + def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d82f137d95630..99f6e84a3036e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1899,14 +1899,12 @@ def test_speculative_sampling(self): ] ) last_assistant_token_is_eos = False - max_matches = 5 validated_tokens, n_matches = _speculative_sampling( candidate_input_ids, candidate_logits, candidate_length, new_logits, last_assistant_token_is_eos, - max_matches, ) self.assertTrue(n_matches.item() == 2) self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])