Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 149 additions & 1 deletion keras_nlp/utils/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def validate_prompt(prompt):
)
if not isinstance(prompt, tf.Tensor):
prompt = tf.convert_to_tensor(prompt)
if prompt.shape[-1] == 0:
raise ValueError(
"Length of `prompt` is 0, please provide a non-empty `prompt`."
)
return prompt


Expand Down Expand Up @@ -357,7 +361,7 @@ def token_probability_fn(inputs):
"tf.function in eager mode."
)
if k <= 0:
raise ValueError("k should be strictly positive (greater than 0).")
raise ValueError(f"`k` should strictly positive. Received: `k={k}`.")

prompt = validate_prompt(prompt)
input_is_1d = prompt.shape.rank == 1
Expand Down Expand Up @@ -393,3 +397,147 @@ def token_probability_fn(inputs):
if input_is_1d:
return tf.squeeze(prompt)
return prompt


def top_p_search(
token_probability_fn,
prompt,
max_length,
p,
seed=None,
from_logits=False,
end_token_id=None,
pad_token_id=0,
):
"""
Text generation utility based on top-p (nucleus) sampling.

Top-p search selects tokens from the smallest subset of output probabilities
that sum to greater than `p`. Put another way, top-p will first order
token predictions by likelihood, and ignore all tokens after the cumulative
probability of selected tokens exceeds `p`. The probability of each
token is provided by `token_probability_fn`.

Args:
token_probability_fn: a callable, which takes in input_sequence
and output the probability distribution of the next token. If
`from_logits` set to True, it should output the logits of the next
token.
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
append generated tokens.
max_length: int. The max length of generated text.
p: float. The probability that the top tokens sums up to. Should
follow the constraint of 0 < p < 1.
seed: int, defaults to None. The random seed used for sampling.
from_logits: bool. Indicates whether `token_probability_fn` outputs
logits or probabilities.
end_token_id: int, defaults to None. The token marking the end of the
sequence, once encountered the generation is finished for the exact
sequence. If None, every sequence is generated up to `max_length`.
If set, all tokens after encountering `end_token_id` will be
replaced with `pad_token_id`.
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
is received.

Returns:
A 1D int Tensor, or 2D int Tensor representing the generated
sequences.

Examples:
```python
BATCH_SIZE = 8
VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID = 1
END_ID = 2

# Create a dummy model to predict the next token.
model = tf.keras.Sequential(
[
tf.keras.Input(shape=[None]),
tf.keras.layers.Embedding(
input_dim=VOCAB_SIZE,
output_dim=FEATURE_SIZE,
),
tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
]
)

# Define a function that outputs the next token's probability given the
# input sequence.
def token_probability_fn(inputs):
return model(inputs)[:, -1, :]

prompt = tf.fill((BATCH_SIZE, 1), START_ID)

# Print the generated sequence (token ids).
keras_nlp.utils.top_p_search(
token_probability_fn,
prompt,
max_length=10,
p=0.8,
end_token_id=END_ID,
)
```

"""
if not tf.executing_eagerly():
raise RuntimeError(
"`keras_nlp.utils.top_p_search` currently requires an eager "
"execution context. Please call `top_p_search` outside "
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)
if p <= 0 or p >= 1:
raise ValueError(
f"`p` should be in the range (0, 1). Received: `p={p}`."
)

prompt = validate_prompt(prompt)
input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = prompt[tf.newaxis, :]
validate_token_probability_fn(token_probability_fn, prompt)

i = prompt.shape[1]
while i < max_length:
# If the prompt has reached our desired length, exit while loop.
pred = token_probability_fn(prompt)
if from_logits:
pred = tf.keras.activations.softmax(pred, axis=-1)
# Sort preds in descending order.
sorted_preds, sorted_indices = tf.math.top_k(
pred, k=pred.shape[1], sorted=True
)
# Calculate cumulative probability distribution.
cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1)
# Create a mask for the tokens to keep.
keep_mask = cumulative_probs <= p
# Shift to include the last token that exceed p.
shifted_keep_mask = tf.concat(
[tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
)
# Filter out unmasked tokens and sample from filtered distribution.
probs = tf.where(
shifted_keep_mask,
sorted_preds,
tf.zeros(pred.shape, dtype=sorted_preds.dtype),
)
sorted_next_token = tf.random.categorical(
tf.math.log(probs), 1, seed=seed
)
next_token = tf.gather_nd(
sorted_indices, sorted_next_token, batch_dims=1
)
next_token = tf.cast(next_token, dtype=prompt.dtype)
# Append the next token to current sequence.
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
i += 1

if end_token_id is not None:
prompt = mask_tokens_after_end_token(
prompt, max_length, end_token_id, pad_token_id
)
if input_is_1d:
return tf.squeeze(prompt)
return prompt
Loading