-
Notifications
You must be signed in to change notification settings - Fork 301
Random Sampling Util for Text Generation #228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random Sampling Util for Text Generation #228
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Left some initial comments
keras_nlp/utils/text_generation.py
Outdated
) | ||
if not isinstance(prompt, tf.Tensor): | ||
prompt = tf.convert_to_tensor(prompt) | ||
input_is_1d = prompt.shape.rank == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be simpler if you left the upranking in the main function. Currently the upranking and downranking and now split apart from each other, which is bad for readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved the section out of the helper function, into the main function
keras_nlp/utils/text_generation.py
Outdated
return prompt, input_is_1d | ||
|
||
|
||
def _mask_tokens_after_end_token( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to lead with underscore here. We choose what function to "export" in our API by listing the init.py. Underscore just needed for private class methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed underscores for both helper functions
keras_nlp/utils/text_generation.py
Outdated
# Build a mask including end_token and replace tokens after end_token | ||
# with `pad_token_id`. | ||
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) | ||
prompt = tf.where(valid_indices, prompt, pad_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean I should do "return tf.where(valid_indices, prompt, pad_token_id)"? Just edited
keras_nlp/utils/text_generation.py
Outdated
pad_token_id=0, | ||
): | ||
""" | ||
Text generation utility based on random sampling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... randomly sampling the entire probability distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
from keras_nlp.utils.text_generation import random_sampling | ||
|
||
|
||
class TextGenerationTest(tf.test.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should split a separate class for greedy, random, etc. They should mirror each other, but not be mixed in the same unit tests.
keras_nlp/utils/text_generation.py
Outdated
return prompt | ||
|
||
|
||
def random_sampling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to make sure we have naming uniformity. Do we want
greedy_search
, random_search
, top_k_search
? or greedy_sampling
, random_sampling
, top_k_sampling`?
Or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I was going to name it was greedy_search, beam_search, random_sampling, top_k_sampling and top_p_sampling, because the latter three uses probabilistic sampling techniques. But if uniformly calling them search is better, I'll change it to that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer a name random_sampling_search
. @mattdangerw Matt, what do you think?
keras_nlp/utils/text_generation.py
Outdated
append generated tokens. | ||
Returns: | ||
a 2D Tensor, the prompt with shape [batch_size, max_length]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct? I don't see padding in this function, so the width is not necessarily max_length
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited, this now says "a 1D or 2D Tensor, with the same shape as prompt."
inputs = tf.constant([1]) | ||
outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) | ||
self.assertEquals(outputs.shape, [5]) | ||
outputs = random_sampling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now each test case is testing different generation algos, which could become unreadable when we have more utilities. Let's use parameterized test to pass in the utility to test. For example: https://github.com/keras-team/keras/blob/v2.9.0/keras/optimizers/optimizer_experimental/optimizer_test.py#L283
There is one thing I am not clear - for different algos, the expected generation results can differ, and how should we test them in a clear way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decoding methods I'll be adding are randomized, hence making it harder to test. Most of the test don't test for value, but rather shape, so those can be reused for different decoding methods.
My testing would be based on seeding the decoding methods so that they generate a specific value, which is seen here: https://github.com/jessechancy/keras-nlp/blob/b08d6c6b49392e4791e2b9ece7f16c941837fa55/keras_nlp/utils/text_generation_test.py#L82.
I also have specific test for top-k and top-p, which runs the algorithm and makes sure that the tokens that are supposed to be cut off, won't ever be selected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a comment above, I think we should just split a separate test class for each utility. Parameterized will get messy.
Re random testing, we can always set a random seed for the test and check the entire output. Though that won't check we are actually sampling the distribution correctly, that's definitely harder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Mainly looks good, some minor comments!
keras_nlp/utils/text_generation.py
Outdated
return prompt | ||
|
||
|
||
def random_sampling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer a name random_sampling_search
. @mattdangerw Matt, what do you think?
rtol=0.2, | ||
) | ||
|
||
def test_seeded_end_token_id(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case's name is a little confusing to me, what are we testing here? Are we testing generation with a given seed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is testing whether the end token is detected and anything after filled with pad token. It needs to be seeded because its still randomized when the end token is detected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, let's just call it test_end_token_id
or test_handle_end_token_id
, seed
here is something we don't need to expose to readers. The current name has a suggestion that we are "seeding" the end_token_id
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm once we add the line to utils/__init__.py
|
||
|
||
def validate_prompt(prompt): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, with the whole docstring for a helper it's hard to tell it's just a helper function. In general I don't think you would need the whole args/return structure for something small like this.
Just Helper function to validate input to text_generation utils.
keras_nlp/utils/text_generation.py
Outdated
|
||
def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id): | ||
""" | ||
Mask the tokens after the end token. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Mention it's a helper, maybe kill the whole args/returns section.
return prompt | ||
|
||
|
||
def random_search( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add this to the __init__.py
for the utils dir, so this gets exported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks1
Randomly samples using the probability distribution provided by the input function.