Skip to content

Conversation

jessechancy
Copy link
Contributor

Randomly samples using the probability distribution provided by the input function.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Left some initial comments

)
if not isinstance(prompt, tf.Tensor):
prompt = tf.convert_to_tensor(prompt)
input_is_1d = prompt.shape.rank == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be simpler if you left the upranking in the main function. Currently the upranking and downranking and now split apart from each other, which is bad for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the section out of the helper function, into the main function

return prompt, input_is_1d


def _mask_tokens_after_end_token(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to lead with underscore here. We choose what function to "export" in our API by listing the init.py. Underscore just needed for private class methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed underscores for both helper functions

# Build a mask including end_token and replace tokens after end_token
# with `pad_token_id`.
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
prompt = tf.where(valid_indices, prompt, pad_token_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean I should do "return tf.where(valid_indices, prompt, pad_token_id)"? Just edited

pad_token_id=0,
):
"""
Text generation utility based on random sampling.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... randomly sampling the entire probability distribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

from keras_nlp.utils.text_generation import random_sampling


class TextGenerationTest(tf.test.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should split a separate class for greedy, random, etc. They should mirror each other, but not be mixed in the same unit tests.

return prompt


def random_sampling(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make sure we have naming uniformity. Do we want

greedy_search, random_search, top_k_search? or greedy_sampling, random_sampling, top_k_sampling`?

Or something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I was going to name it was greedy_search, beam_search, random_sampling, top_k_sampling and top_p_sampling, because the latter three uses probabilistic sampling techniques. But if uniformly calling them search is better, I'll change it to that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a name random_sampling_search. @mattdangerw Matt, what do you think?

append generated tokens.
Returns:
a 2D Tensor, the prompt with shape [batch_size, max_length].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct? I don't see padding in this function, so the width is not necessarily max_length.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited, this now says "a 1D or 2D Tensor, with the same shape as prompt."

inputs = tf.constant([1])
outputs = greedy_search(self.token_probability_fn, inputs, max_length=5)
self.assertEquals(outputs.shape, [5])
outputs = random_sampling(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now each test case is testing different generation algos, which could become unreadable when we have more utilities. Let's use parameterized test to pass in the utility to test. For example: https://github.com/keras-team/keras/blob/v2.9.0/keras/optimizers/optimizer_experimental/optimizer_test.py#L283

There is one thing I am not clear - for different algos, the expected generation results can differ, and how should we test them in a clear way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decoding methods I'll be adding are randomized, hence making it harder to test. Most of the test don't test for value, but rather shape, so those can be reused for different decoding methods.

My testing would be based on seeding the decoding methods so that they generate a specific value, which is seen here: https://github.com/jessechancy/keras-nlp/blob/b08d6c6b49392e4791e2b9ece7f16c941837fa55/keras_nlp/utils/text_generation_test.py#L82.

I also have specific test for top-k and top-p, which runs the algorithm and makes sure that the tokens that are supposed to be cut off, won't ever be selected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a comment above, I think we should just split a separate test class for each utility. Parameterized will get messy.

Re random testing, we can always set a random seed for the test and check the entire output. Though that won't check we are actually sampling the distribution correctly, that's definitely harder.

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Mainly looks good, some minor comments!

return prompt


def random_sampling(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a name random_sampling_search. @mattdangerw Matt, what do you think?

rtol=0.2,
)

def test_seeded_end_token_id(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case's name is a little confusing to me, what are we testing here? Are we testing generation with a given seed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is testing whether the end token is detected and anything after filled with pad token. It needs to be seeded because its still randomized when the end token is detected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let's just call it test_end_token_id or test_handle_end_token_id, seed here is something we don't need to expose to readers. The current name has a suggestion that we are "seeding" the end_token_id.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm once we add the line to utils/__init__.py



def validate_prompt(prompt):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, with the whole docstring for a helper it's hard to tell it's just a helper function. In general I don't think you would need the whole args/return structure for something small like this.

Just Helper function to validate input to text_generation utils.


def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id):
"""
Mask the tokens after the end token.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Mention it's a helper, maybe kill the whole args/returns section.

return prompt


def random_search(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add this to the __init__.py for the utils dir, so this gets exported.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks1

@mattdangerw mattdangerw merged commit 65349af into keras-team:master Jun 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants