-
Notifications
You must be signed in to change notification settings - Fork 301
Stop generation once end_token_id is seen #769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stop generation once end_token_id is seen #769
Conversation
No unit tests added because the user-side behavior is unchanged. Also I will mark this as ready after merging the cache pr. |
ccc4116
to
8d4a8a8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One meta comment on approach.
Another thing we should keep in mind is that this is I think precisely where left padding will become more efficient.
If you left pad your prompt originally, you start generating for all prompts right away, so it's more likely they will all "die" earlier. If you right pad, you might not even start generating on some sequences in the batch or a while.
Not something we need to solve on this PR, but something to think about!
@mattdangerw Yea, actually I have tried left padding earlier. My findings - Left padding does not work well with GPT2 and other models using absolute positional embedding. In my experiment, the generated text becomes chaotic when left padding is applied. |
I actually think this would be bugs in the attention mask and position embedding setup (both of which are complex in the left pad setup!). But if you do everything correctly, the computation is exactly the same as I understand (e.g. greedy search output will one to one identical). I can try to put together a colab with huggingface showing this. |
One thing that might be worth noting in the left pad setup, is your really need to switch to a gather op for the position embedding, because your indices for the position embedding start varying per sample. But overall, I am totally down to look at that as a follow up. Just wanted to point out a place where we are starting to leave performance on the table. |
@mattdangerw Yea, I also suspected my code was buggy, and it was based off a much earlier version, so definitely worth a second trial. |
8d4a8a8
to
e292f53
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would need tests too before we land.
output_dim=self.feature_size, | ||
), | ||
keras.layers.Dense(self.vocab_size), | ||
keras.layers.Softmax(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because we by default set from_logits=True
(was False earlier for mysterious reason), so I am aligning the unit test with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! approving!
Note that github actions appears to be down, so make sure to test out any changes locally! Left some comments for some weird XLA compilation errors I was seeing.
# The index of the last non-padding token in prompt. Since all sequences | ||
# are aligned to the right side, the index is the same for all. | ||
current_index = max_length - num_steps | ||
original_padding_mask = tf.cast(tf.identity(mask), dtype=tf.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we compute this twice (above as well), should we just pass it through?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fairly cheap to compute. I was doing this way because sample
already has confusing arg list, would like to keep it shorter (though not much...)
Resolve #749