Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from keras_nlp.utils.text_generation import greedy_search
from keras_nlp.utils.text_generation import random_search
152 changes: 136 additions & 16 deletions keras_nlp/utils/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,35 @@
import tensorflow as tf


def validate_prompt(prompt):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, with the whole docstring for a helper it's hard to tell it's just a helper function. In general I don't think you would need the whole args/return structure for something small like this.

Just Helper function to validate input to text_generation utils.

Helper function to validate input to text_generation utils.
"""
if isinstance(prompt, tf.RaggedTensor):
raise ValueError(
"RaggedTensor `prompt` is not supported, please "
"provide `prompt` as a list or Tensor."
)
if not isinstance(prompt, tf.Tensor):
prompt = tf.convert_to_tensor(prompt)
return prompt


def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id):
"""
Helper function to mask the tokens after the end token.
"""
# Mask out tokens after `end_token_id` is encountered.
# Find index of first end_token_id.
end_indices = tf.math.argmax(prompt == end_token_id, -1)
# Use max_length if no `end_token_id` is found.
end_indices = tf.where(end_indices == 0, max_length, end_indices)
# Build a mask including end_token and replace tokens after end_token
# with `pad_token_id`.
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
return tf.where(valid_indices, prompt, pad_token_id)


def greedy_search(
token_probability_fn,
prompt,
Expand Down Expand Up @@ -88,13 +117,9 @@ def token_probability_fn(inputs):
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)
if isinstance(prompt, tf.RaggedTensor):
raise ValueError(
"RaggedTensor `prompt` is not supported, please "
"provide `prompt` as a list or Tensor."
)
if not isinstance(prompt, tf.Tensor):
prompt = tf.convert_to_tensor(prompt)

prompt = validate_prompt(prompt)

input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = prompt[tf.newaxis, :]
Expand All @@ -109,16 +134,111 @@ def token_probability_fn(inputs):
i += 1

if end_token_id is not None:
# Mask out tokens after `end_token_id` is encountered.
# Find index of first end_token_id.
end_indices = tf.math.argmax(prompt == end_token_id, -1)
# Use max_length if no `end_token_id` is found.
end_indices = tf.where(end_indices == 0, max_length, end_indices)
# Build a mask including end_token and replace tokens after end_token
# with `pad_token_id`.
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
prompt = tf.where(valid_indices, prompt, pad_token_id)
prompt = mask_tokens_after_end_token(
prompt, max_length, end_token_id, pad_token_id
)

if input_is_1d:
return tf.squeeze(prompt)
return prompt


def random_search(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add this to the __init__.py for the utils dir, so this gets exported.

token_probability_fn,
prompt,
max_length,
seed=None,
end_token_id=None,
pad_token_id=0,
):
"""
Text generation utility based on randomly sampling the entire probability
distribution.

Random sampling samples the next token from the probability distribution
provided by `token_probability_fn` and appends it to the existing sequence.

Args:
token_probability_fn: a callable, which takes in input_sequence
and output the probability distribution of the next token.
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
append generated tokens.
max_length: int. The max length of generated text.
seed: int, defaults to None. The random seed used for sampling.
end_token_id: int, defaults to None. The token marking the end of the
sequence, once encountered the generation is finished for the exact
sequence. If None, every sequence is generated up to `max_length`.
If set, all tokens after encountering `end_token_id` will be
replaced with `pad_token_id`.
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
is received.

Returns:
A 1D int Tensor, or 2D int Tensor representing the generated
sequences.

Examples:
```python
VOCAB_SIZE = 10
FEATURE_SIZE = 16

# Create a dummy model to predict the next token.
model = tf.keras.Sequential(
[
tf.keras.Input(shape=[None]),
tf.keras.layers.Embedding(
input_dim=VOCAB_SIZE,
output_dim=FEATURE_SIZE,
),
tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
]
)

# Define a function that outputs the next token's probability given the
# input sequence.
def token_probability_fn(inputs):
return model(inputs)[:, -1, :]

prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64)

# Print the generated sequence (token ids).
keras_nlp.utils.random_sampling(
token_probability_fn,
prompt,
max_length=10,
end_token_id=0,)
```

