diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index d6fbf26e13..4dc02ff8bf 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -17,7 +17,7 @@ import tensorflow as tf -import keras_nlp +from keras_nlp import samplers from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( @@ -25,9 +25,9 @@ ) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.samplers import BeamSampler -from keras_nlp.samplers import serialize +from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tf_utils import truncate_at_token @keras_nlp_export("keras_nlp.models.GPT2CasualLM") @@ -37,8 +37,13 @@ class GPT2CausalLM(Task): A causal language model (LM) predicts the next token based on previous tokens the next token based on previous tokens, which is the way GPT2 gets pretrained. You can finetune `GPT2CausalLM` to generate text similar to - the custom dataset. `GPT2CausalLM` also has a method `generate()`, which - generates text based on given prompt. + the custom dataset. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. This model can optionally be configured with a `preprocessor` layer, in which case it will automatically apply preprocessing to raw inputs during @@ -67,15 +72,13 @@ class GPT2CausalLM(Task): gpt2_lm.generate(["This is a", "Where are you"], max_length=30) ``` - Use a custom sampler for text generation. + Compile the `generate()` function with custom samplers. ```python gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + gpt2_lm.compile(sampler="top_p") + gpt2_lm.generate("I want to say", max_length=30) - # Use string identifier to set sampler. - gpt2_lm.generate("I want to say", max_length=30, sampler="top_p") - - # Construct a sampler instance. - sampler = keras_nlp.samplers.BeamSampler(num_beams=2) + gpt2_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) ``` @@ -189,7 +192,9 @@ def __init__( self.backbone = backbone self.preprocessor = preprocessor - self.sampler = None + self.generate_function = None + # Private sampler set by compile. + self._sampler = samplers.get("top_k") @classproperty def presets(cls): @@ -203,7 +208,7 @@ def backbone_cls(cls): def preprocessor_cls(cls): return GPT2CausalLMPreprocessor - def call_with_cache(self, token_ids, padding_mask, cache, cache_index): + def call_with_cache(self, token_ids, cache, cache_index): """Forward pass of `GPT2CausalLM` with cache. `call_with_cache` adds an additional forward pass for the model for @@ -213,7 +218,6 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index): Args: token_ids: a dense int Tensor, input token ids. - padding_mask: a dense bool Tensor, input padding mask. cache: a dense float Tensor, the cache of key and value. cache_index: int, or int Tensor. The index of current inputs in the whole sequence. @@ -237,7 +241,6 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index): current_cache = caches[i] x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")( x, - decoder_padding_mask=padding_mask, cache=current_cache, cache_index=cache_index, ) @@ -251,30 +254,78 @@ def call_with_cache(self, token_ids, padding_mask, cache, cache_index): ) return x, cache - def build_empty_cache(self, batch_size, max_length): + def _build_cache(self, prompt): """Build an empty cache for use with `call_with_cache()`.""" + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] num_layers = self.backbone.num_layers num_heads = self.backbone.num_heads head_dim = self.backbone.hidden_dim // self.backbone.num_heads shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] - return tf.zeros(shape) + cache = tf.zeros(shape) + # Seed the cache. + _, cache = self.call_with_cache(prompt, cache, 0) + return cache - def _get_token_probability( + def compile( self, - prompt, - mask, - cache=None, - cache_index=None, + *args, + run_eagerly=False, + jit_compile=True, + sampler="top_k", + **kwargs, ): - batch_size = tf.shape(prompt)[0] - prompt = tf.slice(prompt, [0, cache_index], [batch_size, 1]) - return self.call_with_cache(prompt, mask, cache, cache_index) + xla_compatible = is_xla_compatible(self) + super().compile( + *args, + run_eagerly=run_eagerly, + # Only `jit_compile` if not eager and in a compatible environment. + jit_compile=jit_compile and xla_compatible and not run_eagerly, + **kwargs, + ) + self._sampler = samplers.get(sampler) + # Clear the compiled generate function. + self.generate_function = None + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + def generate_function(prompt, input_mask, min_length): + # Create and seed cache with a single forward pass. + cache = self._build_cache(prompt) + + def next(prompt, state, index): + # The cache index is the index of our previous token. + cache_index = index - 1 + prompt = tf.slice(prompt, [0, cache_index], [-1, 1]) + logits, state = self.call_with_cache(prompt, state, cache_index) + return tf.squeeze(logits, axis=1), state + + return self._sampler( + next=next, + prompt=prompt, + state=cache, + index=min_length, + mask=input_mask, + end_token_id=self.preprocessor.tokenizer.end_token_id, + ) + + if self.run_eagerly: + self.generate_function = generate_function + else: + # `jit_compile` is a property of keras.Model after TF 2.12. + # Use `getattr()` for backwards compatibility. + jit_compile = getattr(self, "jit_compile", True) + self.generate_function = tf.function( + generate_function, jit_compile=jit_compile + ) + return self.generate_function def generate( self, prompt, max_length, - sampler="top_k", ): """Generate text. @@ -295,46 +346,27 @@ def generate( "`self.preprocessor` is `None`, please make sure " "`preprocessor` is set before calling `generate`." ) - sampler = keras_nlp.samplers.get(sampler) - if sampler.__class__ == BeamSampler: - raise ValueError( - "`BeamSampler` is not supported right now, please choose " - "another sampler, e.g., `TopPSampler`." - ) - if hasattr(self, "jit_compile"): - # `jit_compile` is a public property as of tf 2.12. hasattr is for - # backward compat. - sampler.jit_compile = self.jit_compile - sampler.run_eagerly = self.run_eagerly - if self.sampler and serialize(sampler) == serialize(self.sampler): - # If the new sampler is the same as the older one, we reuse the old - # sampler to avoid recompile. - sampler = self.sampler - else: - self.sampler = sampler # Tokenize. + prompt = tf.convert_to_tensor(prompt) + input_is_scalar = prompt.shape.rank == 0 + prompt = prompt[tf.newaxis] if input_is_scalar else prompt prompt = self.preprocessor.tokenizer(prompt) - # Create and seed the cache before generation. - token_ids = prompt - if prompt.shape.rank == 1: - token_ids = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :]) - token_ids = token_ids.to_tensor(shape=(None, max_length)) - # Pass a padding mask of all ones when seeing the cache. The mask will - # not affect cached key/values for input tokens we care about. - padding_mask = tf.ones_like(token_ids, dtype=tf.bool) - batch_size = tf.shape(token_ids)[0] - cache = self.build_empty_cache(batch_size, max_length) - _, cache = self.call_with_cache(token_ids, padding_mask, cache, 0) - # Run generation. - generated = sampler( - prompt, - self._get_token_probability, - max_length=max_length, - end_token_id=self.preprocessor.tokenizer.end_token_id, - cache=cache, - ) + # Pad ragged to dense tensors. + padded_shape = (None, max_length) + min_length = tf.reduce_min(prompt.row_lengths()) + input_mask = tf.ones_like(prompt, tf.bool).to_tensor(shape=padded_shape) + prompt = prompt.to_tensor(shape=padded_shape) + + # Run the (possibly compiled) generate function on dense inputs. + generate_function = self.make_generate_function() + output = generate_function(prompt, input_mask, min_length) + + # Truncate to ragged by removing tokens after the first end token. + end_token_id = self.preprocessor.tokenizer.end_token_id + output = truncate_at_token(output, end_token_id, input_mask) # Detokenize. - return self.preprocessor.tokenizer.detokenize(generated) + output = self.preprocessor.tokenizer.detokenize(output) + return tf.squeeze(output, 0) if input_is_scalar else output diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py index 37436a2336..ccf447532c 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -122,18 +122,17 @@ def test_gpt2_causal_lm_fit_no_preprocessing(self, jit_compile): self.causal_lm_no_preprocessing.fit(self.preprocessed_dataset) @parameterized.named_parameters( - ("non_jit_compile_cache", False, True), - ("non_jit_compile_non_cache", False, False), - ("jit_compile_non_cache", True, False), + ("jit_compile_false", False), ("jit_compile_true", True) ) - def test_gpt2_causal_lm_generate(self, jit_compile, use_cache): + def test_generate(self, jit_compile): + # Tensor input. self.causal_lm.compile(jit_compile=jit_compile) self.causal_lm.generate( self.raw_batch, max_length=10, ) - - # String input + first_fn = self.causal_lm.generate_function + # String input. prompt = " airplane" generated = self.causal_lm.generate( prompt, @@ -141,6 +140,11 @@ def test_gpt2_causal_lm_generate(self, jit_compile, use_cache): ) generated = generated.numpy().decode("utf-8") self.assertTrue(prompt in generated) + second_fn = self.causal_lm.generate_function + # Assert we did not recompile. + self.assertEqual(first_fn, second_fn) + self.causal_lm.compile(sampler="greedy") + self.assertIsNone(self.causal_lm.generate_function) @parameterized.named_parameters( ("tf_format", "tf", "model"), diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 30b768bc2e..03b390a3a6 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -14,19 +14,16 @@ """Beam Sampler.""" import tensorflow as tf -from absl import logging from tensorflow import keras +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from keras_nlp.api_export import keras_nlp_export from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import base_sampler_args_docstring from keras_nlp.samplers.sampler import call_args_docstring from keras_nlp.utils.python_utils import format_docstring -@format_docstring( - base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring -) +@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.BeamSampler") class BeamSampler(Sampler): """Beam Sampler class. @@ -39,192 +36,142 @@ class BeamSampler(Sampler): Args: num_beams: int. The number of beams that should be kept at each time-step. `num_beams` should be strictly positive. - {{base_sampler_args}} Call Args: {{call_args}} Examples: ```python - VOCAB_SIZE = 10 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=VOCAB_SIZE, - output_dim=16, - ), - keras.layers.Dense(VOCAB_SIZE, activation="softmax"), - ] + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # A uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.BeamSampler()( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, ) - - # Define a function that outputs the next token's probability for each token - # in the input sequence. - def token_probability_fn(inputs, mask): - return model(inputs) - - prompt = tf.fill((8, 1), 1) - - sampler = keras_nlp.samplers.BeamSampler(num_beams=3) - # Print the generated sequence (token ids). - print(sampler(prompt, token_probability_fn, max_length=10)) + print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) + # >>> "zzzzzaaaaaaa" ``` """ def __init__( self, num_beams=5, - jit_compile=True, - run_eagerly=False, ): + super().__init__() self.num_beams = num_beams - super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly) - def get_next_token(self, next_token_probs): - # Beam search overrides the whole `sample` method. - pass - - def sample( + def __call__( self, + next, prompt, - mask, - num_steps, - from_logits=True, + state=None, + index=0, + mask=None, end_token_id=None, - cache=None, ): - """Sampling logic implementation. - - Because beam search uses a different loop body, we have to override the - whole `sample` method instead of just the `get_next_token` method. - """ - if cache is not None: - logging.warning( - "`BeamSampler` does not support cache decoding now, the cache " - "will be ignored. To use cache decoding, please use a " - "different sampler, e.g., `TopPSampler`." - ) batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] - max_length = tf.cast(max_length, num_steps.dtype) - length = max_length - num_steps - dummy_preds = self.token_probability_fn(prompt, mask=mask) - vocab_size = tf.shape(dummy_preds)[-1] - pred_dtype = dummy_preds.dtype - - num_beams = self.num_beams - - # Initialize beam with shape `(batch_size, num_beams, length)`. - beams = tf.repeat(tf.expand_dims(prompt, axis=1), num_beams, axis=1) - # Initialize `beams_prob` with shape `(batch_size, num_beams)`. - beams_prob = tf.zeros([batch_size, 1], dtype=pred_dtype) - beams_prob = tf.concat( - [beams_prob, tf.fill((batch_size, num_beams - 1), pred_dtype.min)], - axis=-1, - ) - - def one_step(beams, beams_prob, length, mask): - flattened_beams = tf.reshape( - beams, shape=[batch_size * num_beams, -1] - ) - repeated_mask = tf.tile(mask, [num_beams, 1]) - probs = self.token_probability_fn(flattened_beams, repeated_mask) - preds = tf.gather( - probs, - tf.repeat(length - 1, batch_size * num_beams), - axis=1, - batch_dims=1, - ) - if from_logits: - preds = keras.activations.softmax(preds, axis=-1) + # Make sure max length and start index are the same dtype. + index = tf.cast(index, max_length.dtype) + + def create_beams(x): + """Add initial beam state.""" + return tf.repeat(x, self.num_beams, axis=0) + + def flatten_beams(x): + """Combine the beam dim and batch dim.""" + flat_shape = [batch_size * self.num_beams] + x.shape.as_list()[2:] + return tf.reshape(x, shape=flat_shape) + + def unflatten_beams(x): + """Separate the beam dim and batch dim.""" + unflat_shape = [batch_size, self.num_beams] + x.shape.as_list()[1:] + return tf.reshape(x, shape=unflat_shape) + + mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask + # `tf.while_loop` will not accept `None` as a value for `loop_vars`. + state = () if state is None else state + # Add extra sequences for each beam. + prompt, mask = create_beams(prompt), create_beams(mask) + state = tf.nest.map_structure(create_beams, state) + # Setup the initial beam log-likelihoods. + # On the first loop, make sure only the original beam is considered. + log_probs = tf.constant([[0.0] + [-1e9] * (self.num_beams - 1)]) + log_probs = flatten_beams(tf.repeat(log_probs, batch_size, axis=0)) + + def cond(prompt, state, index, log_probs): + if end_token_id is None: + return True + # Stop if all sequences have produced a *new* end_token_id. + end_tokens = (prompt == end_token_id) & (~mask) + prompt_done = tf.reduce_any(end_tokens, axis=-1) + return not tf.reduce_all(prompt_done) + + def body(prompt, state, index, log_probs): + # Compute the softmax distribution for the next token. + logits, state = next(prompt, state, index) + vocab_size = tf.shape(logits)[-1] + probs = keras.activations.softmax(logits) + + # Compute the running log-likelihood of each new candidate. + next_log_probs = tf.math.log(probs) + log_probs[..., tf.newaxis] # Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`. - preds = tf.reshape(preds, shape=[batch_size, -1]) - - cum_probs = tf.math.log(preds) + tf.repeat( - beams_prob, repeats=vocab_size, axis=1 - ) - - candidate_prob, candidate_indexes = tf.math.top_k( - cum_probs, k=num_beams, sorted=False - ) - - candidate_beam_indexes = candidate_indexes // vocab_size - next_token = candidate_indexes % vocab_size - - beams = tf.gather( - beams, candidate_beam_indexes, axis=1, batch_dims=1 - ) - - # Build a new column of updates to scatter into the beam tensor. - next_token = tf.where( - condition=mask[..., length, tf.newaxis], - x=beams[..., length], - y=next_token, - ) - next_token = tf.reshape(next_token, shape=[-1]) - - mask = tf.tensor_scatter_nd_update( - tensor=mask, - indices=tf.stack( - ( - tf.cast(tf.range(batch_size), dtype=length.dtype), - tf.repeat(length, batch_size), - ), - axis=1, - ), - updates=tf.repeat(True, batch_size), - ) + next_log_probs = tf.reshape(next_log_probs, shape=[batch_size, -1]) - # Generate `(batch_index, beam_index)` tuples for each beam. - beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool)) - beam_indices = tf.cast(beam_indices, dtype=length.dtype) - # Build a tensor of repeated `length` values. - length_indices = tf.fill((batch_size * num_beams, 1), length) - # Concatenate to a triplet of `(batch_index, beam_index, length)`. - indices = tf.concat([beam_indices, length_indices], axis=-1) - - # Update `beams[:, :, length]` with `next_token`. - beams = tf.tensor_scatter_nd_update( - tensor=beams, - indices=indices, - updates=next_token, + # Compute the top beam indices and next tokens. + next_log_probs, indices = tf.math.top_k( + next_log_probs, k=self.num_beams, sorted=False ) - - beams_prob = candidate_prob - - length = tf.add(length, 1) - return beams, beams_prob, length, mask - - # Run a while loop till text of length `max_length` has been generated. - beams, beams_prob, length, mask = tf.while_loop( - cond=lambda beams, beams_prob, length, mask: tf.less( - length, max_length - ), - body=one_step, - loop_vars=[beams, beams_prob, length, mask], - # There is a strange issue that when `batch_size=1`, the first loop - # iteration changes `beams_prob`'s shape from [1, None] to - # [None, None], which does not happen for `batch_size>1`. - # As a workaround, we set shape invariants. - shape_invariants=[ - beams.get_shape(), - tf.TensorShape([None, None]), - length.get_shape(), - mask.get_shape(), - ], - ) - - # Get the beam with the maximum probability. - max_indexes = tf.math.argmax(beams_prob, axis=-1) - max_beams = tf.gather( - beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1 + beam_indices = indices // vocab_size + next_token = flatten_beams(indices % vocab_size) + # We need `ensure_shape` as `top_k` will change the static shape. + next_log_probs = flatten_beams(next_log_probs) + log_probs = tf.ensure_shape(next_log_probs, log_probs.shape) + + def gather_beams(x): + x = unflatten_beams(x) + x = tf.gather(x, beam_indices, axis=1, batch_dims=1) + return flatten_beams(x) + + prompt = gather_beams(prompt) + state = tf.nest.map_structure(gather_beams, state) + + # Update each beam with the next token. + next_token = tf.cast(next_token, prompt.dtype) + # Don't overwrite anywhere mask is True. + next_token = tf.where(mask[:, index], prompt[:, index], next_token) + # Update the prompt with the next token. + next_token = next_token[:, tf.newaxis] + prompt = dynamic_update_slice(prompt, next_token, [0, index]) + # Return the iteration of the loop state. + return (prompt, state, index + 1, log_probs) + + prompt, _, _, log_probs = tf.while_loop( + cond=cond, + body=body, + loop_vars=(prompt, state, index, log_probs), + maximum_iterations=(max_length - index), ) - return tf.squeeze(max_beams, axis=1) + # Gather the top beam at each batch index. + prompt, log_probs = unflatten_beams(prompt), unflatten_beams(log_probs) + top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis] + prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1) + return tf.squeeze(prompt, axis=1) def get_config(self): config = super().get_config() - - config.update({"num_beams": self.num_beams}) + config.update( + { + "num_beams": self.num_beams, + } + ) return config diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index cc0b2cb0f6..95290b8f1a 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -13,136 +13,83 @@ # limitations under the License. """Tests for Beam sampler.""" -import random - +import numpy as np import tensorflow as tf from absl.testing import parameterized -from tensorflow import keras from keras_nlp.samplers.beam_sampler import BeamSampler -from keras_nlp.samplers.greedy_sampler import GreedySampler class BeamSamplerTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - self.vocab_size = 10 - self.feature_size = 16 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=self.vocab_size, - output_dim=self.feature_size, - ), - keras.layers.Dense(self.vocab_size), - ] + # Use a simple alphabet of lowercase characters to [0, 26). + self.int_lookup = {i: chr(i + ord("a")) for i in range(26)} + self.char_lookup = {v: k for k, v in self.int_lookup.items()} + self.batch_size = 1 + self.length = 12 + self.vocab_size = len(self.int_lookup) + + def next(prompt, state, index): + # Return a distribution favoring the next char in state. + logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 + return logits, state + + self.next = next + self.sampler = BeamSampler(num_beams=5) + + def join_as_string(self, x): + return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] + + def test_stateless_call(self): + def next(prompt, state, index): + # Return a distribution favoring the first token in the vocab. + logits = np.zeros((self.batch_size, self.vocab_size)) + logits[:, 0] = 1e9 + return tf.constant(logits, dtype="float32"), state + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, + index=5, ) - - def token_probability_fn(inputs, mask): - return model(inputs) - - self.token_probability_fn = token_probability_fn - self.sampler = BeamSampler(num_beams=2) - - def test_generate_with_1d_prompt(self): - inputs = tf.constant([1]) - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [5]) - - def test_generate_with_2d_prompt(self): - inputs = tf.constant([[1], [1]]) - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_list_prompt(self): - inputs = [[1], [1]] - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_ragged_prompt(self): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - inputs = tf.ragged.constant([[1], [2, 1, 2]]) - outputs = self.sampler(inputs, token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_one_beam_generation(self): - for _ in range(5): - inputs = tf.constant([random.randint(0, 9)]) - beam_sampler = BeamSampler(num_beams=1) - greedy_sampler = GreedySampler() - beam_output = beam_sampler( - inputs, - self.token_probability_fn, - max_length=5, - ) - greedy_output = greedy_sampler( - inputs, - self.token_probability_fn, - max_length=5, - ) - self.assertAllEqual(beam_output, greedy_output) + self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) + + def test_stateful_call(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + ) + self.assertEqual(self.join_as_string(output), ["sequentially"]) + + def test_early_stopping(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + end_token_id=self.char_lookup["t"], + ) + self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) @parameterized.named_parameters( - ("xla_graph", True, False), - ("non_xla_graph", False, False), - ("eager", False, True), + ("jit_compile_false", False), ("jit_compile_true", True) ) - def test_assert_generation_is_correct(self, jit_compile, run_eagerly): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) + def test_compilation(self, jit_compile): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) - batch_size = 10 - inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) - max_length = 3 - for i in range(1, 5): - sampler = BeamSampler( - num_beams=i, - jit_compile=jit_compile, - run_eagerly=run_eagerly, - ) - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - ) - self.assertAllEqual( - outputs, 3 * tf.ones(shape=[batch_size, max_length]) - ) + @tf.function(jit_compile=jit_compile) + def generate(prompt, state): + return self.sampler(self.next, prompt=prompt, state=state) - def test_end_token_id(self): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - max_length = 4 - inputs = tf.constant([[0, 1], [1, 2]]) - outputs = self.sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=2, - ) - # end_token in prompt does not trigger truncation. - expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]]) - self.assertAllEqual(outputs, expected_outputs) - - max_length = 4 - inputs = tf.constant([[0, 1], [1, 3]]) - outputs = self.sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=3, - ) - expected_outputs = tf.ragged.constant([[0, 1], [1, 3]]) - self.assertAllEqual(outputs, expected_outputs) + output = generate(prompt, state) + self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 90beafa88b..c81d35d155 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -17,14 +17,11 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import base_sampler_args_docstring from keras_nlp.samplers.sampler import call_args_docstring from keras_nlp.utils.python_utils import format_docstring -@format_docstring( - base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring -) +@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.GreedySampler") class GreedySampler(Sampler): """Greedy sampler class. @@ -32,47 +29,33 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Args: - {{base_sampler_args}} - Call Args: {{call_args}} Examples: ```python - VOCAB_SIZE = 10 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=VOCAB_SIZE, - output_dim=16, - ), - keras.layers.Dense(VOCAB_SIZE, activation="softmax"), - ] + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # return a uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.GreedySampler()( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, ) - - # Define a function that outputs the next token's probability for each token - # in the input sequence. - def token_probability_fn(inputs, mask): - return model(inputs) - - prompt = tf.fill((8, 1), 1) - - sampler = keras_nlp.samplers.GreedySampler() - # Print the generated sequence (token ids). - print(sampler(prompt, token_probability_fn, max_length=10)) + print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) + # >>> "zzzzzaaaaaaa" ``` """ - def __init__( - self, - jit_compile=True, - run_eagerly=False, - ): - super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly) + def __init__(self): + super().__init__() - def get_next_token(self, next_token_probs): - return tf.argmax(next_token_probs, axis=-1) + def get_next_token(self, probabilities): + return tf.argmax(probabilities, axis=-1) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 8515416e7d..21c237e56b 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -13,9 +13,9 @@ # limitations under the License. """Tests for Greedy sampler.""" +import numpy as np import tensorflow as tf from absl.testing import parameterized -from tensorflow import keras from keras_nlp.samplers.greedy_sampler import GreedySampler @@ -23,116 +23,88 @@ class GreedySamplerTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - self.vocab_size = 10 - self.feature_size = 16 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=self.vocab_size, - output_dim=self.feature_size, - ), - keras.layers.Dense(self.vocab_size), - keras.layers.Softmax(), - ] - ) - - def token_probability_fn(inputs, mask): - return model(inputs) - - self.token_probability_fn = token_probability_fn - + # Use a simple alphabet of lowercase characters to [0, 26). + self.int_lookup = {i: chr(i + ord("a")) for i in range(26)} + self.char_lookup = {v: k for k, v in self.int_lookup.items()} + self.batch_size = 1 + self.length = 12 + self.vocab_size = len(self.int_lookup) + + def next(prompt, state, index): + # Return a distribution favoring the next char in state. + logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 + return logits, state + + self.next = next self.sampler = GreedySampler() - def test_generate_with_1d_prompt(self): - inputs = tf.constant([1]) - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [5]) - - def test_generate_with_2d_prompt(self): - inputs = tf.constant([[1], [1]]) - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_list_prompt(self): - inputs = [[1], [1]] - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_ragged_prompt(self): - max_length = 5 - - def token_probability_fn(inputs, mask): - # Assert that user function is passed only dense tensors. - self.assertIsInstance(inputs, tf.Tensor) - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1) - - inputs = tf.ragged.constant([[1], [2, 1, 2]]) - outputs = self.sampler(inputs, token_probability_fn, max_length) - self.assertEqual(outputs.shape, [2, 5]) - - def test_assert_generation_is_correct(self): - batch_size = 10 - max_length = 3 - - def token_probability_fn(inputs, mask): - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.repeat( - tf.repeat(prob, batch_size, axis=0), max_length, axis=1 - ) - - inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) - outputs = self.sampler( - inputs, token_probability_fn, max_length=max_length + def join_as_string(self, x): + return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] + + def test_stateless_call(self): + def next(prompt, state, index): + # Return a distribution favoring the first token in the vocab. + logits = np.zeros((self.batch_size, self.vocab_size)) + logits[:, 0] = 1e9 + return tf.constant(logits), state + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, + index=5, ) - self.assertAllEqual( - outputs, 3 * tf.ones(shape=[batch_size, max_length]) + self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) + + def test_stateful_call(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, ) - - def test_end_token_id(self): - def token_probability_fn(inputs, mask): - batch_size = tf.shape(inputs)[0] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.repeat( - tf.repeat(prob, batch_size, axis=0), max_length, axis=1 - ) - - max_length = 4 - sampler = GreedySampler() - inputs = tf.constant([[0, 1], [1, 2]]) - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=2, + self.assertEqual(self.join_as_string(output), ["sequentially"]) + + def test_early_stopping(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + end_token_id=self.char_lookup["t"], ) - # end_token in prompt does not trigger truncation. - expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]]) - self.assertAllEqual(outputs, expected_outputs) - - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=3, + self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) + + def test_is_greedy(self): + def next(prompt, state, index): + # Return a distribution where each id is progressively less likely. + logits = tf.range(self.vocab_size, 0, -1, dtype="float32") + logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) + return logits, state + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, ) - # Generated end_token will be truncated. - expected_outputs = tf.ragged.constant([[0, 1], [1, 2]]) - self.assertAllEqual(outputs, expected_outputs) - - def test_compare_xla_noxla_results(self): - inputs = [[1], [1]] - xla_sampler = GreedySampler(jit_compile=True) - outputs_xla = xla_sampler( - inputs, self.token_probability_fn, max_length=5 - ) - - xla_sampler = GreedySampler(jit_compile=False) - outputs_no_xla = xla_sampler( - inputs, self.token_probability_fn, max_length=5 - ) - - self.assertAllEqual(outputs_xla, outputs_no_xla) + output_ids = set(output[0].numpy()) + self.assertContainsSubset(output_ids, [0]) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_compilation(self, jit_compile): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + + @tf.function(jit_compile=jit_compile) + def generate(prompt, state): + return self.sampler(self.next, prompt=prompt, state=state) + + output = generate(prompt, state) + self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index d7098d9777..dde0598f68 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -20,246 +20,121 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.utils.python_utils import format_docstring -base_sampler_args_docstring = """ - jit_compile: bool, defaults to True. If True, XLA compilation will be used. - run_eagerly: bool, defaults to False. If True, the sampler will run in - the eager mode. - """ - call_args_docstring = """ - prompt: a list of integers or an integer Tensor, can be 1D or 2D. The - initial tokens to append generated tokens. - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. - max_length: int. The max length of generated sequence. - mask: a tensor, defaults to None. The padding mask of the prompt. - end_token_id: int, defaults to None. The token marking the end of the - sequence, once encountered the generation is finished for the exact - sequence. If None, every sequence is generated up to `max_length`. - If set, all tokens after encountering `end_token_id` will be - replaced with `pad_token_id`. - from_logits: bool, defaults to True. Indicate if the `token_probability_fn` - returns logits. If False, `token_probability_fn` returns probability - distributions. - cache: a dense int tensor, a cache of intermediate key and value tensor - computed by each decoder self-attention layer. These values will only - be computed once for each new token in the generated sequence. + next: A function which takes in the `prompt, state, index` of the + current generation loop, and outputs a tuple `(logits, state)` with the + probability for the next token and state for the next iteration. + prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This + tensor will be iteratively updated column by column with new sampled + values, starting at `index`. + state: Optional. A tensor or nested structure of tensors that will be + updated by each call to `next`. This can be used to cache computations + from early iterations of the generative loop. + index: Optional. The first index to start sampling at. + mask: Optional. A 2D integer tensor with the same shape as `prompt`. + Locations which are `True` in the mask are never updated during + sampling. Often this will mark all ids in `prompt` which were present in + the original input. + end_token_id: Optional. The token marking the end of the sequence. If + specified, sampling will stop as soon as all sequences in the prompt + produce a `end_token_id` in a location where `mask` is `False`. """ -@format_docstring( - base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring -) +@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.Sampler") class Sampler: """Base sampler class. - Args: - {{base_sampler_args}} - Call Args: {{call_args}} - The inputs and outputs of Sampler class are both token ids. - - Subclassers should always implement the `get_next_token()` method, which - gets the next token based on probability distribution over vocab tokens. - Please check available subclass samplers for examples. If you need more - control over the sampling process, please implement `sample()` method - instead, see `keras_nlp.samplers.BeamSampler` for examples. - - Examples: - - Basic usage: - ```python - VOCAB_SIZE = 10 - - # Create a dummy model to predict the next token. Note that the output is - # random without training, here we just demo how `samplers` works. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=VOCAB_SIZE, - output_dim=16, - ), - keras.layers.Dense(VOCAB_SIZE, activation="softmax"), - ] - ) + This base class can be extended to implement different auto-regressive + sampling methods. Subclasses can either: - # Define a function that outputs the next token's probability for each token - # in the input sequence. - def token_probability_fn(inputs, mask): - return model(inputs) + - Override the `get_next_token()` method, which computes the next token + based on a probability distribution over all possible vocab entries. + - Override `__call__`, if the sampling method need additional state beyond + the next tokens probability distribution to sample a sequence. - prompt = tf.fill((8, 1), 1) + Please check available subclass samplers for examples. - sampler = keras_nlp.samplers.GreedySampler() - # Print the generated sequence (token ids). - print(sampler(prompt, token_probability_fn, max_length=10, end_token_id=2)) - ``` + Examples: - Use with string inputs: ```python - vocab = ["[UNK]", "[PAD]", "[END]", "the", "quick", "brown", "fox"] - tokenizer = keras_nlp.tokenizers.WordPieceTokenizer( - vocabulary=vocab, - lowercase=True, + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # return a uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.GreedySampler()( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, ) - FEATURE_SIZE = 16 - VOCAB_SIZE = len(vocab) - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=VOCAB_SIZE, - output_dim=FEATURE_SIZE, - ), - keras.layers.Dense(VOCAB_SIZE, activation="softmax"), - ] - ) - # Define a function that outputs the next token's probability for each token - # in the input sequence. - def token_probability_fn(inputs, mask): - return model(inputs) - - prompt = tokenizer("the quick brown fox") - sampler = keras_nlp.samplers.GreedySampler() - generated = sampler( - prompt, - token_probability_fn, - max_length=10, - end_token_id=tokenizer.token_to_id("[END]") - ) - print(tokenizer.detokenize(generated)) + print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) + # >>> "zzzzzaaaaaaa" ``` """ - def __init__( - self, - jit_compile=True, - run_eagerly=False, - ): - if run_eagerly and jit_compile: - raise ValueError( - "XLA cannot be turned on under eager mode, received " - "`jit_compile=True` and `run_eagerly=True`. Please either set " - "`jit_compile=False` or set `run_eagerly=False`." - ) - self.jit_compile = jit_compile - self.run_eagerly = run_eagerly - - def _validate_prompt_and_mask(self, prompt, mask): - """Helper method to validate input prompt.""" - if not isinstance(prompt, (list, tf.RaggedTensor, tf.Tensor)): - raise ValueError( - "`prompt` must be one of `list`, `tf.RaggedTensor` or " - f"`tf.Tensor`, but received: prompt={type(prompt)}." - ) - - if isinstance(prompt, tf.RaggedTensor): - if mask: - raise ValueError( - "`mask` is only valid when `prompt` is a list or dense " - f"tensor, but received type(prompt)={type(prompt)}." - ) - return prompt, mask - - if isinstance(prompt, list): - prompt = tf.convert_to_tensor(prompt) - if not mask: - mask = tf.cast(tf.ones_like(prompt), dtype=tf.bool) - prompt = tf.ragged.boolean_mask(prompt, mask) - return prompt, mask - - def _pad_prompt(self, prompt, max_length): - """Pad prompt to `max_length`.""" - mask = tf.ones_like(prompt, dtype=tf.bool) - mask = mask.to_tensor(shape=(None, max_length)) - prompt = prompt.to_tensor(shape=(None, max_length)) - return prompt, mask - - def _mask_tokens_after_end_token( - self, - generated_result, - original_padding_mask, - max_length, - end_token_id, - ): - """Helper function to truncate the tokens after the end token.""" - # Create a tensor with True for each end_token_id. - end_tokens = generated_result == end_token_id - # Remove all end_token_ids in the original input. - end_tokens = end_tokens & (original_padding_mask == tf.constant(False)) - # Find index of first end_token_id. - end_indices = tf.math.argmax(end_tokens, -1) - # Use max_length if no `end_token_id` is found. - end_indices = tf.where( - end_indices == 0, - tf.cast(max_length, dtype=end_indices.dtype), - end_indices, - ) - # Truncate out tokens after (including) the end token. - mask_indices = tf.sequence_mask(end_indices, maxlen=max_length) - return tf.ragged.boolean_mask(generated_result, mask_indices) - def __call__( self, + next, prompt, - token_probability_fn, - max_length, + state=None, + index=0, mask=None, end_token_id=None, - from_logits=True, - cache=None, ): - prompt, mask = self._validate_prompt_and_mask(prompt, mask) - self.token_probability_fn = token_probability_fn - input_is_1d = prompt.shape.rank == 1 - if input_is_1d: - prompt = tf.RaggedTensor.from_tensor(prompt[tf.newaxis, :]) - - shortest_prompt_len = tf.reduce_min(prompt.row_lengths()) - # Pad prompt to be a dense Tensor of shape [batch_size, max_length]. - # This step is required for XLA compatibility because XLA requires a - # static shape, which means we cannot concatenate generated token to - # current prompt. - prompt, mask = self._pad_prompt(prompt, max_length) - original_padding_mask = tf.identity(mask) + max_length = tf.shape(prompt)[-1] + # Make sure `max_length` and `index` are the same dtype. + index = tf.cast(index, max_length.dtype) + mask = tf.zeros_like(prompt, dtype=tf.bool) if mask is None else mask + # `tf.while_loop` will not accept `None` as a value for `loop_vars`. + state = () if state is None else state + + def cond(prompt, state, index): + if end_token_id is None: + return True + # Stop if all sequences have produced a *new* end_token_id. + end_tokens = (prompt == end_token_id) & (~mask) + prompt_done = tf.reduce_any(end_tokens, axis=-1) + return not tf.reduce_all(prompt_done) + + def body(prompt, state, index): + # Compute the softmax distribution for the next token. + logits, state = next(prompt, state, index) + probabilities = keras.activations.softmax(logits) + + # Compute the next token. + next_token = self.get_next_token(probabilities) + # Don't overwrite anywhere mask is True. + next_token = tf.cast(next_token, prompt.dtype) + next_token = tf.where(mask[:, index], prompt[:, index], next_token) + # Update the prompt with the next token. + next_token = next_token[:, tf.newaxis] + prompt = dynamic_update_slice(prompt, next_token, [0, index]) + # Return the next prompt, state and incremented index. + return (prompt, state, index + 1) - # Convert `sample` method to a `tf.function` if `self.run_eagerly=False` - # , and turn on `jit_compile` accordingly. - sample = self.sample - if not self.run_eagerly: - if self.jit_compile: - sample = self._sample_graph_xla - else: - sample = self._sample_graph - prompt = sample( - prompt, - mask, - max_length - shortest_prompt_len, - from_logits=from_logits, - end_token_id=end_token_id, - cache=cache, + prompt, _, _ = tf.while_loop( + cond=cond, + body=body, + loop_vars=(prompt, state, index), + maximum_iterations=(max_length - index), ) - # Mask out tokens after `end_token_id`. - if end_token_id is not None: - prompt = self._mask_tokens_after_end_token( - prompt, - original_padding_mask, - max_length, - end_token_id, - ) - - return tf.squeeze(prompt, axis=0) if input_is_1d else prompt + return prompt - def get_next_token(self, next_token_probs): + def get_next_token(self, probabilities): """Get the next token. Args: - next_token_probs: a Tensor, the probability distribution for next + probabilities: a Tensor, the probability distribution for next token over all vocab tokens. Get the next token based on given probability distribution over tokens. @@ -267,158 +142,9 @@ def get_next_token(self, next_token_probs): """ raise NotImplementedError - @tf.function - def _sample_graph( - self, - prompt, - mask, - num_steps, - from_logits=True, - end_token_id=None, - cache=None, - ): - """Wrapper of `sample` method to make it a non-XLA tf graph.""" - return self.sample( - prompt, mask, num_steps, from_logits, end_token_id, cache - ) - - @tf.function(jit_compile=True) - def _sample_graph_xla( - self, - prompt, - mask, - num_steps, - from_logits=True, - end_token_id=None, - cache=None, - ): - """Wrapper of `sample` method to make it an XLA tf graph.""" - return self.sample( - prompt, mask, num_steps, from_logits, end_token_id, cache - ) - - def sample( - self, - prompt, - mask, - num_steps, - from_logits=True, - end_token_id=None, - cache=None, - ): - """Sampling logic implementation. - - Args: - prompt: a dense int Tensor of shape [batch_size, max_length]. The - placeholder for generated sequence. - token_probability_fn: a function that generates the probability of - the next token over the whole vocabulary for each input token. - mask: a dense bool Tensor of shape [batch_size, max_length]. The - mask of prompt. - num_steps: int. The remaining number of tokens to generate. - from_logits: bool, defaults to True. Indicate if the - `token_probability_fn` returns logits. If False, - `token_probability_fn` returns probability distributions. - end_token_id: int, defaults to None. The token marking the end of - the sequence, once encountered the generation is finished for - the exact sequence. - cache: a dense int tensor, the cache used in decoding. The cache - stores the key and value of each - `keras_nlp.layers.CachedMultiHeadAttention` layer to make the - decoding faster by avoiding duplicated computation. - - Returns: - A dense int Tensor, representing the generated text in token id - space. - """ - batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] - num_steps = tf.cast(num_steps, tf.int32) - max_length = tf.cast(max_length, tf.int32) - # The index of the last non-padding token in prompt. Since all sequences - # are aligned to the right side, the index is the same for all. - current_index = max_length - num_steps - original_padding_mask = tf.cast(tf.identity(mask), dtype=tf.int32) - - def body( - current_index, - prompt, - mask, - cache=None, - ): - last_index = current_index - 1 - if cache is not None: - probs, cache = self.token_probability_fn( - prompt, - mask, - cache=cache, - cache_index=last_index, - ) - next_token_probs = tf.squeeze(probs, axis=1) - else: - probs = self.token_probability_fn( - prompt, - mask, - ) - next_token_probs = tf.gather( - probs, - tf.repeat(current_index - 1, batch_size), - axis=1, - batch_dims=1, - ) - - if from_logits: - next_token_probs = keras.activations.softmax( - next_token_probs, axis=-1 - ) - next_token = self.get_next_token(next_token_probs) - next_token = tf.cast(next_token, prompt.dtype) - next_token = tf.where( - mask[:, current_index], - prompt[:, current_index], - next_token, - ) - next_token = next_token[:, tf.newaxis] - next_mask = tf.fill([batch_size, 1], True) - slice_start = [0, current_index] - mask = dynamic_update_slice(mask, next_mask, slice_start) - prompt = dynamic_update_slice(prompt, next_token, slice_start) - current_index = tf.add(current_index, 1) - if cache is None: - return current_index, prompt, mask - return current_index, prompt, mask, cache - - def cond(current_index, prompt, mask, cache=None): - if end_token_id is None: - return True - end_token_seen = (prompt == end_token_id) & ( - original_padding_mask == 0 - ) - sequence_done = tf.reduce_any(end_token_seen, axis=-1) - all_done = tf.reduce_all(sequence_done) - return not all_done - - if cache is None: - _, prompt, _ = tf.while_loop( - cond=cond, - body=body, - loop_vars=(current_index, prompt, mask), - maximum_iterations=num_steps, - ) - return prompt - # Run a while loop till `max_length` of tokens has been generated. - _, prompt, _, _ = tf.while_loop( - cond=cond, - body=body, - loop_vars=(current_index, prompt, mask, cache), - maximum_iterations=num_steps, - ) - return prompt - @classmethod def from_config(cls, config): return cls(**config) def get_config(self): - return { - "jit_compile": self.jit_compile, - } + return {} diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py index dd33f40093..49dff87a9d 100644 --- a/keras_nlp/samplers/sampler_test.py +++ b/keras_nlp/samplers/sampler_test.py @@ -16,7 +16,6 @@ import tensorflow as tf import keras_nlp -from keras_nlp.samplers.greedy_sampler import GreedySampler from keras_nlp.samplers.top_k_sampler import TopKSampler @@ -30,12 +29,12 @@ def test_serialization(self): def test_get(self): # Test get from string. - identifier = "greedy" + identifier = "top_k" sampler = keras_nlp.samplers.get(identifier) - self.assertIsInstance(sampler, GreedySampler) + self.assertIsInstance(sampler, TopKSampler) # Test dict identifier. - original_sampler = keras_nlp.samplers.GreedySampler(jit_compile=False) + original_sampler = keras_nlp.samplers.TopKSampler(k=7) config = keras_nlp.samplers.serialize(original_sampler) restored_sampler = keras_nlp.samplers.get(config) self.assertDictEqual( @@ -44,6 +43,6 @@ def test_get(self): ) # Test identifier is already a sampler instance. - original_sampler = keras_nlp.samplers.GreedySampler(jit_compile=False) + original_sampler = keras_nlp.samplers.TopKSampler(k=7) restored_sampler = keras_nlp.samplers.get(original_sampler) self.assertEqual(original_sampler, restored_sampler) diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 7ac1a86b99..54f899fcd6 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -17,14 +17,11 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import base_sampler_args_docstring from keras_nlp.samplers.sampler import call_args_docstring from keras_nlp.utils.python_utils import format_docstring -@format_docstring( - base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring -) +@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopKSampler") class TopKSampler(Sampler): """Top-K Sampler class. @@ -36,37 +33,29 @@ class TopKSampler(Sampler): Args: k: int, the `k` value of top-k. seed: int, defaults to None. The random seed. - {{base_sampler_args}} Call Args: {{call_args}} Examples: ```python - VOCAB_SIZE = 10 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=VOCAB_SIZE, - output_dim=16, - ), - keras.layers.Dense(VOCAB_SIZE, activation="softmax"), - ] + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # A uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.TopKSampler(k=3)( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, ) - - # Define a function that outputs the next token's probability for each token - # in the input sequence. - def token_probability_fn(inputs, mask): - return model(inputs) - - prompt = tf.fill((8, 1), 1) - - sampler = keras_nlp.samplers.TopKSampler(k=5) - # Print the generated sequence (token ids). - print(sampler(prompt, token_probability_fn, max_length=10)) + print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) + # >>> "zzzzzacbbcaa" ``` """ @@ -74,17 +63,15 @@ def __init__( self, k=5, seed=None, - jit_compile=True, - run_eagerly=False, ): + super().__init__() self.k = k self.seed = seed - super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly) - def get_next_token(self, next_token_probs): + def get_next_token(self, probabilities): # Filter out top-k tokens. top_k_pred, top_k_indices = tf.math.top_k( - next_token_probs, k=self.k, sorted=False + probabilities, k=self.k, sorted=False ) # Sample the next token from the probability distribution. next_token = tf.random.categorical( @@ -96,7 +83,6 @@ def get_next_token(self, next_token_probs): def get_config(self): config = super().get_config() - config.update( { "k": self.k, diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py index 062f7af767..9efc45f41d 100644 --- a/keras_nlp/samplers/top_k_sampler_test.py +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -16,7 +16,6 @@ import numpy as np import tensorflow as tf from absl.testing import parameterized -from tensorflow import keras from keras_nlp.samplers.top_k_sampler import TopKSampler @@ -24,148 +23,88 @@ class TopKSamplerTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - self.vocab_size = 10 - self.feature_size = 16 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=self.vocab_size, - output_dim=self.feature_size, - ), - keras.layers.Dense(self.vocab_size), - keras.layers.Softmax(), - ] + # Use a simple alphabet of lowercase characters to [0, 26). + self.int_lookup = {i: chr(i + ord("a")) for i in range(26)} + self.char_lookup = {v: k for k, v in self.int_lookup.items()} + self.batch_size = 1 + self.length = 12 + self.vocab_size = len(self.int_lookup) + + def next(prompt, state, index): + # Return a distribution favoring the next char in state. + logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 + return logits, state + + self.next = next + self.sampler = TopKSampler(k=5) + + def join_as_string(self, x): + return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] + + def test_stateless_call(self): + def next(prompt, state, index): + # Return a distribution favoring the first token in the vocab. + logits = np.zeros((self.batch_size, self.vocab_size)) + logits[:, 0] = 1e9 + return tf.constant(logits), state + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, + index=5, ) - - def token_probability_fn(inputs, mask): - return model(inputs) - - self.token_probability_fn = token_probability_fn - self.sampler = TopKSampler(k=2) - - def test_generate_with_1d_prompt(self): - inputs = tf.constant([1]) - - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [5]) - - def test_generate_with_2d_prompt(self): - inputs = tf.constant([[1], [1]]) - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_list_prompt(self): - inputs = [[1], [1]] - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_ragged_prompt(self): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - inputs = tf.ragged.constant([[1], [2, 1, 2]]) - outputs = self.sampler( - inputs, - token_probability_fn, - max_length=5, - from_logits=False, + self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) + + def test_stateful_call(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, ) - self.assertEqual(outputs.shape, [2, 5]) + self.assertEqual(self.join_as_string(output), ["sequentially"]) + + def test_early_stopping(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + end_token_id=self.char_lookup["t"], + ) + self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) + + def test_outputs_in_top_k(self): + def next(prompt, state, index): + # Return a distribution where each id is progressively less likely. + logits = tf.range(self.vocab_size, 0, -1, dtype="float32") + logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) + return logits, state + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, + ) + output_ids = set(output[0].numpy()) + self.assertContainsSubset(output_ids, range(5)) @parameterized.named_parameters( - ("xla_graph", True, False), - ("non_xla_graph", False, False), - ("eager", False, True), + ("jit_compile_false", False), ("jit_compile_true", True) ) - def test_assert_probability_distribution_generation_is_correct( - self, jit_compile, run_eagerly - ): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - batch_size = 10 - inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) - max_length = 3 - - outputs_count = np.array([0, 0, 0, 0]) - tf.random.set_seed(42) - sampler = TopKSampler( - k=2, - seed=42, - run_eagerly=jit_compile, - jit_compile=run_eagerly, - ) - for _ in range(8): - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - from_logits=False, - ) - flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) - for pred in flatten_predictions: - outputs_count[pred] += 1 - self.assertAllClose( - outputs_count / np.sum(outputs_count), - [0.0, 0.0, 0.0, 1.0], - rtol=0.2, - ) + def test_compilation(self, jit_compile): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) - def test_only_choose_from_top_k_tokens(self): - # Test that there are only the top-k tokens in the output. - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.4, 0.3, 0.2, 0.1]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - # Test that it only samples from top-k tokens. - for k in [1, 2, 3]: - inputs = tf.constant([[0, 0], [0, 0]]) - sampler = TopKSampler(k=k) - for _ in range(10): - outputs = sampler( - inputs, - token_probability_fn, - max_length=5, - from_logits=False, - ) - self.assertAllEqual(outputs < k, tf.ones_like(outputs)) - - def test_end_token_id(self): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - tf.random.set_seed(42) - sampler = TopKSampler(k=4, seed=42) - max_length = 4 - inputs = tf.constant([[0, 1], [1, 2]]) - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=2, - from_logits=False, - ) - # end_token in prompt does not trigger truncation. - expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]]) - self.assertAllEqual(outputs, expected_outputs) - - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=3, - from_logits=False, - ) - # Generated end_token will be truncated. - expected_outputs = tf.ragged.constant([[0, 1], [1, 2]]) - self.assertAllEqual(outputs, expected_outputs) + @tf.function(jit_compile=jit_compile) + def generate(prompt, state): + return self.sampler(self.next, prompt=prompt, state=state) + + output = generate(prompt, state) + self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 5622541d1f..fa8a7a8779 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -17,14 +17,11 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import base_sampler_args_docstring from keras_nlp.samplers.sampler import call_args_docstring from keras_nlp.utils.python_utils import format_docstring -@format_docstring( - base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring -) +@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopPSampler") class TopPSampler(Sampler): """Top-P Sampler class. @@ -38,37 +35,29 @@ class TopPSampler(Sampler): Args: p: float, the `p` value of top-p. seed: int, defaults to None. The random seed. - {{base_sampler_args}} Call Args: {{call_args}} Examples: ```python - VOCAB_SIZE = 10 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=VOCAB_SIZE, - output_dim=16, - ), - keras.layers.Dense(VOCAB_SIZE, activation="softmax"), - ] + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # A uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.TopPSampler(p=0.1)( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, ) - - # Define a function that outputs the next token's probability for each token - # in the input sequence. - def token_probability_fn(inputs, mask): - return model(inputs) - - prompt = tf.fill((8, 1), 1) - - sampler = keras_nlp.samplers.TopPSampler(p=0.1) - # Print the generated sequence (token ids). - print(sampler(prompt, token_probability_fn, max_length=10)) + print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) + # >>> "zzzzzbabcccb" ``` """ @@ -76,40 +65,37 @@ def __init__( self, p=0.1, seed=None, - jit_compile=True, - run_eagerly=False, ): + super().__init__() self.p = p self.seed = seed - super().__init__(jit_compile=jit_compile, run_eagerly=run_eagerly) - def get_next_token(self, next_token_probs): + def get_next_token(self, probabilities): # Sort preds in descending order. sorted_preds, sorted_indices = tf.math.top_k( - next_token_probs, k=tf.shape(next_token_probs)[1], sorted=True + probabilities, k=tf.shape(probabilities)[1], sorted=True ) # Calculate cumulative probability distribution. - cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1) + cumulative_probabilities = tf.math.cumsum(sorted_preds, axis=-1) # Create a mask for the tokens to keep. - keep_mask = cumulative_probs <= self.p + keep_mask = cumulative_probabilities <= self.p # Shift to include the last token that exceed p. shifted_keep_mask = tf.concat( [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1 ) # Filter out unmasked tokens and sample from filtered distribution. - probs = tf.where( + probabilities = tf.where( shifted_keep_mask, sorted_preds, - tf.zeros(tf.shape(next_token_probs), dtype=sorted_preds.dtype), + tf.zeros(tf.shape(probabilities), dtype=sorted_preds.dtype), ) sorted_next_token = tf.random.categorical( - tf.math.log(probs), 1, seed=self.seed + tf.math.log(probabilities), 1, seed=self.seed ) return tf.gather_nd(sorted_indices, sorted_next_token, batch_dims=1) def get_config(self): config = super().get_config() - config.update( { "p": self.p, diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index 6cb56cf423..73efc485cd 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -16,7 +16,6 @@ import numpy as np import tensorflow as tf from absl.testing import parameterized -from tensorflow import keras from keras_nlp.samplers.top_p_sampler import TopPSampler @@ -24,148 +23,86 @@ class TopPSamplerTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() - self.vocab_size = 10 - self.feature_size = 16 - - # Create a dummy model to predict the next token. - model = keras.Sequential( - [ - keras.Input(shape=[None]), - keras.layers.Embedding( - input_dim=self.vocab_size, - output_dim=self.feature_size, - ), - keras.layers.Dense(self.vocab_size), - keras.layers.Softmax(), - ] - ) + # Use a simple alphabet of lowercase characters to [0, 26). + self.int_lookup = {i: chr(i + ord("a")) for i in range(26)} + self.char_lookup = {v: k for k, v in self.int_lookup.items()} + self.batch_size = 1 + self.length = 12 + self.vocab_size = len(self.int_lookup) + + def next(prompt, state, index): + # Return a distribution favoring the next char in state. + logits = tf.one_hot(state[:, index], self.vocab_size) * 1e9 + return logits, state + + self.next = next + self.sampler = TopPSampler(p=0.1) - def token_probability_fn(inputs, mask): - return model(inputs) + def join_as_string(self, x): + return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] + + def test_stateless_call(self): + def next(prompt, state, index): + # Return a distribution favoring the first token in the vocab. + logits = np.zeros((self.batch_size, self.vocab_size)) + logits[:, 0] = 1e9 + return tf.constant(logits), state + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=next, + prompt=prompt, + index=5, + ) + self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) + + def test_stateful_call(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + ) + self.assertEqual(self.join_as_string(output), ["sequentially"]) + + def test_early_stopping(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + end_token_id=self.char_lookup["t"], + ) + self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) - self.token_probability_fn = token_probability_fn - self.sampler = TopPSampler(p=0.1) + def test_outputs_in_top_p(self): + def next(prompt, state, index): + logits = np.zeros((self.batch_size, self.vocab_size)) + return tf.constant(logits), state - def test_generate_with_1d_prompt(self): - inputs = tf.constant([1]) - - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [5]) - - def test_generate_with_2d_prompt(self): - inputs = tf.constant([[1], [1]]) - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_list_prompt(self): - inputs = [[1], [1]] - outputs = self.sampler(inputs, self.token_probability_fn, max_length=5) - self.assertEqual(outputs.shape, [2, 5]) - - def test_generate_with_ragged_prompt(self): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - inputs = tf.ragged.constant([[1], [2, 1, 2]]) - outputs = self.sampler( - inputs, - token_probability_fn, - max_length=5, - from_logits=False, + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = TopPSampler(p=(2.0 / self.vocab_size))( + next=next, + prompt=prompt, ) - self.assertEqual(outputs.shape, [2, 5]) + output_ids = set(output[0].numpy()) + self.assertContainsSubset(output_ids, range(3)) @parameterized.named_parameters( - ("xla_graph", True, False), - ("non_xla_graph", False, False), - ("eager", False, True), + ("jit_compile_false", False), ("jit_compile_true", True) ) - def test_assert_probability_distribution_generation_is_correct( - self, jit_compile, run_eagerly - ): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - batch_size = 10 - inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32) - max_length = 3 - - outputs_count = np.array([0, 0, 0, 0]) - tf.random.set_seed(42) - sampler = TopPSampler( - p=0.1, - seed=42, - run_eagerly=jit_compile, - jit_compile=run_eagerly, - ) - for _ in range(8): - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - from_logits=False, - ) - flatten_predictions = tf.reshape(outputs[:, 1:], [-1]) - for pred in flatten_predictions: - outputs_count[pred] += 1 - self.assertAllClose( - outputs_count / np.sum(outputs_count), - [0.0, 0.0, 0.0, 1.0], - rtol=0.2, - ) + def test_compilation(self, jit_compile): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) - def test_only_choose_from_top_p_tokens(self): - # Test that there are only the top-p tokens in the output. - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.4, 0.3, 0.2, 0.1]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - # Test that it only samples from top-p tokens. - for i, p in enumerate([0.399, 0.699, 0.899]): - inputs = tf.constant([[0, 0], [0, 0]]) - sampler = TopPSampler(p=p) - for _ in range(10): - outputs = sampler( - inputs, - token_probability_fn, - max_length=5, - from_logits=False, - ) - self.assertAllEqual(outputs <= i, tf.ones_like(outputs)) - - def test_end_token_id(self): - def token_probability_fn(inputs, mask): - batch_size, seq_length = tf.shape(inputs)[0], tf.shape(inputs)[1] - prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) - return tf.tile(prob, [batch_size, seq_length, 1]) - - tf.random.set_seed(42) - sampler = TopPSampler(p=0.1, seed=42) - max_length = 4 - inputs = tf.constant([[0, 1], [1, 2]]) - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=2, - from_logits=False, - ) - # end_token in prompt does not trigger truncation. - expected_outputs = tf.ragged.constant([[0, 1, 3, 3], [1, 2, 3, 3]]) - self.assertAllEqual(outputs, expected_outputs) - - outputs = sampler( - inputs, - token_probability_fn, - max_length=max_length, - end_token_id=3, - from_logits=False, - ) - # Generated end_token will be truncated. - expected_outputs = tf.ragged.constant([[0, 1], [1, 2]]) - self.assertAllEqual(outputs, expected_outputs) + @tf.function(jit_compile=jit_compile) + def generate(prompt, state): + return self.sampler(self.next, prompt=prompt, state=state) + + output = generate(prompt, state) + self.assertEqual(self.join_as_string(output), ["sequentially"]) diff --git a/keras_nlp/utils/tf_utils.py b/keras_nlp/utils/tf_utils.py index 1b4cbee7e1..c32b9f7115 100644 --- a/keras_nlp/utils/tf_utils.py +++ b/keras_nlp/utils/tf_utils.py @@ -66,6 +66,14 @@ def tensor_to_string_list(inputs): return _decode_strings_to_utf8(list_outputs) +def truncate_at_token(inputs, token, mask): + """Truncate at first instance of `token`, ignoring `mask`.""" + matches = (inputs == token) & (~mask) + end_indices = tf.cast(tf.math.argmax(matches, -1), "int32") + end_indices = tf.where(end_indices == 0, tf.shape(inputs)[-1], end_indices) + return tf.RaggedTensor.from_tensor(inputs, end_indices) + + def assert_tf_text_installed(symbol_name): """Detokenize and convert tensor to nested lists of python strings.""" if tf_text is None: