From ecdd003a29969643d33de8a92bfbb6b2d1329f60 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sat, 3 Dec 2022 20:24:29 -0800 Subject: [PATCH 01/34] Add StopIdStoppingCriteria --- .../generation/stopping_criteria.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 7023fa9998c94..67893d8319c87 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -107,6 +107,38 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return time.time() - self.initial_timestamp > self.max_time +class StopIdStoppingCriteria(StoppingCriteria): + """ + This class can be used to stop a generation once the model generates the specified token id. + + Args: + stop_id (`int`): + The stop token id. This corresponds to the token id of the token you would like to stop the generation on. + early_stopping (`bool`): + If set to `True`, the generation will stop once it detects at least one stop token id. + Otherwise, it will wait till it sees at least one stop token id for each example in the batch. + + Examples: + ```python + >>> stop_id = tokenizer.convert_tokens_to_ids('\n') + >>> stopping_criteria = StopIdStoppingCriteria(stop_id) + >>> model.generate(text, stopping_criteria=[stopping_criteria]) + ``` + """ + def __init__(self, stop_id: int, early_stopping: bool = False): + self.stop_id = stop_id + self.early_stopping = early_stopping + + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: + batch_size = input_ids.shape[0] + how_many_finished = ((input_ids[:, -1] == self.stop_id).sum() > 0).sum() + if self.early_stopping: + return how_many_finished > 0 + else: + return how_many_finished == batch_size + + class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: From 739674d41cb2ea69aae348012ed409b0454445b6 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sat, 3 Dec 2022 21:08:50 -0800 Subject: [PATCH 02/34] add a working test for stop id criteria --- src/transformers/generation/__init__.py | 2 ++ tests/generation/test_stopping_criteria.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index b1d8e8acad5f1..3f8ac9e727fd7 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -67,6 +67,7 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", + "StopIdStoppingCriteria", "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", @@ -184,6 +185,7 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, + StopIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index dfc5308359ffb..5176e3a396398 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import time import unittest @@ -29,6 +30,7 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, + StopIdStoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, ) @@ -36,11 +38,11 @@ @require_torch class StoppingCriteriaTestCase(unittest.TestCase): - def _get_tensors(self, length): + def _get_tensors(self, length, rng=None): batch_size = 3 vocab_size = 250 - input_ids = ids_tensor((batch_size, length), vocab_size) + input_ids = ids_tensor((batch_size, length), vocab_size, rng=rng) scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length return input_ids, scores @@ -98,6 +100,18 @@ def test_max_time_criteria(self): criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) self.assertTrue(criteria(input_ids, scores)) + def test_stop_id_criteria(self): + input_ids, scores = self._get_tensors(5, rng=random.Random(42)) + + criteria = StopIdStoppingCriteria(stop_id=5) + self.assertFalse(criteria(input_ids, scores)) + + criteria = StopIdStoppingCriteria(stop_id=22, early_stopping=False) + self.assertFalse(criteria(input_ids, scores)) + + criteria = StopIdStoppingCriteria(stop_id=22, early_stopping=True) + self.assertTrue(criteria(input_ids, scores)) + def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) From 3bce1cd43b877bfe10461eaae3232c4bda736206 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sat, 3 Dec 2022 21:09:03 -0800 Subject: [PATCH 03/34] add to global scope --- src/transformers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bf630e8a9cf0e..04a8faa08c5b0 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -4007,6 +4007,7 @@ LogitsWarper, MaxLengthCriteria, MaxTimeCriteria, + StopIdStoppingCriteria, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, From 99904eb21e31fe6723f1be16eea88f406b90364b Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sat, 3 Dec 2022 21:38:50 -0800 Subject: [PATCH 04/34] add stop_ids to generate --- src/transformers/generation/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3d945b2be37a7..698300d6f9bee 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -60,6 +60,7 @@ from .stopping_criteria import ( MaxLengthCriteria, MaxTimeCriteria, + StopIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, @@ -875,13 +876,16 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] + self, max_length: Optional[int], max_time: Optional[float], stop_ids: Optional[List[int]], stopping_criteria: Optional[StoppingCriteriaList] ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if max_length is not None: criteria.append(MaxLengthCriteria(max_length=max_length)) if max_time is not None: criteria.append(MaxTimeCriteria(max_time=max_time)) + if stop_ids is not None: + for stop_id in stop_ids: + criteria.append(StopIdStoppingCriteria(stop_id=stop_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -1028,6 +1032,7 @@ def generate( prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = None, renormalize_logits: Optional[bool] = None, + stop_ids: Optional[List[int]] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, @@ -1171,6 +1176,8 @@ def generate( Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. + stop_ids (`List[int]`, *optional*): + When specified, the generation will stop at one of the token ids specified. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a model's config. If a stopping criteria is passed that is already created with the arguments or a @@ -1505,7 +1512,7 @@ def generate( # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( - max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria + max_length=max_length, max_time=max_time, stop_ids=stop_ids, stopping_criteria=stopping_criteria ) # 9. go into different generation modes if is_greedy_gen_mode: From 20aeeca6505160cfce37ba0219b3498b7e0a73f8 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sat, 3 Dec 2022 22:01:47 -0800 Subject: [PATCH 05/34] add pipeline test --- tests/pipelines/test_pipelines_text_generation.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index ca0e101158457..4cdfa97e041db 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -159,6 +159,17 @@ def test_stop_sequence_stopping_criteria(self): output = text_generator(prompt, stop_sequence=" fe") self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) + def test_stop_ids_stopping_criteria(self): + prompt = """Hello I believe in""" + text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2") + output = text_generator(prompt) + self.assertEqual( + output, + [{"generated_text": "Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"}], + ) + output = text_generator(prompt, stop_ids=[641]) + self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) + def run_pipeline_test(self, text_generator, _): model = text_generator.model tokenizer = text_generator.tokenizer From 683c320e42a58128ee6ea291f4e25be0ef256ed8 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 4 Dec 2022 22:59:08 -0800 Subject: [PATCH 06/34] use tokenizer encode in test --- tests/pipelines/test_pipelines_text_generation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 4cdfa97e041db..e5f74427923d1 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -167,7 +167,8 @@ def test_stop_ids_stopping_criteria(self): output, [{"generated_text": "Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"}], ) - output = text_generator(prompt, stop_ids=[641]) + stop_ids = text_generator.tokenizer.encode(' fe') + output = text_generator(prompt, stop_ids=stop_ids) self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) def run_pipeline_test(self, text_generator, _): From 2781b53a2ff31af45c4192b7d9aab8e8ecea636f Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 4 Dec 2022 23:15:33 -0800 Subject: [PATCH 07/34] add test to generation utils --- tests/generation/test_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a03f0d12b9d14..e6a6914e48bf3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3025,3 +3025,17 @@ def test_validate_generation_inputs(self): # However, valid model_kwargs are accepted valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)} model.generate(input_ids, **valid_model_kwargs) + + def test_stop_ids_stopping_criteria(self): + prompt = """Hello I believe in""" + gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + input_ids = gpt2_tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device) + + stop_ids = gpt2_tokenizer.encode(' fe') + self.assertEqual(stop_ids, [641]) + + output = gpt2_model.generate(input_ids=input_ids, stop_ids=stop_ids) + generated_text = gpt2_tokenizer.batch_decode(output) + + self.assertEqual(generated_text, ['Hello I believe in fe']) From 531604147c3736f8378335ff19038021b4a1d5e6 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 4 Dec 2022 23:40:18 -0800 Subject: [PATCH 08/34] reformat --- src/transformers/generation/stopping_criteria.py | 1 + src/transformers/generation/utils.py | 6 +++++- tests/generation/test_utils.py | 4 ++-- tests/pipelines/test_pipelines_text_generation.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 67893d8319c87..580fde56240f1 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -125,6 +125,7 @@ class StopIdStoppingCriteria(StoppingCriteria): >>> model.generate(text, stopping_criteria=[stopping_criteria]) ``` """ + def __init__(self, stop_id: int, early_stopping: bool = False): self.stop_id = stop_id self.early_stopping = early_stopping diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 698300d6f9bee..5ce6198e5c047 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -876,7 +876,11 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, max_length: Optional[int], max_time: Optional[float], stop_ids: Optional[List[int]], stopping_criteria: Optional[StoppingCriteriaList] + self, + max_length: Optional[int], + max_time: Optional[float], + stop_ids: Optional[List[int]], + stopping_criteria: Optional[StoppingCriteriaList], ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if max_length is not None: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e6a6914e48bf3..87eacc148289a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3032,10 +3032,10 @@ def test_stop_ids_stopping_criteria(self): gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) input_ids = gpt2_tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device) - stop_ids = gpt2_tokenizer.encode(' fe') + stop_ids = gpt2_tokenizer.encode(" fe") self.assertEqual(stop_ids, [641]) output = gpt2_model.generate(input_ids=input_ids, stop_ids=stop_ids) generated_text = gpt2_tokenizer.batch_decode(output) - self.assertEqual(generated_text, ['Hello I believe in fe']) + self.assertEqual(generated_text, ["Hello I believe in fe"]) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index e5f74427923d1..e2519d42d47dc 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -167,7 +167,7 @@ def test_stop_ids_stopping_criteria(self): output, [{"generated_text": "Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"}], ) - stop_ids = text_generator.tokenizer.encode(' fe') + stop_ids = text_generator.tokenizer.encode(" fe") output = text_generator(prompt, stop_ids=stop_ids) self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) From 947af736aa0a22070b3822fbb8b720134ae760e7 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 4 Dec 2022 23:41:57 -0800 Subject: [PATCH 09/34] fixup --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 04a8faa08c5b0..cbaf14fbf20fa 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -4007,13 +4007,13 @@ LogitsWarper, MaxLengthCriteria, MaxTimeCriteria, - StopIdStoppingCriteria, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, PhrasalConstraint, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, + StopIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, TemperatureLogitsWarper, From 15834a801ca4b0b67bd75548db03bfd3005c9484 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 4 Dec 2022 23:42:50 -0800 Subject: [PATCH 10/34] make-fix-copies --- src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index daaefd5297fa9..75aad68b52034 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -234,6 +234,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class StopIdStoppingCriteria(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class StoppingCriteria(metaclass=DummyObject): _backends = ["torch"] From 0e6eb1819196831fd1f4cc58446409cf6c726cd8 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 9 Dec 2022 20:48:34 -0800 Subject: [PATCH 11/34] rename to stop_token_id --- src/transformers/__init__.py | 2 +- src/transformers/generation/__init__.py | 4 ++-- src/transformers/generation/stopping_criteria.py | 14 +++++++------- src/transformers/generation/utils.py | 16 ++++++++-------- src/transformers/utils/dummy_pt_objects.py | 2 +- tests/generation/test_stopping_criteria.py | 10 +++++----- tests/generation/test_utils.py | 8 ++++---- .../pipelines/test_pipelines_text_generation.py | 7 ++++--- 8 files changed, 32 insertions(+), 31 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cbaf14fbf20fa..829bea5d05949 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -4013,7 +4013,7 @@ PhrasalConstraint, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, - StopIdStoppingCriteria, + StopTokenIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, TemperatureLogitsWarper, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 3f8ac9e727fd7..be020ede2c909 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -67,7 +67,7 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", - "StopIdStoppingCriteria", + "StopTokenIdStoppingCriteria", "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", @@ -185,7 +185,7 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, - StopIdStoppingCriteria, + StopTokenIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 580fde56240f1..a937a5cc1b585 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -107,12 +107,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return time.time() - self.initial_timestamp > self.max_time -class StopIdStoppingCriteria(StoppingCriteria): +class StopTokenIdStoppingCriteria(StoppingCriteria): """ This class can be used to stop a generation once the model generates the specified token id. Args: - stop_id (`int`): + stop_token_id (`int`): The stop token id. This corresponds to the token id of the token you would like to stop the generation on. early_stopping (`bool`): If set to `True`, the generation will stop once it detects at least one stop token id. @@ -120,20 +120,20 @@ class StopIdStoppingCriteria(StoppingCriteria): Examples: ```python - >>> stop_id = tokenizer.convert_tokens_to_ids('\n') - >>> stopping_criteria = StopIdStoppingCriteria(stop_id) + >>> stop_token_id = tokenizer.convert_tokens_to_ids('\n') + >>> stopping_criteria = StopTokenIdStoppingCriteria(stop_token_id) >>> model.generate(text, stopping_criteria=[stopping_criteria]) ``` """ - def __init__(self, stop_id: int, early_stopping: bool = False): - self.stop_id = stop_id + def __init__(self, stop_token_id: int, early_stopping: bool = False): + self.stop_token_id = stop_token_id self.early_stopping = early_stopping @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: batch_size = input_ids.shape[0] - how_many_finished = ((input_ids[:, -1] == self.stop_id).sum() > 0).sum() + how_many_finished = ((input_ids[:, -1] == self.stop_token_id).sum() > 0).sum() if self.early_stopping: return how_many_finished > 0 else: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5ce6198e5c047..ebf2fc3d1c513 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -60,7 +60,7 @@ from .stopping_criteria import ( MaxLengthCriteria, MaxTimeCriteria, - StopIdStoppingCriteria, + StopTokenIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, @@ -879,7 +879,7 @@ def _get_stopping_criteria( self, max_length: Optional[int], max_time: Optional[float], - stop_ids: Optional[List[int]], + stop_token_ids: Optional[List[int]], stopping_criteria: Optional[StoppingCriteriaList], ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() @@ -887,9 +887,9 @@ def _get_stopping_criteria( criteria.append(MaxLengthCriteria(max_length=max_length)) if max_time is not None: criteria.append(MaxTimeCriteria(max_time=max_time)) - if stop_ids is not None: - for stop_id in stop_ids: - criteria.append(StopIdStoppingCriteria(stop_id=stop_id)) + if stop_token_ids is not None: + for stop_token_id in stop_token_ids: + criteria.append(StopTokenIdStoppingCriteria(stop_token_id=stop_token_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -1036,7 +1036,7 @@ def generate( prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = None, renormalize_logits: Optional[bool] = None, - stop_ids: Optional[List[int]] = None, + stop_token_ids: Optional[List[int]] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, @@ -1180,7 +1180,7 @@ def generate( Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. - stop_ids (`List[int]`, *optional*): + stop_token_ids (`List[int]`, *optional*): When specified, the generation will stop at one of the token ids specified. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a @@ -1516,7 +1516,7 @@ def generate( # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( - max_length=max_length, max_time=max_time, stop_ids=stop_ids, stopping_criteria=stopping_criteria + max_length=max_length, max_time=max_time, stop_token_ids=stop_token_ids, stopping_criteria=stopping_criteria ) # 9. go into different generation modes if is_greedy_gen_mode: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 75aad68b52034..60083b42ad710 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -234,7 +234,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class StopIdStoppingCriteria(metaclass=DummyObject): +class StopTokenIdStoppingCriteria(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 5176e3a396398..d69c800f70bfd 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -30,7 +30,7 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, - StopIdStoppingCriteria, + StopTokenIdStoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, ) @@ -100,16 +100,16 @@ def test_max_time_criteria(self): criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) self.assertTrue(criteria(input_ids, scores)) - def test_stop_id_criteria(self): + def test_stop_token_id_criteria(self): input_ids, scores = self._get_tensors(5, rng=random.Random(42)) - criteria = StopIdStoppingCriteria(stop_id=5) + criteria = StopTokenIdStoppingCriteria(stop_token_id=5) self.assertFalse(criteria(input_ids, scores)) - criteria = StopIdStoppingCriteria(stop_id=22, early_stopping=False) + criteria = StopTokenIdStoppingCriteria(stop_token_id=22, early_stopping=False) self.assertFalse(criteria(input_ids, scores)) - criteria = StopIdStoppingCriteria(stop_id=22, early_stopping=True) + criteria = StopTokenIdStoppingCriteria(stop_token_id=22, early_stopping=True) self.assertTrue(criteria(input_ids, scores)) def test_validate_stopping_criteria(self): diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 87eacc148289a..c35f6df8e8668 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3026,16 +3026,16 @@ def test_validate_generation_inputs(self): valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)} model.generate(input_ids, **valid_model_kwargs) - def test_stop_ids_stopping_criteria(self): + def test_stop_token_ids_stopping_criteria(self): prompt = """Hello I believe in""" gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) input_ids = gpt2_tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device) - stop_ids = gpt2_tokenizer.encode(" fe") - self.assertEqual(stop_ids, [641]) + stop_token_ids = gpt2_tokenizer.encode(" fe") + self.assertEqual(stop_token_ids, [641]) - output = gpt2_model.generate(input_ids=input_ids, stop_ids=stop_ids) + output = gpt2_model.generate(input_ids=input_ids, stop_token_ids=stop_token_ids) generated_text = gpt2_tokenizer.batch_decode(output) self.assertEqual(generated_text, ["Hello I believe in fe"]) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index e2519d42d47dc..793f7a837ead0 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -159,7 +159,8 @@ def test_stop_sequence_stopping_criteria(self): output = text_generator(prompt, stop_sequence=" fe") self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) - def test_stop_ids_stopping_criteria(self): + # TODO: Fix, we will use stop_tokens for the pipeline + def test_stop_token_ids_stopping_criteria(self): prompt = """Hello I believe in""" text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2") output = text_generator(prompt) @@ -167,8 +168,8 @@ def test_stop_ids_stopping_criteria(self): output, [{"generated_text": "Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"}], ) - stop_ids = text_generator.tokenizer.encode(" fe") - output = text_generator(prompt, stop_ids=stop_ids) + stop_token_ids = text_generator.tokenizer.encode(" fe") + output = text_generator(prompt, stop_token_ids=stop_token_ids) self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) def run_pipeline_test(self, text_generator, _): From 8c6e474e7e751351f63f0f85ca5216e609a33b8b Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 9 Dec 2022 21:33:56 -0800 Subject: [PATCH 12/34] use stop_tokens instead --- src/transformers/pipelines/text_generation.py | 18 ++++++++++++++++++ .../test_pipelines_text_generation.py | 18 +++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 2ac4cdeaf511a..5ea08b428b124 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,10 +1,12 @@ import enum +from typing import List import warnings from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING from ..utils import add_end_docstrings, is_tf_available from .base import PIPELINE_INIT_ARGS, Pipeline +from ..generation.stopping_criteria import StoppingCriteriaList, StopTokenIdStoppingCriteria if is_tf_available(): @@ -97,6 +99,7 @@ def _sanitize_parameters( prefix=None, handle_long_generation=None, stop_sequence=None, + stop_tokens: List[str] = None, **generate_kwargs ): preprocess_params = {} @@ -147,6 +150,21 @@ def _sanitize_parameters( ) generate_kwargs["eos_token_id"] = stop_sequence_ids[0] + if stop_tokens is not None: + stop_token_ids = [] + for stop_token in stop_tokens: + _stop_token_ids = self.tokenizer.encode(stop_token, add_special_tokens=False) + if len(_stop_token_ids) > 1: + raise ValueError( + f"The stop_token {stop_token} has more than one associated token id: {_stop_token_ids}." + ) + stop_token_id = _stop_token_ids[0] + stop_token_ids.append(stop_token_id) + if 'stop_token_ids' not in generate_kwargs: + generate_kwargs['stop_token_ids'] = stop_token_ids + else: + generate_kwargs['stop_token_ids'].extend(stop_token_ids) + return preprocess_params, forward_params, postprocess_params # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 793f7a837ead0..ad303dadd6bc5 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -159,7 +159,6 @@ def test_stop_sequence_stopping_criteria(self): output = text_generator(prompt, stop_sequence=" fe") self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) - # TODO: Fix, we will use stop_tokens for the pipeline def test_stop_token_ids_stopping_criteria(self): prompt = """Hello I believe in""" text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2") @@ -172,6 +171,23 @@ def test_stop_token_ids_stopping_criteria(self): output = text_generator(prompt, stop_token_ids=stop_token_ids) self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) + def test_stop_tokens_stopping_criteria(self): + prompt = """Hello I believe in""" + text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", stop_tokens=[' fe']) + output = text_generator(prompt) + self.assertEqual( + output, + [{"generated_text": "Hello I believe in fe"}], + ) + + prompt = """Hello I believe in""" + text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", stopping_criteria=[], stop_tokens=[' fe']) + output = text_generator(prompt) + self.assertEqual( + output, + [{"generated_text": "Hello I believe in fe"}], + ) + def run_pipeline_test(self, text_generator, _): model = text_generator.model tokenizer = text_generator.tokenizer From 305e349dd2fda6228461e7f2202bd107e6ad2345 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 9 Dec 2022 21:38:34 -0800 Subject: [PATCH 13/34] add to text to text generation --- .../pipelines/text2text_generation.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index a9f73218ad54f..30a89d9c2cddf 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -1,4 +1,5 @@ import enum +from typing import List import warnings from ..tokenization_utils import TruncationStrategy @@ -78,6 +79,7 @@ def _sanitize_parameters( clean_up_tokenization_spaces=None, truncation=None, stop_sequence=None, + stop_tokens: List[str] = None, **generate_kwargs ): preprocess_params = {} @@ -104,6 +106,21 @@ def _sanitize_parameters( ) generate_kwargs["eos_token_id"] = stop_sequence_ids[0] + if stop_tokens is not None: + stop_token_ids = [] + for stop_token in stop_tokens: + _stop_token_ids = self.tokenizer.encode(stop_token, add_special_tokens=False) + if len(_stop_token_ids) > 1: + raise ValueError( + f"The stop_token {stop_token} has more than one associated token id: {_stop_token_ids}." + ) + stop_token_id = _stop_token_ids[0] + stop_token_ids.append(stop_token_id) + if 'stop_token_ids' not in generate_kwargs: + generate_kwargs['stop_token_ids'] = stop_token_ids + else: + generate_kwargs['stop_token_ids'].extend(stop_token_ids) + return preprocess_params, forward_params, postprocess_params def check_inputs(self, input_length: int, min_length: int, max_length: int): From 64556a8ba9efc4e780a2f9b82346ce81f7b0b42a Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 9 Dec 2022 21:53:47 -0800 Subject: [PATCH 14/34] make fixup --- src/transformers/__init__.py | 2 +- src/transformers/generation/__init__.py | 2 +- src/transformers/generation/utils.py | 7 +++++-- src/transformers/pipelines/text2text_generation.py | 8 ++++---- src/transformers/pipelines/text_generation.py | 10 +++++----- tests/generation/test_stopping_criteria.py | 2 +- tests/pipelines/test_pipelines_text_generation.py | 6 ++++-- 7 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 829bea5d05949..4f783fd689429 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -4013,9 +4013,9 @@ PhrasalConstraint, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, - StopTokenIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, + StopTokenIdStoppingCriteria, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index be020ede2c909..73a3ab70ce061 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -185,9 +185,9 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, - StopTokenIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, + StopTokenIdStoppingCriteria, validate_stopping_criteria, ) from .utils import ( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ebf2fc3d1c513..89a036f43d8f1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -60,9 +60,9 @@ from .stopping_criteria import ( MaxLengthCriteria, MaxTimeCriteria, - StopTokenIdStoppingCriteria, StoppingCriteria, StoppingCriteriaList, + StopTokenIdStoppingCriteria, validate_stopping_criteria, ) @@ -1516,7 +1516,10 @@ def generate( # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( - max_length=max_length, max_time=max_time, stop_token_ids=stop_token_ids, stopping_criteria=stopping_criteria + max_length=max_length, + max_time=max_time, + stop_token_ids=stop_token_ids, + stopping_criteria=stopping_criteria, ) # 9. go into different generation modes if is_greedy_gen_mode: diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 30a89d9c2cddf..c4db1ff034238 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -1,6 +1,6 @@ import enum -from typing import List import warnings +from typing import List from ..tokenization_utils import TruncationStrategy from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging @@ -116,10 +116,10 @@ def _sanitize_parameters( ) stop_token_id = _stop_token_ids[0] stop_token_ids.append(stop_token_id) - if 'stop_token_ids' not in generate_kwargs: - generate_kwargs['stop_token_ids'] = stop_token_ids + if "stop_token_ids" not in generate_kwargs: + generate_kwargs["stop_token_ids"] = stop_token_ids else: - generate_kwargs['stop_token_ids'].extend(stop_token_ids) + generate_kwargs["stop_token_ids"].extend(stop_token_ids) return preprocess_params, forward_params, postprocess_params diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 5ea08b428b124..f5044cd42490f 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,12 +1,12 @@ import enum -from typing import List import warnings +from typing import List from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING +from ..generation.stopping_criteria import StoppingCriteriaList, StopTokenIdStoppingCriteria from ..utils import add_end_docstrings, is_tf_available from .base import PIPELINE_INIT_ARGS, Pipeline -from ..generation.stopping_criteria import StoppingCriteriaList, StopTokenIdStoppingCriteria if is_tf_available(): @@ -160,10 +160,10 @@ def _sanitize_parameters( ) stop_token_id = _stop_token_ids[0] stop_token_ids.append(stop_token_id) - if 'stop_token_ids' not in generate_kwargs: - generate_kwargs['stop_token_ids'] = stop_token_ids + if "stop_token_ids" not in generate_kwargs: + generate_kwargs["stop_token_ids"] = stop_token_ids else: - generate_kwargs['stop_token_ids'].extend(stop_token_ids) + generate_kwargs["stop_token_ids"].extend(stop_token_ids) return preprocess_params, forward_params, postprocess_params diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index d69c800f70bfd..0b0ad566091fe 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -30,8 +30,8 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, - StopTokenIdStoppingCriteria, StoppingCriteriaList, + StopTokenIdStoppingCriteria, validate_stopping_criteria, ) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index ad303dadd6bc5..afcced4e65954 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -173,7 +173,7 @@ def test_stop_token_ids_stopping_criteria(self): def test_stop_tokens_stopping_criteria(self): prompt = """Hello I believe in""" - text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", stop_tokens=[' fe']) + text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", stop_tokens=[" fe"]) output = text_generator(prompt) self.assertEqual( output, @@ -181,7 +181,9 @@ def test_stop_tokens_stopping_criteria(self): ) prompt = """Hello I believe in""" - text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", stopping_criteria=[], stop_tokens=[' fe']) + text_generator = pipeline( + "text-generation", model="hf-internal-testing/tiny-random-gpt2", stopping_criteria=[], stop_tokens=[" fe"] + ) output = text_generator(prompt) self.assertEqual( output, From 6f0812d7b2a4d840a295e2f1410e36684eb1fc3c Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 9 Dec 2022 21:55:31 -0800 Subject: [PATCH 15/34] make repo-consistency --- src/transformers/utils/dummy_pt_objects.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 60083b42ad710..6253383bb4a92 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -234,21 +234,21 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class StopTokenIdStoppingCriteria(metaclass=DummyObject): +class StoppingCriteria(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class StoppingCriteria(metaclass=DummyObject): +class StoppingCriteriaList(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class StoppingCriteriaList(metaclass=DummyObject): +class StopTokenIdStoppingCriteria(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 70c2ddab2a527e3d2d62d7551c9e105f6c039f6b Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 18 Dec 2022 22:07:19 -0800 Subject: [PATCH 16/34] Add support for list of ints for eos_token_id inside generation/utils.py --- src/transformers/generation/utils.py | 59 ++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 89a036f43d8f1..d14d206840f67 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -569,11 +569,16 @@ def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: Optional[int], - eos_token_id: Optional[int], + eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id) + if isinstance(eos_token_id, int): + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id) + elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id in eos_token_id) + else: + raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: @@ -768,8 +773,9 @@ def _get_logits_processor( bad_words_ids: List[List[int]], min_length: int, max_length: int, - eos_token_id: int, + eos_token_id: Union[int, List[int]], forced_bos_token_id: int, + # TODO: This should be optionally List[int] forced_eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, @@ -919,7 +925,7 @@ def compute_transition_beam_scores( sequences: torch.Tensor, scores: Tuple[torch.Tensor], beam_indices: torch.Tensor, - eos_token_id: int = None, + eos_token_id: Union[int, List[int]] = None, ): """compute the transition probabilities of sequences given generation scores and beam indices""" @@ -1022,7 +1028,7 @@ def generate( force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None, @@ -1332,7 +1338,13 @@ def generate( "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - pad_token_id = eos_token_id + if isinstance(eos_token_id, int): + pad_token_id = eos_token_id + elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): + # Setting the first eos_token_id + pad_token_id = eos_token_id[0] + else: + raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1816,7 +1828,7 @@ def contrastive_search( logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2118,7 +2130,12 @@ def contrastive_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + if isinstance(eos_token_id, int): + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) + else: + raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2155,7 +2172,7 @@ def greedy_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2346,7 +2363,12 @@ def greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + if isinstance(eos_token_id, int): + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) + else: + raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2384,7 +2406,7 @@ def sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2597,7 +2619,12 @@ def sample( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + if isinstance(eos_token_id, int): + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) + else: + raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2635,7 +2662,7 @@ def beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2945,7 +2972,7 @@ def beam_sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3260,7 +3287,7 @@ def group_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3622,7 +3649,7 @@ def constrained_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, From b0cccd65eacfe4b30fbb0561e1a37af9bf70668c Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 18 Dec 2022 22:19:32 -0800 Subject: [PATCH 17/34] Instead of having if elses, cast the eos_token_id into a List[int] --- src/transformers/generation/utils.py | 54 ++++++++++++---------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d14d206840f67..efea4e34486eb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -574,11 +574,8 @@ def _prepare_attention_mask_for_generation( is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) if isinstance(eos_token_id, int): - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id) - elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id in eos_token_id) - else: - raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') + eos_token_id = [eos_token_id] + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: @@ -807,6 +804,8 @@ def _get_logits_processor( ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty forced_bos_token_id = ( forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id @@ -1331,20 +1330,18 @@ def generate( if eos_token_id is None and hasattr(self.config, "decoder"): eos_token_id = self.config.decoder.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if pad_token_id is None and eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - if isinstance(eos_token_id, int): - pad_token_id = eos_token_id - elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): - # Setting the first eos_token_id - pad_token_id = eos_token_id[0] - else: - raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id[0]} for open-end generation.") + # Setting the first eos_token_id + pad_token_id = eos_token_id[0] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1912,6 +1909,8 @@ def contrastive_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2130,12 +2129,7 @@ def contrastive_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - if isinstance(eos_token_id, int): - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) - else: - raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2272,6 +2266,8 @@ def greedy_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2363,12 +2359,7 @@ def greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - if isinstance(eos_token_id, int): - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) - else: - raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2525,6 +2516,8 @@ def sample( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2619,12 +2612,7 @@ def sample( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - if isinstance(eos_token_id, int): - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - elif isinstance(eos_token_id, list) and isinstance(eos_token_id[0], int): - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) - else: - raise ValueError(f'`eos_token_id` should be of type `int` or `List[int]`.') + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens == i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -3411,6 +3399,8 @@ def group_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -3780,6 +3770,8 @@ def constrained_beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( From fba33459970300de88438bc803601f307a02729b Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 18 Dec 2022 22:41:56 -0800 Subject: [PATCH 18/34] Add List[int] support for logits_process.py --- src/transformers/generation/logits_process.py | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 5f54008e16a84..721383c34e762 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -100,16 +100,18 @@ class MinLengthLogitsProcessor(LogitsProcessor): Args: min_length (`int`): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. """ - def __init__(self, min_length: int, eos_token_id: int): + def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): if not isinstance(min_length, int) or min_length < 0: raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]): + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") self.min_length = min_length self.eos_token_id = eos_token_id @@ -117,7 +119,8 @@ def __init__(self, min_length: int, eos_token_id: int): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len < self.min_length: - scores[:, self.eos_token_id] = -float("inf") + for i in self.eos_token_id: + scores[:, i] = -float("inf") return scores @@ -395,11 +398,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): List of list of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. """ - def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): + def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") @@ -413,7 +416,10 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int): f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." ) - bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + bad_words_ids = list(filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)) self.bad_words_id_length_1 = [] self.bad_words_id_length_greater_than_1 = [] for word in bad_words_ids: @@ -628,20 +634,23 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): Args: max_length (`int`): The maximum length of the sequence to be generated. - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the token to force as the last generated token when `max_length` is reached. """ - def __init__(self, max_length: int, eos_token_id: int): + def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): self.max_length = max_length + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len == self.max_length - 1: num_tokens = scores.shape[1] - scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") - scores[:, self.eos_token_id] = 0 + scores[:, [i for i in range(num_tokens) if i in self.eos_token_id]] = -float("inf") + for i in self.eos_token_id: + scores[:, i] = 0 return scores @@ -671,23 +680,26 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): exponential_decay_length_penalty (`tuple(int, float)`, *optional*): This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay - eos_token_id (`int`): + eos_token_id (`Union[int, List[int]]`): The id of the *end-of-sequence* token. input_ids_seq_length (`int`): The length of the input sequence. """ - def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int): + def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int): self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_factor = exponential_decay_length_penalty[1] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] if cur_len > self.regulation_start: - scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow( - self.regulation_factor, cur_len - self.regulation_start - ) + for i in self.eos_token_id: + scores[:, i] = scores[:, i] * pow( + self.regulation_factor, cur_len - self.regulation_start + ) return scores From 405f79cadfe725a11953d20058a69e55373d76e8 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 18 Dec 2022 22:45:43 -0800 Subject: [PATCH 19/34] add List[int] for beam_search.py --- src/transformers/generation/beam_search.py | 32 ++++++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index d22fbaf280dee..402db5f24e1e7 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -16,7 +16,7 @@ import warnings from abc import ABC, abstractmethod from collections import UserDict -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -212,7 +212,7 @@ def process( next_tokens: torch.LongTensor, next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor]: cur_len = input_ids.shape[-1] @@ -234,6 +234,9 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: if self.num_beams < len(beam_hyp): @@ -253,7 +256,7 @@ def process( ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): + if (eos_token_id is not None) and (next_token.item() in eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: @@ -307,11 +310,14 @@ def finalize( final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -376,7 +382,8 @@ def finalize( indices[i, : len(best_idx)] = torch.tensor(best_idx) if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] return UserDict( { @@ -491,7 +498,7 @@ def process( next_indices: torch.LongTensor, scores_for_all_vocab: torch.FloatTensor, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, ) -> Tuple[torch.Tensor]: r""" Args: @@ -549,6 +556,9 @@ def process( next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: if self.num_beams < len(beam_hyp): @@ -568,7 +578,7 @@ def process( ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token.item() == eos_token_id): + if (eos_token_id is not None) and (next_token.item() in eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size @@ -773,10 +783,13 @@ def finalize( final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -840,7 +853,8 @@ def finalize( for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < sent_max_len: - decoded[i, sent_lengths[i]] = eos_token_id + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] return UserDict( { From 8298e87384830e88023c10b78db1cd7cb7587def Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 18 Dec 2022 22:47:33 -0800 Subject: [PATCH 20/34] add List[int] for forced_eos_token_id --- src/transformers/generation/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index efea4e34486eb..1c83854e946c6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -772,8 +772,7 @@ def _get_logits_processor( max_length: int, eos_token_id: Union[int, List[int]], forced_bos_token_id: int, - # TODO: This should be optionally List[int] - forced_eos_token_id: int, + forced_eos_token_id: Union[int, List[int]], prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, num_beam_groups: int, @@ -1049,7 +1048,7 @@ def generate( output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[Union[int, List[int]]] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = False, exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, From 0062ddf82c3bfb451208a7cd4ff8f4233655ac9c Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 21:30:23 -0800 Subject: [PATCH 21/34] revert stop token id stopping criteria changes --- src/transformers/__init__.py | 1 - src/transformers/generation/__init__.py | 2 -- .../generation/stopping_criteria.py | 33 ------------------- .../pipelines/text2text_generation.py | 17 ---------- src/transformers/pipelines/text_generation.py | 18 ---------- src/transformers/utils/dummy_pt_objects.py | 7 ---- tests/generation/test_stopping_criteria.py | 18 ++-------- tests/generation/test_utils.py | 19 ----------- .../test_pipelines_text_generation.py | 31 ----------------- 9 files changed, 2 insertions(+), 144 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5962d0dfd7606..bf53737e96899 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -4147,7 +4147,6 @@ RepetitionPenaltyLogitsProcessor, StoppingCriteria, StoppingCriteriaList, - StopTokenIdStoppingCriteria, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 73a3ab70ce061..b1d8e8acad5f1 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -67,7 +67,6 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", - "StopTokenIdStoppingCriteria", "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", @@ -187,7 +186,6 @@ MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, - StopTokenIdStoppingCriteria, validate_stopping_criteria, ) from .utils import ( diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index a937a5cc1b585..7023fa9998c94 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -107,39 +107,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return time.time() - self.initial_timestamp > self.max_time -class StopTokenIdStoppingCriteria(StoppingCriteria): - """ - This class can be used to stop a generation once the model generates the specified token id. - - Args: - stop_token_id (`int`): - The stop token id. This corresponds to the token id of the token you would like to stop the generation on. - early_stopping (`bool`): - If set to `True`, the generation will stop once it detects at least one stop token id. - Otherwise, it will wait till it sees at least one stop token id for each example in the batch. - - Examples: - ```python - >>> stop_token_id = tokenizer.convert_tokens_to_ids('\n') - >>> stopping_criteria = StopTokenIdStoppingCriteria(stop_token_id) - >>> model.generate(text, stopping_criteria=[stopping_criteria]) - ``` - """ - - def __init__(self, stop_token_id: int, early_stopping: bool = False): - self.stop_token_id = stop_token_id - self.early_stopping = early_stopping - - @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: - batch_size = input_ids.shape[0] - how_many_finished = ((input_ids[:, -1] == self.stop_token_id).sum() > 0).sum() - if self.early_stopping: - return how_many_finished > 0 - else: - return how_many_finished == batch_size - - class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index c4db1ff034238..a9f73218ad54f 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -1,6 +1,5 @@ import enum import warnings -from typing import List from ..tokenization_utils import TruncationStrategy from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging @@ -79,7 +78,6 @@ def _sanitize_parameters( clean_up_tokenization_spaces=None, truncation=None, stop_sequence=None, - stop_tokens: List[str] = None, **generate_kwargs ): preprocess_params = {} @@ -106,21 +104,6 @@ def _sanitize_parameters( ) generate_kwargs["eos_token_id"] = stop_sequence_ids[0] - if stop_tokens is not None: - stop_token_ids = [] - for stop_token in stop_tokens: - _stop_token_ids = self.tokenizer.encode(stop_token, add_special_tokens=False) - if len(_stop_token_ids) > 1: - raise ValueError( - f"The stop_token {stop_token} has more than one associated token id: {_stop_token_ids}." - ) - stop_token_id = _stop_token_ids[0] - stop_token_ids.append(stop_token_id) - if "stop_token_ids" not in generate_kwargs: - generate_kwargs["stop_token_ids"] = stop_token_ids - else: - generate_kwargs["stop_token_ids"].extend(stop_token_ids) - return preprocess_params, forward_params, postprocess_params def check_inputs(self, input_length: int, min_length: int, max_length: int): diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index b3160b2ca96dd..b19d58f4ffbb4 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,10 +1,8 @@ import enum import warnings -from typing import List from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING -from ..generation.stopping_criteria import StoppingCriteriaList, StopTokenIdStoppingCriteria from ..utils import add_end_docstrings, is_tf_available from .base import PIPELINE_INIT_ARGS, Pipeline @@ -99,7 +97,6 @@ def _sanitize_parameters( prefix=None, handle_long_generation=None, stop_sequence=None, - stop_tokens: List[str] = None, **generate_kwargs ): preprocess_params = {} @@ -156,21 +153,6 @@ def _sanitize_parameters( ) generate_kwargs["eos_token_id"] = stop_sequence_ids[0] - if stop_tokens is not None: - stop_token_ids = [] - for stop_token in stop_tokens: - _stop_token_ids = self.tokenizer.encode(stop_token, add_special_tokens=False) - if len(_stop_token_ids) > 1: - raise ValueError( - f"The stop_token {stop_token} has more than one associated token id: {_stop_token_ids}." - ) - stop_token_id = _stop_token_ids[0] - stop_token_ids.append(stop_token_id) - if "stop_token_ids" not in generate_kwargs: - generate_kwargs["stop_token_ids"] = stop_token_ids - else: - generate_kwargs["stop_token_ids"].extend(stop_token_ids) - return preprocess_params, forward_params, postprocess_params # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 9bc533132a7b4..178a0b5ae6e55 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -248,13 +248,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class StopTokenIdStoppingCriteria(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - class TemperatureLogitsWarper(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 0b0ad566091fe..dfc5308359ffb 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import time import unittest @@ -31,18 +30,17 @@ MaxNewTokensCriteria, MaxTimeCriteria, StoppingCriteriaList, - StopTokenIdStoppingCriteria, validate_stopping_criteria, ) @require_torch class StoppingCriteriaTestCase(unittest.TestCase): - def _get_tensors(self, length, rng=None): + def _get_tensors(self, length): batch_size = 3 vocab_size = 250 - input_ids = ids_tensor((batch_size, length), vocab_size, rng=rng) + input_ids = ids_tensor((batch_size, length), vocab_size) scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length return input_ids, scores @@ -100,18 +98,6 @@ def test_max_time_criteria(self): criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) self.assertTrue(criteria(input_ids, scores)) - def test_stop_token_id_criteria(self): - input_ids, scores = self._get_tensors(5, rng=random.Random(42)) - - criteria = StopTokenIdStoppingCriteria(stop_token_id=5) - self.assertFalse(criteria(input_ids, scores)) - - criteria = StopTokenIdStoppingCriteria(stop_token_id=22, early_stopping=False) - self.assertFalse(criteria(input_ids, scores)) - - criteria = StopTokenIdStoppingCriteria(stop_token_id=22, early_stopping=True) - self.assertTrue(criteria(input_ids, scores)) - def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c35f6df8e8668..f1505f55aa528 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2110,25 +2110,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa [1, 18], ) - def test_stop_sequence_stopping_criteria(self): - - prompt = """Hello I believe in""" - generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") - output = generator(prompt) - self.assertEqual( - output, - [ - { - "generated_text": ( - "Hello I believe in in in number number number number number number number number number" - ) - } - ], - ) - - output = generator(prompt, stop_sequence=" number") - self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) - def test_custom_logits_processor(self): bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 73b863f6e7937..4de6f878dd22c 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -159,37 +159,6 @@ def test_stop_sequence_stopping_criteria(self): output = text_generator(prompt, stop_sequence=" fe") self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) - def test_stop_token_ids_stopping_criteria(self): - prompt = """Hello I believe in""" - text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2") - output = text_generator(prompt) - self.assertEqual( - output, - [{"generated_text": "Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"}], - ) - stop_token_ids = text_generator.tokenizer.encode(" fe") - output = text_generator(prompt, stop_token_ids=stop_token_ids) - self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) - - def test_stop_tokens_stopping_criteria(self): - prompt = """Hello I believe in""" - text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", stop_tokens=[" fe"]) - output = text_generator(prompt) - self.assertEqual( - output, - [{"generated_text": "Hello I believe in fe"}], - ) - - prompt = """Hello I believe in""" - text_generator = pipeline( - "text-generation", model="hf-internal-testing/tiny-random-gpt2", stopping_criteria=[], stop_tokens=[" fe"] - ) - output = text_generator(prompt) - self.assertEqual( - output, - [{"generated_text": "Hello I believe in fe"}], - ) - def run_pipeline_test(self, text_generator, _): model = text_generator.model tokenizer = text_generator.tokenizer From 6d69af0b6919769f4ec580345f009be17aa98ea2 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 21:33:44 -0800 Subject: [PATCH 22/34] make fixup --- src/transformers/generation/logits_process.py | 12 +++++++----- tests/generation/test_utils.py | 1 - 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 721383c34e762..ebba9c9a81c2d 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -419,7 +419,9 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - bad_words_ids = list(filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids)) + bad_words_ids = list( + filter(lambda bad_token_seq: all([bad_token_seq != [i] for i in eos_token_id]), bad_words_ids) + ) self.bad_words_id_length_1 = [] self.bad_words_id_length_greater_than_1 = [] for word in bad_words_ids: @@ -686,7 +688,9 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): The length of the input sequence. """ - def __init__(self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int): + def __init__( + self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], input_ids_seq_length: int + ): self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_factor = exponential_decay_length_penalty[1] if isinstance(eos_token_id, int): @@ -697,9 +701,7 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Float cur_len = input_ids.shape[-1] if cur_len > self.regulation_start: for i in self.eos_token_id: - scores[:, i] = scores[:, i] * pow( - self.regulation_factor, cur_len - self.regulation_start - ) + scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start) return scores diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f1505f55aa528..0277464c4e02d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -39,7 +39,6 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, VisionEncoderDecoderModel, - pipeline, top_k_top_p_filtering, ) from transformers.generation import ( From 9ad7689c60f41ecced344ac6d2ca20a5517b8ab3 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 21:43:07 -0800 Subject: [PATCH 23/34] fix tests --- src/transformers/generation/logits_process.py | 2 +- tests/generation/test_utils.py | 14 -------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index ebba9c9a81c2d..f4d467a5d5844 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -650,7 +650,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to cur_len = input_ids.shape[-1] if cur_len == self.max_length - 1: num_tokens = scores.shape[1] - scores[:, [i for i in range(num_tokens) if i in self.eos_token_id]] = -float("inf") + scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf") for i in self.eos_token_id: scores[:, i] = 0 return scores diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0277464c4e02d..b6ab8872e6b14 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3005,17 +3005,3 @@ def test_validate_generation_inputs(self): # However, valid model_kwargs are accepted valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)} model.generate(input_ids, **valid_model_kwargs) - - def test_stop_token_ids_stopping_criteria(self): - prompt = """Hello I believe in""" - gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - input_ids = gpt2_tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device) - - stop_token_ids = gpt2_tokenizer.encode(" fe") - self.assertEqual(stop_token_ids, [641]) - - output = gpt2_model.generate(input_ids=input_ids, stop_token_ids=stop_token_ids) - generated_text = gpt2_tokenizer.batch_decode(output) - - self.assertEqual(generated_text, ["Hello I believe in fe"]) From 4df3b468b6601f0440fba5950e24d35e4e5a5167 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 22:40:09 -0800 Subject: [PATCH 24/34] add eos_token_id to generation/utils.py and added tests test_utils.py --- src/transformers/generation/utils.py | 20 ++++- tests/generation/test_utils.py | 111 ++++++++++++++++++++++++++- 2 files changed, 126 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0db1005d95b89..da9413872d370 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1697,6 +1697,8 @@ def contrastive_search( stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -1919,7 +1921,7 @@ def contrastive_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2056,6 +2058,8 @@ def greedy_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2151,7 +2155,7 @@ def greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2309,6 +2313,8 @@ def sample( logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2407,7 +2413,7 @@ def sample( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -2565,6 +2571,8 @@ def beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -2889,6 +2897,8 @@ def beam_sample( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -3202,6 +3212,8 @@ def group_beam_search( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -3575,6 +3587,8 @@ def constrained_beam_search( warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b6ab8872e6b14..96c74b27a7e82 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -17,6 +17,8 @@ import inspect import unittest +import pytest + from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device @@ -90,8 +92,9 @@ def _get_input_ids_and_config(self): max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` - config.pad_token_id = config.eos_token_id - + if isinstance(config.eos_token_id, int): + config.eos_token_id = [config.eos_token_id] + config.pad_token_id = config.eos_token_id[0] # TransfoXL has no attention mask if "transfoxl" in config.__class__.__name__.lower(): attention_mask = None @@ -3005,3 +3008,107 @@ def test_validate_generation_inputs(self): # However, valid model_kwargs are accepted valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)} model.generate(input_ids, **valid_model_kwargs) + + def test_eos_token_id_int_and_list_greedy_search(self): + generation_kwargs = { + 'do_sample': False, + 'num_beams': 1, + } + expectation = 13 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 873 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [873] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + def test_eos_token_id_int_and_list_contrastive_search(self): + generation_kwargs = { + 'do_sample': False, + 'num_beams': 1, + 'penalty_alpha': 0.6, + 'top_k': 4, + } + expectation = 17 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 225 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + print('generated_tokens', generated_tokens) + print('tokenizer.batch_decode(generated_tokens)', tokenizer.batch_decode(generated_tokens)) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [225] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + def test_eos_token_id_int_and_list_top_k_top_sampling(self): + generation_kwargs = { + 'do_sample': True, + 'num_beams': 1, + 'top_p': 0.7, + 'top_k': 10, + 'temperature': 0.7, + } + expectation = 15 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 846 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + print('generated_tokens', generated_tokens) + print('tokenizer.batch_decode(generated_tokens)', tokenizer.batch_decode(generated_tokens)) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [846] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + + def test_eos_token_id_int_and_list_beam_search(self): + generation_kwargs = { + 'do_sample': False, + 'num_beams': 3, + } + expectation = 13 + + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = """Hello, my dog is cute and""" + tokens = tokenizer(text, return_tensors="pt") + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + torch.manual_seed(0) + eos_token_id = 873 + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + print('generated_tokens', generated_tokens) + print('tokenizer.batch_decode(generated_tokens)', tokenizer.batch_decode(generated_tokens)) + self.assertTrue(expectation == len(generated_tokens[0])) + + torch.manual_seed(0) + eos_token_id = [873] + generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) + self.assertTrue(expectation == len(generated_tokens[0])) + assert False From 96ccb522b330800575599b89a755cecd41aa4d44 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 22:50:06 -0800 Subject: [PATCH 25/34] add eos_token_id type hints and fix for pad tokens --- src/transformers/generation/utils.py | 29 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index da9413872d370..21c5f36e11529 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -575,11 +575,13 @@ def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: Optional[int], - eos_token_id: Optional[int], + eos_token_id: Optional[Union[int, List[int]]], ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: @@ -891,7 +893,7 @@ def compute_transition_beam_scores( sequences: torch.Tensor, scores: Tuple[torch.Tensor], beam_indices: torch.Tensor, - eos_token_id: int = None, + eos_token_id: Union[int, List[int]] = None, ): """compute the transition probabilities of sequences given generation scores and beam indices""" @@ -1154,10 +1156,13 @@ def generate( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] logger.warning( - f"Setting `pad_token_id` to `eos_token_id`:{generation_config.eos_token_id} for open-end generation." + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." ) - generation_config.pad_token_id = generation_config.eos_token_id + generation_config.pad_token_id = eos_token_id # 3. Define model inputs # inputs_tensor has to be defined @@ -1613,7 +1618,7 @@ def contrastive_search( logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -1958,7 +1963,7 @@ def greedy_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2193,7 +2198,7 @@ def sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2451,7 +2456,7 @@ def beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -2767,7 +2772,7 @@ def beam_sample( logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3088,7 +3093,7 @@ def group_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -3456,7 +3461,7 @@ def constrained_beam_search( stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, From dd34d52f040d2b6744e658060405b9eaa2bdbf14 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 22:52:45 -0800 Subject: [PATCH 26/34] add comments --- src/transformers/generation/configuration_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a477ebe4203c4..a16e8eaa6e1cd 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -142,7 +142,7 @@ class GenerationConfig(PushToHubMixin): The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target language token. - forced_eos_token_id (`int`, *optional*, defaults to `model.config.forced_eos_token_id`): + forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`): The id of the token to force as the last generated token when `max_length` is reached. remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`): Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. @@ -152,10 +152,10 @@ class GenerationConfig(PushToHubMixin): generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay suppress_tokens (`List[int]`, *optional*): - A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set their + A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled. begin_suppress_tokens (`List[int]`, *optional*): - A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens` logit + A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. forced_decoder_ids (`List[List[int]]`, *optional*): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be @@ -183,7 +183,7 @@ class GenerationConfig(PushToHubMixin): The id of the *padding* token. bos_token_id (`int`, *optional*): The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): + eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. > Generation parameters exclusive to encoder-decoder models From fc789bf39abdf8a8fa5386c0761740603a18203f Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 22:58:38 -0800 Subject: [PATCH 27/34] remove some prints and remove forced false test --- src/transformers/generation/utils.py | 4 +--- tests/generation/test_utils.py | 33 +++++++++++----------------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 21c5f36e11529..a464880455c66 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1159,9 +1159,7 @@ def generate( eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] - logger.warning( - f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." - ) + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id # 3. Define model inputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 96c74b27a7e82..0250461aa66e3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3011,8 +3011,8 @@ def test_validate_generation_inputs(self): def test_eos_token_id_int_and_list_greedy_search(self): generation_kwargs = { - 'do_sample': False, - 'num_beams': 1, + "do_sample": False, + "num_beams": 1, } expectation = 13 @@ -3034,10 +3034,10 @@ def test_eos_token_id_int_and_list_greedy_search(self): def test_eos_token_id_int_and_list_contrastive_search(self): generation_kwargs = { - 'do_sample': False, - 'num_beams': 1, - 'penalty_alpha': 0.6, - 'top_k': 4, + "do_sample": False, + "num_beams": 1, + "penalty_alpha": 0.6, + "top_k": 4, } expectation = 17 @@ -3050,8 +3050,6 @@ def test_eos_token_id_int_and_list_contrastive_search(self): torch.manual_seed(0) eos_token_id = 225 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) - print('generated_tokens', generated_tokens) - print('tokenizer.batch_decode(generated_tokens)', tokenizer.batch_decode(generated_tokens)) self.assertTrue(expectation == len(generated_tokens[0])) torch.manual_seed(0) @@ -3061,11 +3059,11 @@ def test_eos_token_id_int_and_list_contrastive_search(self): def test_eos_token_id_int_and_list_top_k_top_sampling(self): generation_kwargs = { - 'do_sample': True, - 'num_beams': 1, - 'top_p': 0.7, - 'top_k': 10, - 'temperature': 0.7, + "do_sample": True, + "num_beams": 1, + "top_p": 0.7, + "top_k": 10, + "temperature": 0.7, } expectation = 15 @@ -3078,8 +3076,6 @@ def test_eos_token_id_int_and_list_top_k_top_sampling(self): torch.manual_seed(0) eos_token_id = 846 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) - print('generated_tokens', generated_tokens) - print('tokenizer.batch_decode(generated_tokens)', tokenizer.batch_decode(generated_tokens)) self.assertTrue(expectation == len(generated_tokens[0])) torch.manual_seed(0) @@ -3089,8 +3085,8 @@ def test_eos_token_id_int_and_list_top_k_top_sampling(self): def test_eos_token_id_int_and_list_beam_search(self): generation_kwargs = { - 'do_sample': False, - 'num_beams': 3, + "do_sample": False, + "num_beams": 3, } expectation = 13 @@ -3103,12 +3099,9 @@ def test_eos_token_id_int_and_list_beam_search(self): torch.manual_seed(0) eos_token_id = 873 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) - print('generated_tokens', generated_tokens) - print('tokenizer.batch_decode(generated_tokens)', tokenizer.batch_decode(generated_tokens)) self.assertTrue(expectation == len(generated_tokens[0])) torch.manual_seed(0) eos_token_id = [873] generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) - assert False From e9bd3a95a1fdc9b2281d216f99244329d0e9065e Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 23:48:38 -0800 Subject: [PATCH 28/34] fix --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a464880455c66..0df75ceb3994a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -581,7 +581,7 @@ def _prepare_attention_mask_for_generation( is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id in eos_token_id) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: From b25f05230c996b473c2c57afeef97db962ad5b1d Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 23:51:29 -0800 Subject: [PATCH 29/34] put back test_stop_sequence_stopping_criteria --- tests/generation/test_utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0250461aa66e3..597df73ab0d30 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -19,7 +19,7 @@ import pytest -from transformers import is_torch_available +from transformers import is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device from ..test_modeling_common import floats_tensor, ids_tensor @@ -2112,6 +2112,25 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa [1, 18], ) + def test_stop_sequence_stopping_criteria(self): + + prompt = """Hello I believe in""" + generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") + output = generator(prompt) + self.assertEqual( + output, + [ + { + "generated_text": ( + "Hello I believe in in in number number number number number number number number number" + ) + } + ], + ) + + output = generator(prompt, stop_sequence=" number") + self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) + def test_custom_logits_processor(self): bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" From 33e49ef65a0252f8c191b3a6c990952fb703b9ef Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Fri, 30 Dec 2022 23:53:50 -0800 Subject: [PATCH 30/34] remove unused import and make fixup --- tests/generation/test_utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 597df73ab0d30..dfe8be0efd7e8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -17,8 +17,6 @@ import inspect import unittest -import pytest - from transformers import is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device @@ -2118,14 +2116,14 @@ def test_stop_sequence_stopping_criteria(self): generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") output = generator(prompt) self.assertEqual( - output, - [ - { - "generated_text": ( - "Hello I believe in in in number number number number number number number number number" - ) - } - ], + output, + [ + { + "generated_text": ( + "Hello I believe in in in number number number number number number number number number" + ) + } + ], ) output = generator(prompt, stop_sequence=" number") From 642856274746fe525020668ce52d3e37c9b2b6b6 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Sun, 1 Jan 2023 20:46:08 -0800 Subject: [PATCH 31/34] add a none check --- src/transformers/generation/logits_process.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index f4d467a5d5844..bee30fc614507 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -416,6 +416,8 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." ) + if eos_token_id is None: + eos_token_id = [] if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] From a60afdb432b39477a1f01cddcdf0a5107d7279dc Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Tue, 3 Jan 2023 11:25:10 -0800 Subject: [PATCH 32/34] update docstring --- src/transformers/generation/beam_search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index 402db5f24e1e7..46846b0c32d9b 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -42,7 +42,7 @@ Beam indices indicating to which beam hypothesis the `next_tokens` correspond. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): + eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Return: @@ -74,7 +74,7 @@ The beam indices indicating to which beam the `final_beam_tokens` shall be added. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): + eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Return: @@ -519,7 +519,7 @@ def process( The scores of all tokens in the vocabulary for each of the beam hypotheses. pad_token_id (`int`, *optional*): The id of the *padding* token. - eos_token_id (`int`, *optional*): + eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Return: From f402cb044283a4c2eccbe2a48f30f12b0a6f2aa9 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Tue, 3 Jan 2023 11:29:57 -0800 Subject: [PATCH 33/34] add more docstring for list ints --- src/transformers/generation/beam_search.py | 6 +++--- src/transformers/generation/configuration_utils.py | 3 ++- src/transformers/generation/logits_process.py | 7 ++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index 46846b0c32d9b..6e4f9cb936e8f 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -43,7 +43,7 @@ pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. Return: `UserDict`: A dictionary composed of the fields as defined above: @@ -75,7 +75,7 @@ pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. Return: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. @@ -520,7 +520,7 @@ def process( pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. Return: `UserDict`: A dictionary composed of the fields as defined above: diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a16e8eaa6e1cd..3181bf3e8bfc0 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -144,6 +144,7 @@ class GenerationConfig(PushToHubMixin): language token. forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`): The id of the token to force as the last generated token when `max_length` is reached. + Optionally, use a list to set multiple *end-of-sequence* tokens. remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`): Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation. @@ -184,7 +185,7 @@ class GenerationConfig(PushToHubMixin): bos_token_id (`int`, *optional*): The id of the *beginning-of-sequence* token. eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. > Generation parameters exclusive to encoder-decoder models diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index bee30fc614507..0b8b2a3876789 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -101,7 +101,7 @@ class MinLengthLogitsProcessor(LogitsProcessor): min_length (`int`): The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): @@ -399,7 +399,7 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): @@ -640,6 +640,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): The maximum length of the sequence to be generated. eos_token_id (`Union[int, List[int]]`): The id of the token to force as the last generated token when `max_length` is reached. + Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): @@ -685,7 +686,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. input_ids_seq_length (`int`): The length of the input sequence. """ From 84b8e1d4d43098148e6c8111301f51acb7dff126 Mon Sep 17 00:00:00 2001 From: Motoki Wu Date: Tue, 3 Jan 2023 11:51:46 -0800 Subject: [PATCH 34/34] make fixup --- src/transformers/generation/configuration_utils.py | 6 +++--- src/transformers/generation/logits_process.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 3181bf3e8bfc0..a01222c8b41ec 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -143,8 +143,8 @@ class GenerationConfig(PushToHubMixin): multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target language token. forced_eos_token_id (`Union[int, List[int]]`, *optional*, defaults to `model.config.forced_eos_token_id`): - The id of the token to force as the last generated token when `max_length` is reached. - Optionally, use a list to set multiple *end-of-sequence* tokens. + The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a + list to set multiple *end-of-sequence* tokens. remove_invalid_values (`bool`, *optional*, defaults to `model.config.remove_invalid_values`): Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation. @@ -185,7 +185,7 @@ class GenerationConfig(PushToHubMixin): bos_token_id (`int`, *optional*): The id of the *beginning-of-sequence* token. eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. > Generation parameters exclusive to encoder-decoder models diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 0b8b2a3876789..0012354f45c3d 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -639,8 +639,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): max_length (`int`): The maximum length of the sequence to be generated. eos_token_id (`Union[int, List[int]]`): - The id of the token to force as the last generated token when `max_length` is reached. - Optionally, use a list to set multiple *end-of-sequence* tokens. + The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a + list to set multiple *end-of-sequence* tokens. """ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):