Skip to content

Commit

Permalink
🚨🚨 TextGenerationPipeline: rely on the tokenizer default kwargs (#31747)
Browse files Browse the repository at this point in the history
* rely on the tokenizer default kwargs

* fix a few tests
  • Loading branch information
gante committed Jul 2, 2024
1 parent a970195 commit 82486e5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
28 changes: 15 additions & 13 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,31 +266,33 @@ def preprocess(
prompt_text,
prefix="",
handle_long_generation=None,
add_special_tokens=False,
add_special_tokens=None,
truncation=None,
padding=False,
padding=None,
max_length=None,
**generate_kwargs,
):
if isinstance(prompt_text, Chat):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {}
for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]:
if locals()[tokenizer_kwarg_name] is not None:
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
truncation=truncation,
padding=padding,
max_length=max_length,
add_generation_prompt=True,
return_dict=True,
return_tensors=self.framework,
**tokenizer_kwargs,
)
else:
inputs = self.tokenizer(
prefix + prompt_text,
truncation=truncation,
padding=padding,
max_length=max_length,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
)
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {}
for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]:
if locals()[tokenizer_kwarg_name] is not None:
tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name]
inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)

inputs["prompt_text"] = prompt_text

if handle_long_generation == "hole":
Expand Down
13 changes: 4 additions & 9 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,24 +2087,19 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
[1, 18],
)

# TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in"""
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
output = generator(prompt)
self.assertEqual(
output,
[
{
"generated_text": (
"Hello I believe in in in number number number number number number number number number"
)
}
],
[{"generated_text": ("Hello I believe in we we we we we we we we we")}],
)

output = generator(prompt, stop_sequence=" number")
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
output = generator(prompt, stop_sequence=" we")
self.assertEqual(output, [{"generated_text": "Hello I believe in we"}])

def test_generate_non_nlp_input_ids_as_kwarg(self):
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def run_pipeline_test(self, text_generator, _):
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
else:
with self.assertRaises((ValueError, AssertionError)):
outputs = text_generator("")
outputs = text_generator("", add_special_tokens=False)

if text_generator.framework == "tf":
# TF generation does not support max_new_tokens, and it's impossible
Expand Down

0 comments on commit 82486e5

Please sign in to comment.