-
Notifications
You must be signed in to change notification settings - Fork 301
Stop generation once end_token_id is seen #769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3e4f2f5
e292f53
ea63d50
17b1c1e
7bbcc41
a67ea77
6d2ef24
1a675d6
9583a8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,8 +236,9 @@ def __call__( | |
token_probability_fn, | ||
mask, | ||
max_length - shortest_prompt_len, | ||
cache=cache, | ||
from_logits=from_logits, | ||
end_token_id=end_token_id, | ||
cache=cache, | ||
) | ||
# Mask out tokens after `end_token_id`. | ||
if end_token_id is not None: | ||
|
@@ -269,6 +270,7 @@ def sample( | |
mask, | ||
num_steps, | ||
from_logits=True, | ||
end_token_id=None, | ||
cache=None, | ||
): | ||
"""Sampling logic implementation. | ||
|
@@ -284,6 +286,9 @@ def sample( | |
from_logits: bool, defaults to True. Indicate if the | ||
`token_probability_fn` returns logits. If False, | ||
`token_probability_fn` returns probability distributions. | ||
end_token_id: int, defaults to None. The token marking the end of | ||
the sequence, once encountered the generation is finished for | ||
the exact sequence. | ||
cache: a dense int tensor, the cache used in decoding. The cache | ||
stores the key and value of each | ||
`keras_nlp.layers.CachedMultiHeadAttention` layer to make the | ||
|
@@ -299,8 +304,9 @@ def sample( | |
# The index of the last non-padding token in prompt. Since all sequences | ||
# are aligned to the right side, the index is the same for all. | ||
current_index = max_length - num_steps | ||
original_padding_mask = tf.cast(tf.identity(mask), dtype=tf.int32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we compute this twice (above as well), should we just pass it through? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's fairly cheap to compute. I was doing this way because |
||
|
||
def one_step( | ||
def body( | ||
current_index, | ||
prompt, | ||
mask, | ||
|
@@ -316,7 +322,10 @@ def one_step( | |
) | ||
next_token_probs = tf.squeeze(probs, axis=1) | ||
else: | ||
probs = token_probability_fn(prompt, mask) | ||
probs = token_probability_fn( | ||
prompt, | ||
mask, | ||
) | ||
next_token_probs = tf.gather( | ||
probs, | ||
tf.repeat(current_index - 1, batch_size), | ||
|
@@ -343,24 +352,32 @@ def one_step( | |
current_index = tf.add(current_index, 1) | ||
if cache is None: | ||
return current_index, prompt, mask | ||
return [current_index, prompt, mask, cache] | ||
return current_index, prompt, mask, cache | ||
|
||
def cond(current_index, prompt, mask, cache=None): | ||
if end_token_id is None: | ||
return True | ||
end_token_seen = (prompt == end_token_id) & ( | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
original_padding_mask == 0 | ||
) | ||
sequence_done = tf.reduce_any(end_token_seen, axis=-1) | ||
all_done = tf.reduce_all(sequence_done) | ||
return not all_done | ||
|
||
if cache is None: | ||
_, prompt, _ = tf.while_loop( | ||
cond=lambda current_index, prompt, mask: tf.less( | ||
current_index, max_length | ||
), | ||
body=one_step, | ||
loop_vars=[current_index, prompt, mask], | ||
cond=cond, | ||
body=body, | ||
loop_vars=(current_index, prompt, mask), | ||
maximum_iterations=num_steps, | ||
) | ||
return prompt | ||
# Run a while loop till `max_length` of tokens has been generated. | ||
_, prompt, _, _ = tf.while_loop( | ||
cond=lambda current_index, prompt, mask, cache: tf.less( | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
current_index, max_length | ||
), | ||
body=one_step, | ||
loop_vars=[current_index, prompt, mask, cache], | ||
cond=cond, | ||
body=body, | ||
loop_vars=(current_index, prompt, mask, cache), | ||
maximum_iterations=num_steps, | ||
) | ||
return prompt | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because we by default set
from_logits=True
(was False earlier for mysterious reason), so I am aligning the unit test with it.