"""
if not tf.executing_eagerly():
raise RuntimeError(
"`keras_nlp.utils.random_sampling` currently requires an eager "
"execution context. Please call `random_sampling` outside "
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)

prompt = validate_prompt(prompt)
input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = prompt[tf.newaxis, :]

i = prompt.shape[1]
while i < max_length:
# If the prompt has reached our desired length, exit while loop.
pred = token_probability_fn(prompt)
next_token = tf.cast(
tf.random.categorical(tf.math.log(pred), 1, seed=seed),
dtype=prompt.dtype,
)
# Append the next token to current sequence.
prompt = tf.concat([prompt, next_token], axis=-1)
i += 1

if end_token_id is not None:
prompt = mask_tokens_after_end_token(
prompt, max_length, end_token_id, pad_token_id
)
if input_is_1d:
return tf.squeeze(prompt)
return prompt
116 changes: 113 additions & 3 deletions keras_nlp/utils/text_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
"""Tests for Text Generation Utils."""


import numpy as np
import tensorflow as tf

from keras_nlp.utils.text_generation import greedy_search
from keras_nlp.utils.text_generation import random_search


class TextGenerationTest(tf.test.TestCase):
class GreedySearchTextGenerationTest(tf.test.TestCase):
def setUp(self):
super().setUp()
vocab_size = 10
Expand Down Expand Up @@ -66,7 +68,7 @@ def test_generate_with_ragged_prompt(self):
def test_assert_generation_is_correct(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.1, 0.2, 0.3, 0.4]])
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
return tf.repeat(prob, batch_size, axis=0)

batch_size = 10
Expand All @@ -82,7 +84,7 @@ def token_probability_fn(inputs):
def test_end_token_id(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.1, 0.2, 0.3, 0.4]])
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
return tf.repeat(prob, batch_size, axis=0)

max_length = 5
Expand All @@ -96,5 +98,113 @@ def token_probability_fn(inputs):
)
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
self.assertAllEqual(outputs, expected_outputs)


class RandomSamplingTextGenerationTest(tf.test.TestCase):
def setUp(self):
super().setUp()
vocab_size = 10
feature_size = 16

# Create a dummy model to predict the next token.
model = tf.keras.Sequential(
[
tf.keras.Input(shape=[None]),
tf.keras.layers.Embedding(
input_dim=vocab_size,
output_dim=feature_size,
),
tf.keras.layers.Dense(vocab_size),
tf.keras.layers.Softmax(),
]
)

def token_probability_fn(inputs):
return model(inputs)[:, -1, :]

self.token_probability_fn = token_probability_fn

def test_generate_with_1d_prompt(self):
inputs = tf.constant([1])
outputs = random_search(self.token_probability_fn, inputs, max_length=5)
self.assertEquals(outputs.shape, [5])

def test_generate_with_2d_prompt(self):
inputs = tf.constant([[1], [1]])
outputs = random_search(self.token_probability_fn, inputs, max_length=5)
self.assertEquals(outputs.shape, [2, 5])

def test_generate_with_list_prompt(self):
inputs = [[1], [1]]
outputs = random_search(self.token_probability_fn, inputs, max_length=5)
self.assertEquals(outputs.shape, [2, 5])

def test_generate_with_ragged_prompt(self):
inputs = tf.ragged.constant([[1], [2, 3]])
with self.assertRaises(ValueError):
random_search(self.token_probability_fn, inputs, max_length=5)

def test_assert_seeded_generation_is_correct(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
return tf.repeat(prob, batch_size, axis=0)

batch_size = 10
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
max_length = 3
tf.random.set_seed(42)
outputs = random_search(
token_probability_fn, inputs, max_length=max_length, seed=42
)
# Random sampling result with seed 42
seeded_result = 3 * np.ones(shape=[batch_size, max_length])
self.assertAllEqual(outputs, seeded_result)

def test_assert_probability_distribution_generation_is_correct(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
return tf.repeat(prob, batch_size, axis=0)

batch_size = 10
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
max_length = 3

outputs_count = np.array([0, 0, 0, 0])
tf.random.set_seed(42)
for i in range(500):
outputs = random_search(
token_probability_fn, inputs, max_length=max_length, seed=42
)
flatten_predictions = tf.reshape(outputs[:, 1:], [-1])
for pred in flatten_predictions:
outputs_count[pred] += 1
self.assertAllClose(
outputs_count / np.sum(outputs_count),
[0.01, 0.01, 0.08, 0.9],
rtol=0.2,
)

def test_end_token_id(self):
def token_probability_fn(inputs):
batch_size = inputs.shape[0]
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
return tf.repeat(prob, batch_size, axis=0)

max_length = 5
inputs = tf.constant([[0, 1], [1, 2]])

outputs = random_search(
token_probability_fn,
inputs,
max_length=max_length,
seed=42,
end_token_id=2,
pad_token_id=0,
)
# Random sampling result with seed 42
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
self.assertAllEqual(outputs, expected_outputs)