From 8be0f83236774d8693704a6f8a33a787ea564770 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Thu, 16 Jun 2022 13:32:38 -0700 Subject: [PATCH 1/9] reformatted greedy search with helper functions and added random sampling util --- keras_nlp/utils/text_generation.py | 170 +++++++++++++++++++++--- keras_nlp/utils/text_generation_test.py | 48 ++++++- 2 files changed, 195 insertions(+), 23 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index f07f9b9a89..329105fce4 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -16,6 +16,58 @@ import tensorflow as tf +def _validate_prompt(prompt): + """ + Validate the prompt and reformat for use. + + Args: + prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to + append generated tokens. + + Returns: + a 2D Tensor, the prompt with shape [batch_size, max_length]. + """ + if isinstance(prompt, tf.RaggedTensor): + raise ValueError( + "RaggedTensor `prompt` is not supported, please " + "provide `prompt` as a list or Tensor." + ) + if not isinstance(prompt, tf.Tensor): + prompt = tf.convert_to_tensor(prompt) + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + return prompt, input_is_1d + + +def _mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id +): + """ + Mask the tokens after the end token. + + Args: + prompt: a 2D Tensor, the prompt with shape [batch_size, max_length]. + max_length: an integer, the maximum length of the prompt. + end_token_id: an integer, the id of the end token. + pad_token_id: an integer, the id of the padding token. + + Returns: + a 2D Tensor, the masked prompt with shape [batch_size, max_length]. All + tokens after encountering `end_token_id` will be replaced with + `pad_token_id`. + """ + # Mask out tokens after `end_token_id` is encountered. + # Find index of first end_token_id. + end_indices = tf.math.argmax(prompt == end_token_id, -1) + # Use max_length if no `end_token_id` is found. + end_indices = tf.where(end_indices == 0, max_length, end_indices) + # Build a mask including end_token and replace tokens after end_token + # with `pad_token_id`. + valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + prompt = tf.where(valid_indices, prompt, pad_token_id) + return prompt + def greedy_search( token_probability_fn, @@ -88,17 +140,8 @@ def token_probability_fn(inputs): "tf.function or run `tf.config.run_functions_eagerly(True)` to run " "tf.function in eager mode." ) - if isinstance(prompt, tf.RaggedTensor): - raise ValueError( - "RaggedTensor `prompt` is not supported, please " - "provide `prompt` as a list or Tensor." - ) - if not isinstance(prompt, tf.Tensor): - prompt = tf.convert_to_tensor(prompt) - input_is_1d = prompt.shape.rank == 1 - if input_is_1d: - prompt = prompt[tf.newaxis, :] + prompt, input_is_1d = _validate_prompt(prompt) i = prompt.shape[1] while i < max_length: # If the prompt has reached our desired length, exit while loop. @@ -109,16 +152,105 @@ def token_probability_fn(inputs): i += 1 if end_token_id is not None: - # Mask out tokens after `end_token_id` is encountered. - # Find index of first end_token_id. - end_indices = tf.math.argmax(prompt == end_token_id, -1) - # Use max_length if no `end_token_id` is found. - end_indices = tf.where(end_indices == 0, max_length, end_indices) - # Build a mask including end_token and replace tokens after end_token - # with `pad_token_id`. - valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - prompt = tf.where(valid_indices, prompt, pad_token_id) + prompt = _mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id + ) if input_is_1d: return tf.squeeze(prompt) return prompt + +def random_sampling( + token_probability_fn, + prompt, + max_length, + seed=None, + end_token_id=None, + pad_token_id=0, +): + """ + Text generation utility based on random sampling. + + Random sampling samples the next token from the probability distribution + provided by `token_probability_fn` and appends it to the existing sequence. + + Args: + token_probability_fn: a callable, which takes in input_sequence + and output the probability distribution of the next token. + prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to + append generated tokens. + max_length: int. The max length of generated text. + seed: int, defaults to None. The random seed used for sampling. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + + Returns: + A 1D int Tensor, or 2D int Tensor representing the generated + sequences. + + Examples: + ```python + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability given the + # input sequence. + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64) + + # Print the generated sequence (token ids). + keras_nlp.utils.random_sampling( + token_probability_fn, + prompt, + max_length=10, + end_token_id=0,) + ``` + + """ + if not tf.executing_eagerly(): + raise RuntimeError( + "`keras_nlp.utils.random_sampling` currently requires an eager " + "execution context. Please call `random_sampling` outside " + "tf.function or run `tf.config.run_functions_eagerly(True)` to run " + "tf.function in eager mode." + ) + + prompt, input_is_1d = _validate_prompt(prompt) + i = prompt.shape[1] + while i < max_length: + # If the prompt has reached our desired length, exit while loop. + pred = token_probability_fn(prompt) + next_token = tf.cast( + tf.random.categorical(tf.math.log(pred), 1, seed=seed), + dtype=prompt.dtype, + ) + # Append the next token to current sequence. + prompt = tf.concat([prompt, next_token], axis=-1) + i += 1 + + if end_token_id is not None: + prompt = _mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id + ) + if input_is_1d: + return tf.squeeze(prompt) + return prompt \ No newline at end of file diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index dc4686832f..71106b2001 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -15,9 +15,9 @@ import tensorflow as tf - +import numpy as np from keras_nlp.utils.text_generation import greedy_search - +from keras_nlp.utils.text_generation import random_sampling class TextGenerationTest(tf.test.TestCase): def setUp(self): @@ -47,26 +47,41 @@ def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [5]) + outputs = random_sampling( + self.token_probability_fn, inputs, max_length=5 + ) + self.assertEquals(outputs.shape, [5]) def test_generate_with_2d_prompt(self): inputs = tf.constant([[1], [1]]) outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [2, 5]) + outputs = random_sampling( + self.token_probability_fn, inputs, max_length=5 + ) + self.assertEquals(outputs.shape, [2, 5]) + def test_generate_with_list_prompt(self): inputs = [[1], [1]] outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [2, 5]) + outputs = random_sampling( + self.token_probability_fn, inputs, max_length=5 + ) + self.assertEquals(outputs.shape, [2, 5]) def test_generate_with_ragged_prompt(self): inputs = tf.ragged.constant([[1], [2, 3]]) with self.assertRaises(ValueError): greedy_search(self.token_probability_fn, inputs, max_length=5) + with self.assertRaises(ValueError): + random_sampling(self.token_probability_fn, inputs, max_length=5) def test_assert_generation_is_correct(self): def token_probability_fn(inputs): batch_size = inputs.shape[0] - prob = tf.constant([[0.1, 0.2, 0.3, 0.4]]) + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) return tf.repeat(prob, batch_size, axis=0) batch_size = 10 @@ -78,11 +93,23 @@ def token_probability_fn(inputs): self.assertAllEqual( outputs, 3 * tf.ones(shape=[batch_size, max_length]) ) + outputs = random_sampling( + token_probability_fn, inputs, max_length=max_length, seed=42 + ) + # Random sampling result with seed 42 + seeded_result = 3 * np.ones(shape=[batch_size, max_length]) + seeded_result[2][2] = 2 + seeded_result[5][2] = 2 + seeded_result[6][2] = 2 + seeded_result[7][1] = 2 + seeded_result[8][1] = 2 + seeded_result[8][2] = 2 + self.assertAllEqual(outputs, seeded_result) def test_end_token_id(self): def token_probability_fn(inputs): batch_size = inputs.shape[0] - prob = tf.constant([[0.1, 0.2, 0.3, 0.4]]) + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) return tf.repeat(prob, batch_size, axis=0) max_length = 5 @@ -96,5 +123,18 @@ def token_probability_fn(inputs): ) expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) + outputs = random_sampling( + token_probability_fn, + inputs, + max_length=max_length, + seed=42, + end_token_id=2, + pad_token_id=0, + ) + # Random sampling result with seed 42 + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs) + From b08d6c6b49392e4791e2b9ece7f16c941837fa55 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Thu, 16 Jun 2022 13:33:25 -0700 Subject: [PATCH 2/9] reformat files --- keras_nlp/utils/text_generation.py | 4 +++- keras_nlp/utils/text_generation_test.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 329105fce4..53eb4a8446 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -16,6 +16,7 @@ import tensorflow as tf + def _validate_prompt(prompt): """ Validate the prompt and reformat for use. @@ -160,6 +161,7 @@ def token_probability_fn(inputs): return tf.squeeze(prompt) return prompt + def random_sampling( token_probability_fn, prompt, @@ -253,4 +255,4 @@ def token_probability_fn(inputs): ) if input_is_1d: return tf.squeeze(prompt) - return prompt \ No newline at end of file + return prompt diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index 71106b2001..b4eabdc5ec 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -14,11 +14,13 @@ """Tests for Text Generation Utils.""" -import tensorflow as tf import numpy as np +import tensorflow as tf + from keras_nlp.utils.text_generation import greedy_search from keras_nlp.utils.text_generation import random_sampling + class TextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() @@ -61,7 +63,6 @@ def test_generate_with_2d_prompt(self): ) self.assertEquals(outputs.shape, [2, 5]) - def test_generate_with_list_prompt(self): inputs = [[1], [1]] outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) @@ -137,4 +138,3 @@ def token_probability_fn(inputs): expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs) - From 9bcf5a8a474b6d76e87a64d9a4f4025368ea6640 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Thu, 16 Jun 2022 16:40:16 -0700 Subject: [PATCH 3/9] split testing into two classes + minor changes --- keras_nlp/utils/text_generation.py | 34 +++--- keras_nlp/utils/text_generation_test.py | 131 +++++++++++++++++++----- 2 files changed, 123 insertions(+), 42 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 53eb4a8446..668bc37481 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -17,7 +17,7 @@ import tensorflow as tf -def _validate_prompt(prompt): +def validate_prompt(prompt): """ Validate the prompt and reformat for use. @@ -26,7 +26,7 @@ def _validate_prompt(prompt): append generated tokens. Returns: - a 2D Tensor, the prompt with shape [batch_size, max_length]. + a 1D or 2D Tensor, with the same shape as prompt. """ if isinstance(prompt, tf.RaggedTensor): raise ValueError( @@ -35,13 +35,10 @@ def _validate_prompt(prompt): ) if not isinstance(prompt, tf.Tensor): prompt = tf.convert_to_tensor(prompt) - input_is_1d = prompt.shape.rank == 1 - if input_is_1d: - prompt = prompt[tf.newaxis, :] - return prompt, input_is_1d + return prompt -def _mask_tokens_after_end_token( +def mask_tokens_after_end_token( prompt, max_length, end_token_id, pad_token_id ): """ @@ -66,8 +63,7 @@ def _mask_tokens_after_end_token( # Build a mask including end_token and replace tokens after end_token # with `pad_token_id`. valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - prompt = tf.where(valid_indices, prompt, pad_token_id) - return prompt + return tf.where(valid_indices, prompt, pad_token_id) def greedy_search( @@ -142,7 +138,12 @@ def token_probability_fn(inputs): "tf.function in eager mode." ) - prompt, input_is_1d = _validate_prompt(prompt) + prompt = validate_prompt(prompt) + + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + i = prompt.shape[1] while i < max_length: # If the prompt has reached our desired length, exit while loop. @@ -153,7 +154,7 @@ def token_probability_fn(inputs): i += 1 if end_token_id is not None: - prompt = _mask_tokens_after_end_token( + prompt = mask_tokens_after_end_token( prompt, max_length, end_token_id, pad_token_id ) @@ -171,7 +172,8 @@ def random_sampling( pad_token_id=0, ): """ - Text generation utility based on random sampling. + Text generation utility based on randomly sampling the entire probability + distribution. Random sampling samples the next token from the probability distribution provided by `token_probability_fn` and appends it to the existing sequence. @@ -236,7 +238,11 @@ def token_probability_fn(inputs): "tf.function in eager mode." ) - prompt, input_is_1d = _validate_prompt(prompt) + prompt = validate_prompt(prompt) + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + i = prompt.shape[1] while i < max_length: # If the prompt has reached our desired length, exit while loop. @@ -250,7 +256,7 @@ def token_probability_fn(inputs): i += 1 if end_token_id is not None: - prompt = _mask_tokens_after_end_token( + prompt = mask_tokens_after_end_token( prompt, max_length, end_token_id, pad_token_id ) if input_is_1d: diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index b4eabdc5ec..e79a7e2a31 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -21,7 +21,7 @@ from keras_nlp.utils.text_generation import random_sampling -class TextGenerationTest(tf.test.TestCase): +class GreedySearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() vocab_size = 10 @@ -49,36 +49,22 @@ def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [5]) - outputs = random_sampling( - self.token_probability_fn, inputs, max_length=5 - ) - self.assertEquals(outputs.shape, [5]) def test_generate_with_2d_prompt(self): inputs = tf.constant([[1], [1]]) outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [2, 5]) - outputs = random_sampling( - self.token_probability_fn, inputs, max_length=5 - ) - self.assertEquals(outputs.shape, [2, 5]) def test_generate_with_list_prompt(self): inputs = [[1], [1]] outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [2, 5]) - outputs = random_sampling( - self.token_probability_fn, inputs, max_length=5 - ) - self.assertEquals(outputs.shape, [2, 5]) def test_generate_with_ragged_prompt(self): inputs = tf.ragged.constant([[1], [2, 3]]) with self.assertRaises(ValueError): greedy_search(self.token_probability_fn, inputs, max_length=5) - with self.assertRaises(ValueError): - random_sampling(self.token_probability_fn, inputs, max_length=5) - + def test_assert_generation_is_correct(self): def token_probability_fn(inputs): batch_size = inputs.shape[0] @@ -94,18 +80,6 @@ def token_probability_fn(inputs): self.assertAllEqual( outputs, 3 * tf.ones(shape=[batch_size, max_length]) ) - outputs = random_sampling( - token_probability_fn, inputs, max_length=max_length, seed=42 - ) - # Random sampling result with seed 42 - seeded_result = 3 * np.ones(shape=[batch_size, max_length]) - seeded_result[2][2] = 2 - seeded_result[5][2] = 2 - seeded_result[6][2] = 2 - seeded_result[7][1] = 2 - seeded_result[8][1] = 2 - seeded_result[8][2] = 2 - self.assertAllEqual(outputs, seeded_result) def test_end_token_id(self): def token_probability_fn(inputs): @@ -126,6 +100,107 @@ def token_probability_fn(inputs): expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs) +class RandomSamplingTextGenerationTest(tf.test.TestCase): + def setUp(self): + super().setUp() + vocab_size = 10 + feature_size = 16 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=vocab_size, + output_dim=feature_size, + ), + tf.keras.layers.Dense(vocab_size), + tf.keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + self.token_probability_fn = token_probability_fn + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = random_sampling( + self.token_probability_fn, inputs, max_length=5 + ) + self.assertEquals(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = random_sampling( + self.token_probability_fn, inputs, max_length=5 + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = random_sampling( + self.token_probability_fn, inputs, max_length=5 + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + inputs = tf.ragged.constant([[1], [2, 3]]) + with self.assertRaises(ValueError): + random_sampling(self.token_probability_fn, inputs, max_length=5) + + def test_assert_seeded_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + tf.random.set_seed(42) + outputs = random_sampling( + token_probability_fn, inputs, max_length=max_length, seed=42 + ) + # Random sampling result with seed 42 + seeded_result = 3 * np.ones(shape=[batch_size, max_length]) + self.assertAllEqual(outputs, seeded_result) + + def test_assert_probability_distribution_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + + outputs_count = np.array([0, 0, 0, 0]) + tf.random.set_seed(42) + for i in range(500): + outputs = random_sampling( + token_probability_fn, inputs, max_length=max_length, seed=42 + ) + flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) + for pred in flatten_predictions: + outputs_count[pred] += 1 + self.assertAllClose( + outputs_count/np.sum(outputs_count), + [0.01, 0.01, 0.08, 0.9], + rtol=0.2 + ) + + def test_seeded_end_token_id(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + outputs = random_sampling( token_probability_fn, inputs, From d31020f2dd61285700ea3c9e1e78639237df1b55 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Thu, 16 Jun 2022 18:22:32 -0700 Subject: [PATCH 4/9] formatted code --- keras_nlp/utils/text_generation.py | 4 +--- keras_nlp/utils/text_generation_test.py | 9 +++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 668bc37481..7ec66c2930 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -38,9 +38,7 @@ def validate_prompt(prompt): return prompt -def mask_tokens_after_end_token( - prompt, max_length, end_token_id, pad_token_id -): +def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id): """ Mask the tokens after the end token. diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index e79a7e2a31..e011a8df9b 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -64,7 +64,7 @@ def test_generate_with_ragged_prompt(self): inputs = tf.ragged.constant([[1], [2, 3]]) with self.assertRaises(ValueError): greedy_search(self.token_probability_fn, inputs, max_length=5) - + def test_assert_generation_is_correct(self): def token_probability_fn(inputs): batch_size = inputs.shape[0] @@ -100,6 +100,7 @@ def token_probability_fn(inputs): expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs) + class RandomSamplingTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() @@ -187,9 +188,9 @@ def token_probability_fn(inputs): for pred in flatten_predictions: outputs_count[pred] += 1 self.assertAllClose( - outputs_count/np.sum(outputs_count), - [0.01, 0.01, 0.08, 0.9], - rtol=0.2 + outputs_count / np.sum(outputs_count), + [0.01, 0.01, 0.08, 0.9], + rtol=0.2, ) def test_seeded_end_token_id(self): From b05a01c488b1dd41fce3ad9970f95f4d7509141b Mon Sep 17 00:00:00 2001 From: jessechancy Date: Tue, 21 Jun 2022 11:28:57 -0700 Subject: [PATCH 5/9] naming changes --- keras_nlp/utils/text_generation.py | 2 +- keras_nlp/utils/text_generation_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 7ec66c2930..1047f6f9c5 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -161,7 +161,7 @@ def token_probability_fn(inputs): return prompt -def random_sampling( +def random_sampling_search( token_probability_fn, prompt, max_length, diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index e011a8df9b..37c8f1fcec 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -193,7 +193,7 @@ def token_probability_fn(inputs): rtol=0.2, ) - def test_seeded_end_token_id(self): + def test_end_token_id(self): def token_probability_fn(inputs): batch_size = inputs.shape[0] prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) From 76785608eda356388b796a4af1ce81ba0704f757 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Tue, 21 Jun 2022 11:46:00 -0700 Subject: [PATCH 6/9] naming changes --- keras_nlp/utils/text_generation_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index 37c8f1fcec..597c5fa760 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -18,7 +18,7 @@ import tensorflow as tf from keras_nlp.utils.text_generation import greedy_search -from keras_nlp.utils.text_generation import random_sampling +from keras_nlp.utils.text_generation import random_sampling_search class GreedySearchTextGenerationTest(tf.test.TestCase): @@ -127,21 +127,21 @@ def token_probability_fn(inputs): def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) - outputs = random_sampling( + outputs = random_sampling_search( self.token_probability_fn, inputs, max_length=5 ) self.assertEquals(outputs.shape, [5]) def test_generate_with_2d_prompt(self): inputs = tf.constant([[1], [1]]) - outputs = random_sampling( + outputs = random_sampling_search( self.token_probability_fn, inputs, max_length=5 ) self.assertEquals(outputs.shape, [2, 5]) def test_generate_with_list_prompt(self): inputs = [[1], [1]] - outputs = random_sampling( + outputs = random_sampling_search( self.token_probability_fn, inputs, max_length=5 ) self.assertEquals(outputs.shape, [2, 5]) @@ -149,7 +149,7 @@ def test_generate_with_list_prompt(self): def test_generate_with_ragged_prompt(self): inputs = tf.ragged.constant([[1], [2, 3]]) with self.assertRaises(ValueError): - random_sampling(self.token_probability_fn, inputs, max_length=5) + random_sampling_search(self.token_probability_fn, inputs, max_length=5) def test_assert_seeded_generation_is_correct(self): def token_probability_fn(inputs): @@ -161,7 +161,7 @@ def token_probability_fn(inputs): inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) max_length = 3 tf.random.set_seed(42) - outputs = random_sampling( + outputs = random_sampling_search( token_probability_fn, inputs, max_length=max_length, seed=42 ) # Random sampling result with seed 42 @@ -181,7 +181,7 @@ def token_probability_fn(inputs): outputs_count = np.array([0, 0, 0, 0]) tf.random.set_seed(42) for i in range(500): - outputs = random_sampling( + outputs = random_sampling_search( token_probability_fn, inputs, max_length=max_length, seed=42 ) flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) @@ -202,7 +202,7 @@ def token_probability_fn(inputs): max_length = 5 inputs = tf.constant([[0, 1], [1, 2]]) - outputs = random_sampling( + outputs = random_sampling_search( token_probability_fn, inputs, max_length=max_length, From bfb032d9aca22d1d271d1a782b68f63f96c42344 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Tue, 21 Jun 2022 11:48:45 -0700 Subject: [PATCH 7/9] format changes --- keras_nlp/utils/text_generation_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index 597c5fa760..916eee21d2 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -149,7 +149,9 @@ def test_generate_with_list_prompt(self): def test_generate_with_ragged_prompt(self): inputs = tf.ragged.constant([[1], [2, 3]]) with self.assertRaises(ValueError): - random_sampling_search(self.token_probability_fn, inputs, max_length=5) + random_sampling_search( + self.token_probability_fn, inputs, max_length=5 + ) def test_assert_seeded_generation_is_correct(self): def token_probability_fn(inputs): From 0f6757da6f37958829eb7ce6b383fc666a9b6c84 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Tue, 21 Jun 2022 14:11:59 -0700 Subject: [PATCH 8/9] naming changes to random_search --- keras_nlp/utils/text_generation.py | 2 +- keras_nlp/utils/text_generation_test.py | 24 ++++++++---------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 1047f6f9c5..f7332bd7aa 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -161,7 +161,7 @@ def token_probability_fn(inputs): return prompt -def random_sampling_search( +def random_search( token_probability_fn, prompt, max_length, diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index 916eee21d2..54e4bed209 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -18,7 +18,7 @@ import tensorflow as tf from keras_nlp.utils.text_generation import greedy_search -from keras_nlp.utils.text_generation import random_sampling_search +from keras_nlp.utils.text_generation import random_search class GreedySearchTextGenerationTest(tf.test.TestCase): @@ -127,31 +127,23 @@ def token_probability_fn(inputs): def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) - outputs = random_sampling_search( - self.token_probability_fn, inputs, max_length=5 - ) + outputs = random_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [5]) def test_generate_with_2d_prompt(self): inputs = tf.constant([[1], [1]]) - outputs = random_sampling_search( - self.token_probability_fn, inputs, max_length=5 - ) + outputs = random_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [2, 5]) def test_generate_with_list_prompt(self): inputs = [[1], [1]] - outputs = random_sampling_search( - self.token_probability_fn, inputs, max_length=5 - ) + outputs = random_search(self.token_probability_fn, inputs, max_length=5) self.assertEquals(outputs.shape, [2, 5]) def test_generate_with_ragged_prompt(self): inputs = tf.ragged.constant([[1], [2, 3]]) with self.assertRaises(ValueError): - random_sampling_search( - self.token_probability_fn, inputs, max_length=5 - ) + random_search(self.token_probability_fn, inputs, max_length=5) def test_assert_seeded_generation_is_correct(self): def token_probability_fn(inputs): @@ -163,7 +155,7 @@ def token_probability_fn(inputs): inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) max_length = 3 tf.random.set_seed(42) - outputs = random_sampling_search( + outputs = random_search( token_probability_fn, inputs, max_length=max_length, seed=42 ) # Random sampling result with seed 42 @@ -183,7 +175,7 @@ def token_probability_fn(inputs): outputs_count = np.array([0, 0, 0, 0]) tf.random.set_seed(42) for i in range(500): - outputs = random_sampling_search( + outputs = random_search( token_probability_fn, inputs, max_length=max_length, seed=42 ) flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) @@ -204,7 +196,7 @@ def token_probability_fn(inputs): max_length = 5 inputs = tf.constant([[0, 1], [1, 2]]) - outputs = random_sampling_search( + outputs = random_search( token_probability_fn, inputs, max_length=max_length, From 1f8cec18ff436011988977b87a8ae2a8c8191326 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Tue, 21 Jun 2022 16:47:10 -0700 Subject: [PATCH 9/9] removed docstring for helper and added random_search to init --- keras_nlp/utils/__init__.py | 1 + keras_nlp/utils/text_generation.py | 22 ++-------------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/keras_nlp/utils/__init__.py b/keras_nlp/utils/__init__.py index 533f0aa8f9..28c1af5a67 100644 --- a/keras_nlp/utils/__init__.py +++ b/keras_nlp/utils/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from keras_nlp.utils.text_generation import greedy_search +from keras_nlp.utils.text_generation import random_search diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index f7332bd7aa..81b09decc1 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -19,14 +19,7 @@ def validate_prompt(prompt): """ - Validate the prompt and reformat for use. - - Args: - prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to - append generated tokens. - - Returns: - a 1D or 2D Tensor, with the same shape as prompt. + Helper function to validate input to text_generation utils. """ if isinstance(prompt, tf.RaggedTensor): raise ValueError( @@ -40,18 +33,7 @@ def validate_prompt(prompt): def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id): """ - Mask the tokens after the end token. - - Args: - prompt: a 2D Tensor, the prompt with shape [batch_size, max_length]. - max_length: an integer, the maximum length of the prompt. - end_token_id: an integer, the id of the end token. - pad_token_id: an integer, the id of the padding token. - - Returns: - a 2D Tensor, the masked prompt with shape [batch_size, max_length]. All - tokens after encountering `end_token_id` will be replaced with - `pad_token_id`. + Helper function to mask the tokens after the end token. """ # Mask out tokens after `end_token_id` is encountered. # Find index of first end_token_id.