Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stop sequence to text generation pipeline #18444

Merged
merged 13 commits into from
Sep 30, 2022

Conversation

KMFODA
Copy link
Contributor

@KMFODA KMFODA commented Aug 3, 2022

What does this PR do?

As per the conversation in #17562, creating this draft PR to add a stop_sequence option to text generation pipelines.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@Narsil

Models:

All

Library:

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 3, 2022

The documentation is not available anymore as the PR was closed or merged.

@KMFODA
Copy link
Contributor Author

KMFODA commented Aug 8, 2022

Hey @Narsil. I've managed to get this working for greedy decoding and multimodal sampling. For beam-search, what would be the best approach to deal with a stop_sequence? I've assumed that if a stop_sequence appears in any of the beams then we stop the generation process.

Should it instead be that we wait until each beam reaches the stop_sequence or any other stopping criteria before stopping the generation process?

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I pinged other maintainers to get an advice.

The main thing is that EOS is already handled without a stopping criteria so I don't know if we should add the new StoppingCriteria.

Also we should add some simple tests.

Ideally just set up a random model from hf-internal-testing, generate 5 tokens, look at the results and use token 3 as the new eos_token_id, decode it to get it as string, then reuse the generation with generate(..., stop_sequence='xx') and verify we stopped at token 3.

