-
Notifications
You must be signed in to change notification settings - Fork 301
top p search and testing #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
top p search and testing #233
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Mainly looks good, left some comments
keras_nlp/utils/text_generation.py
Outdated
pred = tf.keras.activations.softmax(pred, axis=-1) | ||
# Sort preds in descending order. | ||
sorted_preds, sorted_indices = tf.math.top_k( | ||
pred, k=pred.shape[-1], sorted=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be consistent with -1 or 1 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
keras_nlp/utils/text_generation.py
Outdated
replaced with `pad_token_id`. | ||
pad_token_id: int, defaults to 0. The pad token after `end_token_id` | ||
is received. | ||
filter_value: float, defaults to -Inf. The value for filtering out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to allow customization on filter_value
? The code uses filter_value
to set certain tokens to have probability 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to allow customization, we can just mask out tokens that are unused by having filter_value = 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, in that case, let's remove this argument.
keras_nlp/utils/text_generation.py
Outdated
) | ||
# Filter out unmasked tokens and sample from filtered distribution. | ||
probs = tf.where( | ||
shifted_keep_mask, sorted_preds, tf.fill(pred.shape, filter_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should filter_value
always be 0? I saw in the docstring it defaults to -Inf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep filter_value should be 0, edited
self.assertAllEqual(output_logit, output_probs) | ||
|
||
|
||
class TopPSamplingTextGenerationTest(tf.test.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep test name and method name the same - TopPSearchTextGenerationTest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited for this and other testing names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only a few minor comments!
keras_nlp/utils/text_generation.py
Outdated
replaced with `pad_token_id`. | ||
pad_token_id: int, defaults to 0. The pad token after `end_token_id` | ||
is received. | ||
filter_value: float, defaults to -Inf. The value for filtering out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, in that case, let's remove this argument.
prob = tf.constant([[0.4, 0.3, 0.2, 0.1]]) | ||
return tf.repeat(prob, batch_size, axis=0) | ||
|
||
# Test that it only samples from tokens that sum up to p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: period at end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Nice work.
Just a few minor comments
keras_nlp/utils/text_generation.py
Outdated
) | ||
if p <= 0 or p >= 1: | ||
raise ValueError( | ||
"p should be a float strictly between 0 and 1 (0 < p < 1)." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surround arg names in backticks (p
) also show what value we received, something like.
`p` should be in the range (0, 1). Received: `p={p}`.
keras_nlp/utils/text_generation.py
Outdated
""" | ||
Text generation utility based on top-p (nucleus) sampling. | ||
Top-p search filters the top token with probabilities that sum up to p, and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can explain this a little more clearly, also remember to backtick argument names so it's clear what they are...
Top-p search selects tokens from the smallest subset of output probabilities that sum to greater than p
. Put another way, top-p will first order token predictions by likelihood, and ignore all tokens after the cumulative probably of selected tokens exceeds p
.
prompt = tf.fill((BATCH_SIZE, 1), START_ID) | ||
# Print the generated sequence (token ids). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
No description provided.