diff --git a/keras_nlp/utils/__init__.py b/keras_nlp/utils/__init__.py index 533f0aa8f9..28c1af5a67 100644 --- a/keras_nlp/utils/__init__.py +++ b/keras_nlp/utils/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from keras_nlp.utils.text_generation import greedy_search +from keras_nlp.utils.text_generation import random_search diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index f07f9b9a89..81b09decc1 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -17,6 +17,35 @@ import tensorflow as tf +def validate_prompt(prompt): + """ + Helper function to validate input to text_generation utils. + """ + if isinstance(prompt, tf.RaggedTensor): + raise ValueError( + "RaggedTensor `prompt` is not supported, please " + "provide `prompt` as a list or Tensor." + ) + if not isinstance(prompt, tf.Tensor): + prompt = tf.convert_to_tensor(prompt) + return prompt + + +def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id): + """ + Helper function to mask the tokens after the end token. + """ + # Mask out tokens after `end_token_id` is encountered. + # Find index of first end_token_id. + end_indices = tf.math.argmax(prompt == end_token_id, -1) + # Use max_length if no `end_token_id` is found. + end_indices = tf.where(end_indices == 0, max_length, end_indices) + # Build a mask including end_token and replace tokens after end_token + # with `pad_token_id`. + valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) + return tf.where(valid_indices, prompt, pad_token_id) + + def greedy_search( token_probability_fn, prompt, @@ -88,13 +117,9 @@ def token_probability_fn(inputs): "tf.function or run `tf.config.run_functions_eagerly(True)` to run " "tf.function in eager mode." ) - if isinstance(prompt, tf.RaggedTensor): - raise ValueError( - "RaggedTensor `prompt` is not supported, please " - "provide `prompt` as a list or Tensor." - ) - if not isinstance(prompt, tf.Tensor): - prompt = tf.convert_to_tensor(prompt) + + prompt = validate_prompt(prompt) + input_is_1d = prompt.shape.rank == 1 if input_is_1d: prompt = prompt[tf.newaxis, :] @@ -109,16 +134,111 @@ def token_probability_fn(inputs): i += 1 if end_token_id is not None: - # Mask out tokens after `end_token_id` is encountered. - # Find index of first end_token_id. - end_indices = tf.math.argmax(prompt == end_token_id, -1) - # Use max_length if no `end_token_id` is found. - end_indices = tf.where(end_indices == 0, max_length, end_indices) - # Build a mask including end_token and replace tokens after end_token - # with `pad_token_id`. - valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length) - prompt = tf.where(valid_indices, prompt, pad_token_id) + prompt = mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id + ) if input_is_1d: return tf.squeeze(prompt) return prompt + + +def random_search( + token_probability_fn, + prompt, + max_length, + seed=None, + end_token_id=None, + pad_token_id=0, +): + """ + Text generation utility based on randomly sampling the entire probability + distribution. + + Random sampling samples the next token from the probability distribution + provided by `token_probability_fn` and appends it to the existing sequence. + + Args: + token_probability_fn: a callable, which takes in input_sequence + and output the probability distribution of the next token. + prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to + append generated tokens. + max_length: int. The max length of generated text. + seed: int, defaults to None. The random seed used for sampling. + end_token_id: int, defaults to None. The token marking the end of the + sequence, once encountered the generation is finished for the exact + sequence. If None, every sequence is generated up to `max_length`. + If set, all tokens after encountering `end_token_id` will be + replaced with `pad_token_id`. + pad_token_id: int, defaults to 0. The pad token after `end_token_id` + is received. + + Returns: + A 1D int Tensor, or 2D int Tensor representing the generated + sequences. + + Examples: + ```python + VOCAB_SIZE = 10 + FEATURE_SIZE = 16 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=VOCAB_SIZE, + output_dim=FEATURE_SIZE, + ), + tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"), + ] + ) + + # Define a function that outputs the next token's probability given the + # input sequence. + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64) + + # Print the generated sequence (token ids). + keras_nlp.utils.random_sampling( + token_probability_fn, + prompt, + max_length=10, + end_token_id=0,) + ``` + + """ + if not tf.executing_eagerly(): + raise RuntimeError( + "`keras_nlp.utils.random_sampling` currently requires an eager " + "execution context. Please call `random_sampling` outside " + "tf.function or run `tf.config.run_functions_eagerly(True)` to run " + "tf.function in eager mode." + ) + + prompt = validate_prompt(prompt) + input_is_1d = prompt.shape.rank == 1 + if input_is_1d: + prompt = prompt[tf.newaxis, :] + + i = prompt.shape[1] + while i < max_length: + # If the prompt has reached our desired length, exit while loop. + pred = token_probability_fn(prompt) + next_token = tf.cast( + tf.random.categorical(tf.math.log(pred), 1, seed=seed), + dtype=prompt.dtype, + ) + # Append the next token to current sequence. + prompt = tf.concat([prompt, next_token], axis=-1) + i += 1 + + if end_token_id is not None: + prompt = mask_tokens_after_end_token( + prompt, max_length, end_token_id, pad_token_id + ) + if input_is_1d: + return tf.squeeze(prompt) + return prompt diff --git a/keras_nlp/utils/text_generation_test.py b/keras_nlp/utils/text_generation_test.py index dc4686832f..54e4bed209 100644 --- a/keras_nlp/utils/text_generation_test.py +++ b/keras_nlp/utils/text_generation_test.py @@ -14,12 +14,14 @@ """Tests for Text Generation Utils.""" +import numpy as np import tensorflow as tf from keras_nlp.utils.text_generation import greedy_search +from keras_nlp.utils.text_generation import random_search -class TextGenerationTest(tf.test.TestCase): +class GreedySearchTextGenerationTest(tf.test.TestCase): def setUp(self): super().setUp() vocab_size = 10 @@ -66,7 +68,7 @@ def test_generate_with_ragged_prompt(self): def test_assert_generation_is_correct(self): def token_probability_fn(inputs): batch_size = inputs.shape[0] - prob = tf.constant([[0.1, 0.2, 0.3, 0.4]]) + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) return tf.repeat(prob, batch_size, axis=0) batch_size = 10 @@ -82,7 +84,7 @@ def token_probability_fn(inputs): def test_end_token_id(self): def token_probability_fn(inputs): batch_size = inputs.shape[0] - prob = tf.constant([[0.1, 0.2, 0.3, 0.4]]) + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) return tf.repeat(prob, batch_size, axis=0) max_length = 5 @@ -96,5 +98,113 @@ def token_probability_fn(inputs): ) expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) expected_outputs = tf.concat([inputs, expected_outputs], axis=1) + self.assertAllEqual(outputs, expected_outputs) + + +class RandomSamplingTextGenerationTest(tf.test.TestCase): + def setUp(self): + super().setUp() + vocab_size = 10 + feature_size = 16 + + # Create a dummy model to predict the next token. + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=[None]), + tf.keras.layers.Embedding( + input_dim=vocab_size, + output_dim=feature_size, + ), + tf.keras.layers.Dense(vocab_size), + tf.keras.layers.Softmax(), + ] + ) + + def token_probability_fn(inputs): + return model(inputs)[:, -1, :] + + self.token_probability_fn = token_probability_fn + + def test_generate_with_1d_prompt(self): + inputs = tf.constant([1]) + outputs = random_search(self.token_probability_fn, inputs, max_length=5) + self.assertEquals(outputs.shape, [5]) + + def test_generate_with_2d_prompt(self): + inputs = tf.constant([[1], [1]]) + outputs = random_search(self.token_probability_fn, inputs, max_length=5) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_list_prompt(self): + inputs = [[1], [1]] + outputs = random_search(self.token_probability_fn, inputs, max_length=5) + self.assertEquals(outputs.shape, [2, 5]) + + def test_generate_with_ragged_prompt(self): + inputs = tf.ragged.constant([[1], [2, 3]]) + with self.assertRaises(ValueError): + random_search(self.token_probability_fn, inputs, max_length=5) + + def test_assert_seeded_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + tf.random.set_seed(42) + outputs = random_search( + token_probability_fn, inputs, max_length=max_length, seed=42 + ) + # Random sampling result with seed 42 + seeded_result = 3 * np.ones(shape=[batch_size, max_length]) + self.assertAllEqual(outputs, seeded_result) + + def test_assert_probability_distribution_generation_is_correct(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + + batch_size = 10 + inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) + max_length = 3 + + outputs_count = np.array([0, 0, 0, 0]) + tf.random.set_seed(42) + for i in range(500): + outputs = random_search( + token_probability_fn, inputs, max_length=max_length, seed=42 + ) + flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) + for pred in flatten_predictions: + outputs_count[pred] += 1 + self.assertAllClose( + outputs_count / np.sum(outputs_count), + [0.01, 0.01, 0.08, 0.9], + rtol=0.2, + ) + + def test_end_token_id(self): + def token_probability_fn(inputs): + batch_size = inputs.shape[0] + prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) + return tf.repeat(prob, batch_size, axis=0) + max_length = 5 + inputs = tf.constant([[0, 1], [1, 2]]) + + outputs = random_search( + token_probability_fn, + inputs, + max_length=max_length, + seed=42, + end_token_id=2, + pad_token_id=0, + ) + # Random sampling result with seed 42 + expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) + expected_outputs = tf.concat([inputs, expected_outputs], axis=1) self.assertAllEqual(outputs, expected_outputs)