-
Notifications
You must be signed in to change notification settings - Fork 25.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add stop sequence to text generation pipeline #18444
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Hey @Narsil. I've managed to get this working for greedy decoding and multimodal sampling. For beam-search, what would be the best approach to deal with a stop_sequence? I've assumed that if a stop_sequence appears in any of the beams then we stop the generation process. Should it instead be that we wait until each beam reaches the stop_sequence or any other stopping criteria before stopping the generation process? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I pinged other maintainers to get an advice.
The main thing is that EOS is already handled without a stopping criteria so I don't know if we should add the new StoppingCriteria
.
Also we should add some simple tests.
Ideally just set up a random model from hf-internal-testing, generate 5 tokens, look at the results and use token 3
as the new eos_token_id, decode it to get it as string, then reuse the generation with generate(..., stop_sequence='xx')
and verify we stopped at token 3.
(We can leave the first steps with checks just so that the readers of the test can understand why we're supposed to stop at token 3
).
stop_sequence | ||
) | ||
if len(stop_sequence_ids) > 3: | ||
warnings.warn(f"Stopping on a multiple token sequence is not yet supported on transformers. The first token of the stop sequence will be used as the stop sequence string in the interim.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great message.
if stop_sequence is not None: | ||
stop_sequence_ids = self.tokenizer.encode( | ||
stop_sequence | ||
) | ||
if len(stop_sequence_ids) > 1: | ||
warnings.warn(f"Stopping on a multiple token sequence is not yet supported on transformers. The first token of the stop sequence will be used as the stop sequence string in the interim.") | ||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think implementing either in generate
or here is enough.
We shouldn't try to implement it everywhere.
@gante @patrickvonplaten Are you ok if it's directly included in generate
(otherwise we can keep it just in the pipeline).
|
||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) | ||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | ||
return sum([self.eos_token_id in i for i in input_ids]) == input_ids.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line is wrong and the following one is correct, no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Narsil thanks for the helpful comments. And apologies for the few quality errors I was planning on addressing those after we decided on the strategy for eos_token_id.
Here specifically I had 2 different returns because I was experimenting with different approaches to stop a beam search. One stops if any of the beams reach the eos_token_id the other waits for all of them to reach the eos_token_id.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. I think beam search should wait on all beams being "done". (And you keep the eos ones has long as their score allows.
I didn't touch enough the beam search code to be sure how to handle that.
Also I think you only need to check then final tokens to prevent looping over the whole input_ids
.
input_ids[:, -1] == eos_token_id
no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit more complex than that -- we can only check new tokens, but different batch sequences may generate eos_token_id
at different times.
The generate functions use an auxiliary variable to keep track of which members of the batch have already finished, see here
In any case, I'd recommend doing this change in a separate PR :) In addition to adding the stopping criteria, we also need to remove the existing equivalent code from generate
@KMFODA I think I will let others comment on the best way to do this in |
For the tests removing the breakpoint should help then for code quality.
Should do the trick. |
@Narsil @KMFODA I'm in favor of moving it to a It is already implemented on the multiple generation strategies (e.g. here for greedy search). Also, the existing implementation is different from the current PR -- the existing implementation only checks whether the |
Thanks @Narsil @gante. Okay so for the sake of deploying iteratively I've removed the I've added a test for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me.
I think if the test needs to change if it's going to be in generate
testing files. to only use generate
.
We should implement stop_sequence
only once (probably in generate
) but we could have 2 tests if you want to test the full pipeline too. (Probably in tests/pipelines/test_pipelines_text_generation.py
for instance.)
|
||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) | ||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | ||
return sum([self.eos_token_id in i for i in input_ids]) == input_ids.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. I think beam search should wait on all beams being "done". (And you keep the eos ones has long as their score allows.
I didn't touch enough the beam search code to be sure how to handle that.
Also I think you only need to check then final tokens to prevent looping over the whole input_ids
.
input_ids[:, -1] == eos_token_id
no ?
def test_stop_sequence_stopping_criteria(self): | ||
prompt = """Hello I believe in""" | ||
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") | ||
output = generator(prompt, stop_sequence=" number") | ||
self.assertEqual(output[0]["generated_text"].split()[-1], "number") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_stop_sequence_stopping_criteria(self): | |
prompt = """Hello I believe in""" | |
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") | |
output = generator(prompt, stop_sequence=" number") | |
self.assertEqual(output[0]["generated_text"].split()[-1], "number") | |
def test_stop_sequence_stopping_criteria(self): | |
prompt = """Hello I believe in""" | |
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") | |
output = generator(prompt) | |
self.assertEqual(output, [{'generated_text': 'Hello I believe in in in number number number number number number number number number'}]) | |
output = generator(prompt, stop_sequence=" number") | |
self.assertEqual(output, [{'generated_text': 'Hello I believe in in in number'}]) |
I think this formulation conveys the intent of the test a tiny bit better.
Also since you were only testing the last generated token, if we deactivated the whole option your test would still pass since the model just generates number
all the time.
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes sense. I'll change it to that.
If we were to move |
You're entirely right, oversight on my part. Sorry, failed to see that. |
No problem I've just moved the stop_sequence back to the pipeline function and added the tests you requested in the When I was playing with the stop_sequence though I found that sometime when I add a specific stop_sequence the output changes and avoids mentioning the word entirely. I don't have live examples now but I just wanted to check if this is normal behaviour? If not I can find examples on public models and share it in a different issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
src/transformers/generation_utils.py
Outdated
@@ -1063,7 +1063,7 @@ def generate( | |||
exponential_decay_length_penalty (`tuple(int, float)`, *optional*, defaults to `model.config.exponential_decay_length_penalty`): | |||
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been | |||
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates | |||
where penalty starts and `decay_factor` represents the factor of exponential decay | |||
where penalty starts and `decay_factor` represents the factor of exponential decays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think without an s
is actually better, no ? There's only a single decay.
@@ -147,6 +147,24 @@ def get_test_pipeline(self, model, tokenizer, feature_extractor): | |||
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer) | |||
return text_generator, ["This is a test", "Another test"] | |||
|
|||
def test_stop_sequence_stopping_criteria(self): | |||
prompt = """Hello I believe in""" | |||
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bart is a seq2seq model so it will fail.
You can use https://huggingface.co/hf-internal-testing/tiny-random-gpt2
instead I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ +1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the two comments I added and the failing tests, LGTM as well 👍
@@ -147,6 +147,24 @@ def get_test_pipeline(self, model, tokenizer, feature_extractor): | |||
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer) | |||
return text_generator, ["This is a test", "Another test"] | |||
|
|||
def test_stop_sequence_stopping_criteria(self): | |||
prompt = """Hello I believe in""" | |||
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ +1
@@ -107,6 +107,24 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa | |||
return time.time() - self.initial_timestamp > self.max_time | |||
|
|||
|
|||
class EndOfStringCriteria(StoppingCriteria): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it is not used anywhere, I'd suggest adding this class in a follow-up PR, where we implement it and use it instead of the current logic for the eos token :)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@KMFODA I think your PR is almost ready to be merged! Would you like to try to fix the final problems and apply the review suggestions? :-) |
Hey @patrickvonplaten. My apologies I was out sick over the past month. I worked on the suggestions now. Hopefully this should be good to merge now but if not let me know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! @gante I let you merge the PR :-)
I'm happy with the PR, except for the @KMFODA can you remove it for now, and perhaps reintroduce it in a follow-up PR (with use cases)? :) |
Hi @gante yes of course. I had removed it locally but somehow the changes didn't push through with one of the commits. Forced changed it now. Hopefully that looks good now :). |
What does this PR do?
As per the conversation in #17562, creating this draft PR to add a stop_sequence option to text generation pipelines.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@Narsil
Models:
All
Library: