diff --git a/keras_nlp/utils/text_generation.py b/keras_nlp/utils/text_generation.py index 0e381f69e3..8e4f95a8a4 100644 --- a/keras_nlp/utils/text_generation.py +++ b/keras_nlp/utils/text_generation.py @@ -18,6 +18,8 @@ from absl import logging from tensorflow import keras +from keras_nlp.api_export import keras_nlp_export + def _validate_prompt(prompt): """Helper function to validate input to text_generation utils.""" @@ -95,6 +97,7 @@ def _mask_tokens_after_end_token( return tf.where(valid_indices, prompt, pad_token_id) +@keras_nlp_export("keras_nlp.utils.greedy_search") def greedy_search( token_probability_fn, prompt, @@ -210,6 +213,7 @@ def one_step(length, prompt): return tf.squeeze(prompt) if input_is_1d else prompt +@keras_nlp_export("keras_nlp.utils.beam_search") def beam_search( token_probability_fn, prompt, @@ -400,6 +404,7 @@ def one_step(beams, beams_prob, length): return tf.squeeze(prompt) if input_is_1d else prompt +@keras_nlp_export("keras_nlp.utils.random_search") def random_search( token_probability_fn, prompt, @@ -530,6 +535,7 @@ def one_step(length, prompt): return tf.squeeze(prompt) if input_is_1d else prompt +@keras_nlp_export("keras_nlp.utils.top_k_search") def top_k_search( token_probability_fn, prompt, @@ -680,6 +686,7 @@ def one_step(length, prompt): return tf.squeeze(prompt) if input_is_1d else prompt +@keras_nlp_export("keras_nlp.utils.top_p_search") def top_p_search( token_probability_fn, prompt,