diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 4c7583633c..231d4162dd 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -26,6 +26,10 @@ def validate_prompt(prompt): ) if not isinstance(prompt, tf.Tensor): prompt = tf.convert_to_tensor(prompt) + if prompt.shape[-1] == 0: + raise ValueError( + "Length of `prompt` is 0, please provide a non-empty `prompt`." + ) return prompt @@ -357,7 +361,7 @@ def token_probability_fn(inputs): "tf.function in eager mode." ) if k <= 0: - raise ValueError("k should be strictly positive (greater than 0).") + raise ValueError(f"`k` should strictly positive. Received: `k={k}`.") prompt = validate_prompt(prompt) input_is_1d = prompt.shape.rank == 1 @@ -393,3 +397,147 @@ def token_probability_fn(inputs): if input_is_1d: return tf.squeeze(prompt) return prompt + + +def top_p_search( + token_probability_fn, + prompt, + max_length, + p, + seed=None, + from_logits=False, + end_token_id=None, + pad_token_id=0, +): + """ + Text generation utility based on top-p (nucleus) sampling. + + Top-p search selects tokens from the smallest subset of output probabilities + that sum to greater than `p`. Put another way, top-p will first order + token predictions by likelihood, and ignore all tokens after the cumulative + probability of selected tokens exceeds `p`. The probability of each + token is provided by `token_probability_fn`. + + Args: + token_probability_fn: a callable, which takes in input_sequence + and output the probability distribution of the next token. If + `from_logits` set to True, it should output the logits of the next + token. + prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to + append generated tokens. + max_length: int. The max length of generated text. + p: float. The probability that the top tokens sums up to. Should + follow the constraint of 0 < p < 1. + seed: int, defaults to None. The random seed used for sampling. + from_logits: bool. Indicates whether `token_probability_fn` outputs + logits or probabilities. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + + Returns: + A 1D int Tensor, or 2D int Tensor representing the generated + sequences. + + Examples: + ```python + BATCH_SIZE = 8 + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + START_ID = 1 + END_ID = 2 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability given the + # input sequence. + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + prompt = tf.fill((BATCH_SIZE, 1), START_ID) + + # Print the generated sequence (token ids). + keras_nlp.utils.top_p_search( + token_probability_fn, + prompt, + max_length=10, + p=0.8, + end_token_id=END_ID, + ) + ``` + + """ + if not tf.executing_eagerly(): + raise RuntimeError( + "`keras_nlp.utils.top_p_search` currently requires an eager " + "execution context. Please call `top_p_search` outside " + "tf.function or run `tf.config.run_functions_eagerly(True)` to run " + "tf.function in eager mode." + ) + if p <= 0 or p >= 1: + raise ValueError( + f"`p` should be in the range (0, 1). Received: `p={p}`." + ) + + prompt = validate_prompt(prompt) + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + validate_token_probability_fn(token_probability_fn, prompt) + + i = prompt.shape[1] + while i < max_length: + # If the prompt has reached our desired length, exit while loop. + pred = token_probability_fn(prompt) + if from_logits: + pred = tf.keras.activations.softmax(pred, axis=-1) + # Sort preds in descending order. + sorted_preds, sorted_indices = tf.math.top_k( + pred, k=pred.shape[1], sorted=True + ) + # Calculate cumulative probability distribution. + cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) + # Create a mask for the tokens to keep. + keep_mask = cumulative_probs <= p + # Shift to include the last token that exceed p. + shifted_keep_mask = tf.concat( + [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1 + ) + # Filter out unmasked tokens and sample from filtered distribution. + probs = tf.where( + shifted_keep_mask, + sorted_preds, + tf.zeros(pred.shape, dtype=sorted_preds.dtype), + ) + sorted_next_token = tf.random.categorical( + tf.math.log(probs), 1, seed=seed + ) + next_token = tf.gather_nd( + sorted_indices, sorted_next_token, batch_dims=1 + ) + next_token = tf.cast(next_token, dtype=prompt.dtype) + # Append the next token to current sequence. + prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1) + i += 1 + + if end_token_id is not None: + prompt = mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id + ) + if input_is_1d: + return tf.squeeze(prompt) + return prompt diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index fab1e4e0e4..5b1da81965 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -20,6 +20,7 @@ from keras_nlp.utils.text_generation import greedy_search from keras_nlp.utils.text_generation import random_search from keras_nlp.utils.text_generation import top_k_search +from keras_nlp.utils.text_generation import top_p_search class GreedySearchTextGenerationTest(tf.test.TestCase): @@ -46,6 +47,14 @@ def token_probability_fn(inputs): self.token_probability_fn = token_probability_fn + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + greedy_search(self.token_probability_fn, inputs, max_length=5) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + greedy_search(self.token_probability_fn, inputs, max_length=5) + def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = greedy_search(self.token_probability_fn, inputs, max_length=5) @@ -102,7 +111,7 @@ def token_probability_fn(inputs): self.assertAllEqual(outputs, expected_outputs) -class RandomSamplingTextGenerationTest(tf.test.TestCase): +class RandomSearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() vocab_size = 10 @@ -126,6 +135,14 @@ def token_probability_fn(inputs): self.token_probability_fn = token_probability_fn + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + random_search(self.token_probability_fn, inputs, max_length=5) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + random_search(self.token_probability_fn, inputs, max_length=5) + def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = random_search(self.token_probability_fn, inputs, max_length=5) @@ -244,7 +261,7 @@ def token_probability_fn(inputs): self.assertAllEqual(output_logit, output_probs) -class TopKSamplingTextGenerationTest(tf.test.TestCase): +class TopKSearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() vocab_size = 10 @@ -268,6 +285,14 @@ def token_probability_fn(inputs): self.token_probability_fn = token_probability_fn + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + top_k_search(self.token_probability_fn, inputs, max_length=5, k=2) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + top_k_search(self.token_probability_fn, inputs, max_length=5, k=2) + def test_generate_with_1d_prompt(self): inputs = tf.constant([1]) outputs = top_k_search( @@ -387,13 +412,13 @@ def token_probability_fn(inputs): def test_from_logits(self): def token_logits_fn(inputs): batch_size = inputs.shape[0] - prob = tf.constant([[1.0, 2.0, 3, 0, 4.0]]) + prob = tf.constant([[1.0, 2.0, 3.0, 4.0]]) return tf.repeat(prob, batch_size, axis=0) def token_probability_fn(inputs): batch_size = inputs.shape[0] prob = tf.keras.activations.softmax( - tf.constant([[1.0, 2.0, 3, 0, 4.0]]) + tf.constant([[1.0, 2.0, 3.0, 4.0]]) ) return tf.repeat(prob, batch_size, axis=0) @@ -418,3 +443,183 @@ def token_probability_fn(inputs): seed=42, ) self.assertAllEqual(output_logit, output_probs) + + +class TopPSearchTextGenerationTest(tf.test.TestCase): + def setUp(self): + super().setUp() + vocab_size = 10 + feature_size = 16 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=vocab_size, + output_dim=feature_size, + ), + tf.keras.layers.Dense(vocab_size), + tf.keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + self.token_probability_fn = token_probability_fn + + def test_generate_with_empty_prompt(self): + inputs = tf.constant([]) + with self.assertRaises(ValueError): + top_p_search(self.token_probability_fn, inputs, max_length=5, p=0.8) + inputs = tf.constant([[]]) + with self.assertRaises(ValueError): + top_p_search(self.token_probability_fn, inputs, max_length=5, p=0.8) + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = top_p_search( + self.token_probability_fn, inputs, max_length=5, p=0.8 + ) + self.assertEquals(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = top_p_search( + self.token_probability_fn, inputs, max_length=5, p=0.8 + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = top_p_search( + self.token_probability_fn, inputs, max_length=5, p=0.8 + ) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + inputs = tf.ragged.constant([[1], [2, 3]]) + with self.assertRaises(ValueError): + top_p_search(self.token_probability_fn, inputs, max_length=5, p=0.8) + + def test_assert_seeded_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + tf.random.set_seed(42) + outputs = top_p_search( + token_probability_fn, inputs, max_length=max_length, p=0.91, seed=42 + ) + # Top-p sampling result with seed 42. + seeded_result = 3 * np.ones(shape=[batch_size, max_length]) + seeded_result[3][1] = 2 + seeded_result[7][1] = 2 + self.assertAllEqual(outputs, seeded_result) + + def test_assert_probability_distribution_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.1, 0.2, 0.3, 0.4]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + + outputs_count = np.array([0, 0, 0, 0]) + tf.random.set_seed(42) + for i in range(500): + outputs = top_p_search( + token_probability_fn, + inputs, + max_length=max_length, + p=0.6, + seed=42, + ) + flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) + for pred in flatten_predictions: + outputs_count[pred] += 1 + self.assertAllClose( + outputs_count / np.sum(outputs_count), + [0.0, 0.0, 0.429, 0.571], + rtol=0.2, + ) + + def test_only_choose_from_top_p_tokens(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.4, 0.3, 0.2, 0.1]]) + return tf.repeat(prob, batch_size, axis=0) + + # Test that it only samples from tokens that sum up to p. + for p, n in [(0.3, 1), (0.7, 2), (0.9, 3)]: + inputs = tf.constant([[0, 0], [0, 0]]) + for _ in range(10): + outputs = top_p_search( + token_probability_fn, inputs, max_length=5, p=p + ) + self.assertAllEqual(outputs < n, tf.ones_like(outputs)) + + def test_end_token_id(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + tf.random.set_seed(42) + outputs = top_p_search( + token_probability_fn, + inputs, + max_length=max_length, + p=0.92, + seed=1, + end_token_id=2, + pad_token_id=0, + ) + # Top-p sampling result with seed 42. + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) + + def test_from_logits(self): + def token_logits_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[1.0, 2.0, 3.0, 4.0]]) + return tf.repeat(prob, batch_size, axis=0) + + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.keras.activations.softmax( + tf.constant([[1.0, 2.0, 3.0, 4.0]]) + ) + return tf.repeat(prob, batch_size, axis=0) + + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + tf.random.set_seed(42) + output_logit = top_p_search( + token_logits_fn, + inputs, + max_length=max_length, + p=0.92, + from_logits=True, + seed=42, + ) + tf.random.set_seed(42) + output_probs = top_p_search( + token_probability_fn, + inputs, + max_length=max_length, + p=0.92, + from_logits=False, + seed=42, + ) + self.assertAllEqual(output_logit, output_probs)