(We can leave the first steps with checks just so that the readers of the test can understand why we're supposed to stop at token 3).

src/transformers/pipelines/text2text_generation.py Outdated Show resolved Hide resolved
stop_sequence
)
if len(stop_sequence_ids) > 3:
warnings.warn(f"Stopping on a multiple token sequence is not yet supported on transformers. The first token of the stop sequence will be used as the stop sequence string in the interim.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great message.

src/transformers/pipelines/text2text_generation.py Outdated Show resolved Hide resolved
Comment on lines 112 to 118
if stop_sequence is not None:
stop_sequence_ids = self.tokenizer.encode(
stop_sequence
)
if len(stop_sequence_ids) > 1:
warnings.warn(f"Stopping on a multiple token sequence is not yet supported on transformers. The first token of the stop sequence will be used as the stop sequence string in the interim.")
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think implementing either in generate or here is enough.
We shouldn't try to implement it everywhere.

@gante @patrickvonplaten Are you ok if it's directly included in generate (otherwise we can keep it just in the pipeline).

src/transformers/pipelines/text_generation.py Outdated Show resolved Hide resolved

@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return sum([self.eos_token_id in i for i in input_ids]) == input_ids.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is wrong and the following one is correct, no ?

Copy link
Contributor Author

@KMFODA KMFODA Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Narsil thanks for the helpful comments. And apologies for the few quality errors I was planning on addressing those after we decided on the strategy for eos_token_id.

Here specifically I had 2 different returns because I was experimenting with different approaches to stop a beam search. One stops if any of the beams reach the eos_token_id the other waits for all of them to reach the eos_token_id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I think beam search should wait on all beams being "done". (And you keep the eos ones has long as their score allows.

I didn't touch enough the beam search code to be sure how to handle that.
Also I think you only need to check then final tokens to prevent looping over the whole input_ids.

input_ids[:, -1] == eos_token_id no ?

Copy link
Member

@gante gante Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit more complex than that -- we can only check new tokens, but different batch sequences may generate eos_token_id at different times.

The generate functions use an auxiliary variable to keep track of which members of the batch have already finished, see here

In any case, I'd recommend doing this change in a separate PR :) In addition to adding the stopping criteria, we also need to remove the existing equivalent code from generate

@Narsil
Copy link
Contributor

Narsil commented Aug 8, 2022

Should it instead be that we wait until each beam reaches the stop_sequence or any other stopping criteria before stopping the generation process?

@KMFODA I think eos_token_id is already handled for beam search, see my comment on the StoppingCriteria.

I will let others comment on the best way to do this in .generate but I think we don't need the criteria, just let eos_token_id regular logic apply (it's handled separately from StoppingCriteria).

@Narsil
Copy link
Contributor

Narsil commented Aug 8, 2022

For the tests removing the breakpoint should help then for code quality.

pip install -e .[quality]
make fixup

Should do the trick.

@gante
Copy link
Member

gante commented Aug 8, 2022

@Narsil @KMFODA I'm in favor of moving it to a StoppingCriteria, so that all conditions that can terminate generation fall under the same class. However, it should be noted that it is not a requirement to complete the issue, i.e. to add a stop sequence to the text generation pipeline :P

It is already implemented on the multiple generation strategies (e.g. here for greedy search). Also, the existing implementation is different from the current PR -- the existing implementation only checks whether the eos_token is present in newly generated tokens. This is because models like GPT-2 often set pad_token_id to eos_token_id, and we don't want the pad tokens to trigger this condition.

@KMFODA
Copy link
Contributor Author

KMFODA commented Aug 9, 2022

Thanks @Narsil @gante. Okay so for the sake of deploying iteratively I've removed the eos_token_id from the StoppingCriteria and will add it as a separate PR.

I've added a test for the stop_sequence being fed in at the pipeline level. When @Narsil's comment around wether the stop sequence should be handled in the pipeline or in the generation_kwargs is addressed I can alter this test accordingly.

@KMFODA KMFODA marked this pull request as ready for review August 9, 2022 07:19
Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

I think if the test needs to change if it's going to be in generate testing files. to only use generate.

We should implement stop_sequence only once (probably in generate) but we could have 2 tests if you want to test the full pipeline too. (Probably in tests/pipelines/test_pipelines_text_generation.py for instance.)


@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return sum([self.eos_token_id in i for i in input_ids]) == input_ids.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I think beam search should wait on all beams being "done". (And you keep the eos ones has long as their score allows.

I didn't touch enough the beam search code to be sure how to handle that.
Also I think you only need to check then final tokens to prevent looping over the whole input_ids.

input_ids[:, -1] == eos_token_id no ?

Comment on lines 1980 to 1985
def test_stop_sequence_stopping_criteria(self):
prompt = """Hello I believe in"""
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
output = generator(prompt, stop_sequence=" number")
self.assertEqual(output[0]["generated_text"].split()[-1], "number")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_stop_sequence_stopping_criteria(self):
prompt = """Hello I believe in"""
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
output = generator(prompt, stop_sequence=" number")
self.assertEqual(output[0]["generated_text"].split()[-1], "number")
def test_stop_sequence_stopping_criteria(self):
prompt = """Hello I believe in"""
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
output = generator(prompt)
self.assertEqual(output, [{'generated_text': 'Hello I believe in in in number number number number number number number number number'}])
output = generator(prompt, stop_sequence=" number")
self.assertEqual(output, [{'generated_text': 'Hello I believe in in in number'}])

I think this formulation conveys the intent of the test a tiny bit better.
Also since you were only testing the last generated token, if we deactivated the whole option your test would still pass since the model just generates number all the time.
Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense. I'll change it to that.

@KMFODA
Copy link
Contributor Author

KMFODA commented Aug 10, 2022

We should implement stop_sequence only once (probably in generate) but we could have 2 tests if you want to test the full pipeline too. (Probably in tests/pipelines/test_pipelines_text_generation.py for instance.)

If we were to move stop_sequence to be in generate wouldn't we have to tokenise it first. In that case what's the reasoning behind feeding it as a stop_sequence instead of a eos_token_id?

@Narsil
Copy link
Contributor

Narsil commented Aug 12, 2022

If we were to move stop_sequence to be in generate wouldn't we have to tokenise it first. In that case what's the reasoning behind feeding it as a stop_sequence instead of a eos_token_id?

You're entirely right, oversight on my part. eos_token_id already does the job. So we just need to implement stop_sequence in the pipeline to tokenize the stop_sequence and produce the eos_token_id and just feed it to generate.
So no additional code in generate should be needed actually.

Sorry, failed to see that.

@KMFODA
Copy link
Contributor Author

KMFODA commented Aug 15, 2022

No problem I've just moved the stop_sequence back to the pipeline function and added the tests you requested in the tests/pipelines/test_pipelines_text_generation.py folder. This should make this PR ready for review now.

When I was playing with the stop_sequence though I found that sometime when I add a specific stop_sequence the output changes and avoids mentioning the word entirely. I don't have live examples now but I just wanted to check if this is normal behaviour? If not I can find examples on public models and share it in a different issue.

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@@ -1063,7 +1063,7 @@ def generate(
exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`):
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
where penalty starts and `decay_factor` represents the factor of exponential decay
where penalty starts and `decay_factor` represents the factor of exponential decays
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think without an s is actually better, no ? There's only a single decay.

@@ -147,6 +147,24 @@ def get_test_pipeline(self, model, tokenizer, feature_extractor):
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
return text_generator, ["This is a test", "Another test"]

def test_stop_sequence_stopping_criteria(self):
prompt = """Hello I believe in"""
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bart is a seq2seq model so it will fail.

You can use https://huggingface.co/hf-internal-testing/tiny-random-gpt2 instead I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ +1

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the two comments I added and the failing tests, LGTM as well 👍

@@ -147,6 +147,24 @@ def get_test_pipeline(self, model, tokenizer, feature_extractor):
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
return text_generator, ["This is a test", "Another test"]

def test_stop_sequence_stopping_criteria(self):
prompt = """Hello I believe in"""
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ +1

@@ -107,6 +107,24 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
return time.time() - self.initial_timestamp > self.max_time


class EndOfStringCriteria(StoppingCriteria):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is not used anywhere, I'd suggest adding this class in a follow-up PR, where we implement it and use it instead of the current logic for the eos token :)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@patrickvonplaten
Copy link
Contributor

@KMFODA I think your PR is almost ready to be merged! Would you like to try to fix the final problems and apply the review suggestions? :-)

@KMFODA
Copy link
Contributor Author

KMFODA commented Sep 28, 2022

Hey @patrickvonplaten. My apologies I was out sick over the past month. I worked on the suggestions now. Hopefully this should be good to merge now but if not let me know!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! @gante I let you merge the PR :-)

@gante
Copy link
Member

gante commented Sep 30, 2022

I'm happy with the PR, except for the EndOfStringCriteria class -- it is not being used, and it is not a good practice to add unused classes/functions.

@KMFODA can you remove it for now, and perhaps reintroduce it in a follow-up PR (with use cases)? :)

@KMFODA
Copy link
Contributor Author

KMFODA commented Sep 30, 2022

Hi @gante yes of course. I had removed it locally but somehow the changes didn't push through with one of the commits. Forced changed it now. Hopefully that looks good now :).

@gante gante merged commit e396358 into huggingface:main Sep 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants