From ad45ee8272aebf242bbf010f64df03cdfc809752 Mon Sep 17 00:00:00 2001 From: Jesse Chan Date: Fri, 24 Jun 2022 15:24:08 -0700 Subject: [PATCH 1/6] top p search and testing --- keras_nlp/utils/text_generation.py | 143 +++++++++++++++++++ keras_nlp/utils/text_generation_test.py | 177 +++++++++++++++++++++++- 2 files changed, 318 insertions(+), 2 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 4c7583633c..101ebb1274 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -393,3 +393,146 @@ def token_probability_fn(inputs): if input_is_1d: return tf.squeeze(prompt) return prompt + + +def top_p_search( + token_probability_fn, + prompt, + max_length, + p, + seed=None, + from_logits=False, + end_token_id=None, + pad_token_id=0, + filter_value=-float("Inf"), +): + """ + Text generation utility based on top-p (nucleus) sampling. + + Top-p search filters the top token with probabilities that sum up to p, and + samples from this subset to get the next token. The probability of each + token is provided by `token_probability_fn`. + + Args: + token_probability_fn: a callable, which takes in input_sequence + and output the probability distribution of the next token. If + `from_logits` set to True, it should output the logits of the next + token. + prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to + append generated tokens. + max_length: int. The max length of generated text. + p: float. The probability that the top tokens sums up to. Should + follow the constraint of 0 < p < 1. + seed: int, defaults to None. The random seed used for sampling. + from_logits: bool. Indicates whether `token_probability_fn` outputs + logits or probabilities. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + filter_value: float, defaults to -Inf. The value for filtering out + unused tokens when sampling the probability distribution. + + Returns: + A 1D int Tensor, or 2D int Tensor representing the generated + sequences. + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability given the + # input sequence. + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + # Print the generated sequence (token ids). + keras_nlp.utils.top_p_search( + token_probability_fn, + prompt, + max_length=10, + p=0.8, + end_token_id=END_ID, + ) + ``` + + """ + if not tf.executing_eagerly(): + raise RuntimeError( + "`keras_nlp.utils.top_k_search` currently requires an eager " + "execution context. Please call `top_k_search` outside " + "tf.function or run `tf.config.run_functions_eagerly(True)` to run " + "tf.function in eager mode." + ) + if p <= 0 or p >= 1: + raise ValueError( + "p should be a float strictly between 0 and 1 (0 < p < 1)." + ) + + prompt = validate_prompt(prompt) + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + validate_token_probability_fn(token_probability_fn, prompt) + + i = prompt.shape[1] + while i < max_length: + # If the prompt has reached our desired length, exit while loop. + pred = token_probability_fn(prompt) + if from_logits: + pred = tf.keras.activations.softmax(pred, axis=-1) + # Sort preds in descending order. + sorted_preds, sorted_indices = tf.math.top_k( + pred, k=pred.shape[-1], sorted=True + ) + # Calculate cumulative probability distribution. + cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) + # Create a mask for the tokens to keep. + keep_mask = cumulative_probs <= p + # Shift to include the last token that exceed p. + shifted_keep_mask = tf.concat( + [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1 + ) + # Filter out unmasked tokens and sample from filtered distribution. + probs = tf.where( + shifted_keep_mask, sorted_preds, tf.fill(pred.shape, filter_value) + ) + sorted_next_token = tf.random.categorical( + tf.math.log(probs), 1, seed=seed + ) + next_token = tf.gather_nd( + sorted_indices, sorted_next_token, batch_dims=1 + ) + next_token = tf.cast(next_token, dtype=prompt.dtype) + # Append the next token to current sequence. + prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1) + i += 1 + + if end_token_id is not None: + prompt = mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id + ) + if input_is_1d: + return tf.squeeze(prompt) + return prompt diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index fab1e4e0e4..f87f899802 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -20,6 +20,7 @@ from keras_nlp.utils.text_generation import greedy_search from keras_nlp.utils.text_generation import random_search from keras_nlp.utils.text_generation import top_k_search +from keras_nlp.utils.text_generation import top_p_search class GreedySearchTextGenerationTest(tf.test.TestCase): @@ -387,13 +388,13 @@ def token_probability_fn(inputs): def test_from_logits(self): def token_logits_fn(inputs): batch_size = inputs.shape[0] - prob = tf.constant([[1.0, 2.0, 3, 0, 4.0]]) + prob = tf.constant([[1.0, 2.0, 3.0, 4.0]]) return tf.repeat(prob, batch_size, axis=0) def token_probability_fn(inputs): batch_size = inputs.shape[0] prob = tf.keras.activations.softmax( - tf.constant([[1.0, 2.0, 3, 0, 4.0]]) + tf.constant([[1.0, 2.0, 3.0, 4.0]]) ) return tf.repeat(prob, batch_size, axis=0) @@ -418,3 +419,175 @@ def token_probability_fn(inputs): seed=42, ) self.assertAllEqual(output_logit, output_probs) + + +class TopPSamplingTextGenerationTest(tf.test.TestCase): + def setUp(self): + super().setUp() + vocab_size = 10 + feature_size = 16 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=vocab_size, + output_dim=feature_size, + ), + tf.keras.layers.Dense(vocab_size), + tf.keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + self.token_probability_fn = token_probability_fn + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = top_p_search( + self.token_probability_fn, inputs, max_length=5, p=0.8 + ) + self.assertEquals(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = top_p_search( + self.token_probability_fn, inputs, max_length=5, p=0.8 + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = top_p_search( + self.token_probability_fn, inputs, max_length=5, p=0.8 + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + inputs = tf.ragged.constant([[1], [2, 3]]) + with self.assertRaises(ValueError): + top_p_search(self.token_probability_fn, inputs, max_length=5, p=0.8) + + def test_assert_seeded_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + tf.random.set_seed(42) + outputs = top_p_search( + token_probability_fn, inputs, max_length=max_length, p=0.91, seed=42 + ) + # Top-p sampling result with seed 42. + seeded_result = 3 * np.ones(shape=[batch_size, max_length]) + seeded_result[3][1] = 2 + seeded_result[7][1] = 2 + self.assertAllEqual(outputs, seeded_result) + + def test_assert_probability_distribution_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.1, 0.2, 0.3, 0.4]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + + outputs_count = np.array([0, 0, 0, 0]) + tf.random.set_seed(42) + for i in range(500): + outputs = top_p_search( + token_probability_fn, + inputs, + max_length=max_length, + p=0.6, + seed=42, + ) + flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) + for pred in flatten_predictions: + outputs_count[pred] += 1 + self.assertAllClose( + outputs_count / np.sum(outputs_count), + [0.0, 0.0, 0.429, 0.571], + rtol=0.2, + ) + + def test_only_choose_from_top_p_tokens(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.4, 0.3, 0.2, 0.1]]) + return tf.repeat(prob, batch_size, axis=0) + + # Test that it only samples from tokens that sum up to p + for p, n in [(0.3, 1), (0.7, 2), (0.9, 3)]: + inputs = tf.constant([[0, 0], [0, 0]]) + for _ in range(10): + outputs = top_p_search( + token_probability_fn, inputs, max_length=5, p=p + ) + self.assertAllEqual(outputs < n, tf.ones_like(outputs)) + + def test_end_token_id(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + tf.random.set_seed(42) + outputs = top_p_search( + token_probability_fn, + inputs, + max_length=max_length, + p=0.92, + seed=1, + end_token_id=2, + pad_token_id=0, + ) + # Top-p sampling result with seed 42. + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) + + def test_from_logits(self): + def token_logits_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[1.0, 2.0, 3.0, 4.0]]) + return tf.repeat(prob, batch_size, axis=0) + + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.keras.activations.softmax( + tf.constant([[1.0, 2.0, 3.0, 4.0]]) + ) + return tf.repeat(prob, batch_size, axis=0) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + tf.random.set_seed(42) + output_logit = top_p_search( + token_logits_fn, + inputs, + max_length=max_length, + p=0.92, + from_logits=True, + seed=42, + ) + tf.random.set_seed(42) + output_probs = top_p_search( + token_probability_fn, + inputs, + max_length=max_length, + p=0.92, + from_logits=False, + seed=42, + ) + self.assertAllEqual(output_logit, output_probs) From aa22db63004d2672f9cb5fcab4da022fb03a2797 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Fri, 24 Jun 2022 18:16:49 -0700 Subject: [PATCH 2/6] made filter_value a default 0 --- keras_nlp/utils/text_generation.py | 7 ++++--- keras_nlp/utils/text_generation_test.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 101ebb1274..d44ecd2741 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -404,7 +404,6 @@ def top_p_search( from_logits=False, end_token_id=None, pad_token_id=0, - filter_value=-float("Inf"), ): """ Text generation utility based on top-p (nucleus) sampling. @@ -504,7 +503,7 @@ def token_probability_fn(inputs): pred = tf.keras.activations.softmax(pred, axis=-1) # Sort preds in descending order. sorted_preds, sorted_indices = tf.math.top_k( - pred, k=pred.shape[-1], sorted=True + pred, k=pred.shape[1], sorted=True ) # Calculate cumulative probability distribution. cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) @@ -516,7 +515,9 @@ def token_probability_fn(inputs): ) # Filter out unmasked tokens and sample from filtered distribution. probs = tf.where( - shifted_keep_mask, sorted_preds, tf.fill(pred.shape, filter_value) + shifted_keep_mask, + sorted_preds, + tf.zeros(pred.shape, dtype=sorted_preds.dtype) ) sorted_next_token = tf.random.categorical( tf.math.log(probs), 1, seed=seed diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index f87f899802..3afc9d7a45 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -103,7 +103,7 @@ def token_probability_fn(inputs): self.assertAllEqual(outputs, expected_outputs) -class RandomSamplingTextGenerationTest(tf.test.TestCase): +class RandomSearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() vocab_size = 10 @@ -245,7 +245,7 @@ def token_probability_fn(inputs): self.assertAllEqual(output_logit, output_probs) -class TopKSamplingTextGenerationTest(tf.test.TestCase): +class TopKSearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() vocab_size = 10 @@ -421,7 +421,7 @@ def token_probability_fn(inputs): self.assertAllEqual(output_logit, output_probs) -class TopPSamplingTextGenerationTest(tf.test.TestCase): +class TopPSearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() vocab_size = 10 From cb9e32d69256f73207b5d7b5c64abc9c21f0d86a Mon Sep 17 00:00:00 2001 From: jessechancy Date: Fri, 24 Jun 2022 18:19:11 -0700 Subject: [PATCH 3/6] style fixes --- keras_nlp/utils/text_generation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index d44ecd2741..e3a5818202 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -515,9 +515,9 @@ def token_probability_fn(inputs): ) # Filter out unmasked tokens and sample from filtered distribution. probs = tf.where( - shifted_keep_mask, - sorted_preds, - tf.zeros(pred.shape, dtype=sorted_preds.dtype) + shifted_keep_mask, + sorted_preds, + tf.zeros(pred.shape, dtype=sorted_preds.dtype), ) sorted_next_token = tf.random.categorical( tf.math.log(probs), 1, seed=seed From 7e6e6c2927e04fb4e18dcab83307c4ef2249d5e1 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Mon, 27 Jun 2022 12:32:06 -0700 Subject: [PATCH 4/6] minor changes --- keras_nlp/utils/text_generation.py | 2 -- keras_nlp/utils/text_generation_test.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index e3a5818202..fcca415bcd 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -432,8 +432,6 @@ def top_p_search( replaced with `pad_token_id`. pad_token_id: int, defaults to 0. The pad token after `end_token_id` is received. - filter_value: float, defaults to -Inf. The value for filtering out - unused tokens when sampling the probability distribution. Returns: A 1D int Tensor, or 2D int Tensor representing the generated diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index 3afc9d7a45..1640af6d04 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -525,7 +525,7 @@ def token_probability_fn(inputs): prob = tf.constant([[0.4, 0.3, 0.2, 0.1]]) return tf.repeat(prob, batch_size, axis=0) - # Test that it only samples from tokens that sum up to p + # Test that it only samples from tokens that sum up to p. for p, n in [(0.3, 1), (0.7, 2), (0.9, 3)]: inputs = tf.constant([[0, 0], [0, 0]]) for _ in range(10): From cd33b090f5987042f5e04b8834bff9cb7123e1d9 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Mon, 27 Jun 2022 17:27:48 -0700 Subject: [PATCH 5/6] minor changes and addition of empty prompt checks --- keras_nlp/utils/text_generation.py | 20 ++++++++++------ keras_nlp/utils/text_generation_test.py | 32 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index fcca415bcd..6b63d3333a 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -26,6 +26,10 @@ def validate_prompt(prompt): ) if not isinstance(prompt, tf.Tensor): prompt = tf.convert_to_tensor(prompt) + if prompt.shape[-1] == 0: + raise ValueError( + "Length of `prompt` is 0, please provide a non-empty `prompt`." + ) return prompt @@ -112,7 +116,7 @@ def token_probability_fn(inputs): prompt = tf.fill((BATCH_SIZE, 1), START_ID) - # Print the generated sequence (token ids). + # Print the generated Asequence (token ids). keras_nlp.utils.greedy_search( token_probability_fn, prompt, @@ -357,7 +361,7 @@ def token_probability_fn(inputs): "tf.function in eager mode." ) if k <= 0: - raise ValueError("k should be strictly positive (greater than 0).") + raise ValueError(f"`k` should strictly positive. Received: `k={k}`.") prompt = validate_prompt(prompt) input_is_1d = prompt.shape.rank == 1 @@ -408,8 +412,10 @@ def top_p_search( """ Text generation utility based on top-p (nucleus) sampling. - Top-p search filters the top token with probabilities that sum up to p, and - samples from this subset to get the next token. The probability of each + Top-p search selects tokens from the smallest subset of output probabilities + that sum to greater than `p`. Put another way, top-p will first order + token predictions by likelihood, and ignore all tokens after the cumulative + probability of selected tokens exceeds `p`. The probability of each token is provided by `token_probability_fn`. Args: @@ -477,14 +483,14 @@ def token_probability_fn(inputs): """ if not tf.executing_eagerly(): raise RuntimeError( - "`keras_nlp.utils.top_k_search` currently requires an eager " - "execution context. Please call `top_k_search` outside " + "`keras_nlp.utils.top_p_search` currently requires an eager " + "execution context. Please call `top_p_search` outside " "tf.function or run `tf.config.run_functions_eagerly(True)` to run " "tf.function in eager mode." ) if p <= 0 or p >= 1: raise ValueError( - "p should be a float strictly between 0 and 1 (0 < p < 1)." + f"`p` should be in the range (0, 1). Received: `p={p}`." ) prompt = validate_prompt(prompt) diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index 1640af6d04..5b1da81965 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -47,6 +47,14 @@ def token_probability_fn(inputs): self.token_probability_fn = token_probability_fn + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + greedy_search(self.token_probability_fn, inputs, max_length=5) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + greedy_search(self.token_probability_fn, inputs, max_length=5) + def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) @@ -127,6 +135,14 @@ def token_probability_fn(inputs): self.token_probability_fn = token_probability_fn + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + random_search(self.token_probability_fn, inputs, max_length=5) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + random_search(self.token_probability_fn, inputs, max_length=5) + def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = random_search(self.token_probability_fn, inputs, max_length=5) @@ -269,6 +285,14 @@ def token_probability_fn(inputs): self.token_probability_fn = token_probability_fn + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + top_k_search(self.token_probability_fn, inputs, max_length=5, k=2) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + top_k_search(self.token_probability_fn, inputs, max_length=5, k=2) + def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = top_k_search( @@ -445,6 +469,14 @@ def token_probability_fn(inputs): self.token_probability_fn = token_probability_fn + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + top_p_search(self.token_probability_fn, inputs, max_length=5, p=0.8) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + top_p_search(self.token_probability_fn, inputs, max_length=5, p=0.8) + def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = top_p_search( From 2588e6b9da8624b335041c5cd9c88a46d4e7f678 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 27 Jun 2022 17:44:05 -0700 Subject: [PATCH 6/6] Fix typo --- keras_nlp/utils/text_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 6b63d3333a..231d4162dd 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -116,7 +116,7 @@ def token_probability_fn(inputs): prompt = tf.fill((BATCH_SIZE, 1), START_ID) - # Print the generated Asequence (token ids). + # Print the generated sequence (token ids). keras_nlp.utils.greedy_search( token_probability_fn, prompt,