diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 4d50ce4dfe..67a0ce8014 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -92,8 +92,8 @@ def sample( mask, num_steps, from_logits=True, + end_token_id=None, cache=None, - token_probs=None, ): """Sampling logic implementation. @@ -139,7 +139,6 @@ def one_step(beams, beams_prob, length, mask): if from_logits: preds = keras.activations.softmax(preds, axis=-1) # Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`. - preds = tf.reshape(preds, shape=[batch_size, -1]) cum_probs = tf.math.log(preds) + tf.repeat( diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index b92a460b0f..cc0b2cb0f6 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -38,7 +38,6 @@ def setUp(self): output_dim=self.feature_size, ), keras.layers.Dense(self.vocab_size), - keras.layers.Softmax(), ] ) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 1d1eb6aee9..8515416e7d 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -94,7 +94,7 @@ def token_probability_fn(inputs, mask): def test_end_token_id(self): def token_probability_fn(inputs, mask): - batch_size = inputs.shape[0] + batch_size = tf.shape(inputs)[0] prob = tf.constant([[[0.0, 0.0, 0.0, 1.0]]]) return tf.repeat( tf.repeat(prob, batch_size, axis=0), max_length, axis=1 diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index a21093aa24..8a763ef435 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -236,8 +236,9 @@ def __call__( token_probability_fn, mask, max_length - shortest_prompt_len, - cache=cache, from_logits=from_logits, + end_token_id=end_token_id, + cache=cache, ) # Mask out tokens after `end_token_id`. if end_token_id is not None: @@ -269,6 +270,7 @@ def sample( mask, num_steps, from_logits=True, + end_token_id=None, cache=None, ): """Sampling logic implementation. @@ -284,6 +286,9 @@ def sample( from_logits: bool, defaults to True. Indicate if the `token_probability_fn` returns logits. If False, `token_probability_fn` returns probability distributions. + end_token_id: int, defaults to None. The token marking the end of + the sequence, once encountered the generation is finished for + the exact sequence. cache: a dense int tensor, the cache used in decoding. The cache stores the key and value of each `keras_nlp.layers.CachedMultiHeadAttention` layer to make the @@ -299,8 +304,9 @@ def sample( # The index of the last non-padding token in prompt. Since all sequences # are aligned to the right side, the index is the same for all. current_index = max_length - num_steps + original_padding_mask = tf.cast(tf.identity(mask), dtype=tf.int32) - def one_step( + def body( current_index, prompt, mask, @@ -316,7 +322,10 @@ def one_step( ) next_token_probs = tf.squeeze(probs, axis=1) else: - probs = token_probability_fn(prompt, mask) + probs = token_probability_fn( + prompt, + mask, + ) next_token_probs = tf.gather( probs, tf.repeat(current_index - 1, batch_size), @@ -343,24 +352,32 @@ def one_step( current_index = tf.add(current_index, 1) if cache is None: return current_index, prompt, mask - return [current_index, prompt, mask, cache] + return current_index, prompt, mask, cache + + def cond(current_index, prompt, mask, cache=None): + if end_token_id is None: + return True + end_token_seen = (prompt == end_token_id) & ( + original_padding_mask == 0 + ) + sequence_done = tf.reduce_any(end_token_seen, axis=-1) + all_done = tf.reduce_all(sequence_done) + return not all_done if cache is None: _, prompt, _ = tf.while_loop( - cond=lambda current_index, prompt, mask: tf.less( - current_index, max_length - ), - body=one_step, - loop_vars=[current_index, prompt, mask], + cond=cond, + body=body, + loop_vars=(current_index, prompt, mask), + maximum_iterations=num_steps, ) return prompt # Run a while loop till `max_length` of tokens has been generated. _, prompt, _, _ = tf.while_loop( - cond=lambda current_index, prompt, mask, cache: tf.less( - current_index, max_length - ), - body=one_step, - loop_vars=[current_index, prompt, mask, cache], + cond=cond, + body=body, + loop_vars=(current_index, prompt, mask, cache), + maximum_iterations=num_steps, ) return prompt