diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 03b390a3a6..0e1d7d3808 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -132,6 +132,9 @@ def body(prompt, state, index, log_probs): ) beam_indices = indices // vocab_size next_token = flatten_beams(indices % vocab_size) + # Ensure shape is `[None]`, otherwise it causes issues after + # converting to TFLite. + next_token = tf.ensure_shape(next_token, [None]) # We need `ensure_shape` as `top_k` will change the static shape. next_log_probs = flatten_beams(next_log_probs) log_probs = tf.ensure_shape(next_log_probs, log_probs.shape) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index dde0598f68..39cb80b776 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -110,11 +110,13 @@ def body(prompt, state, index): # Compute the softmax distribution for the next token. logits, state = next(prompt, state, index) probabilities = keras.activations.softmax(logits) - # Compute the next token. next_token = self.get_next_token(probabilities) # Don't overwrite anywhere mask is True. next_token = tf.cast(next_token, prompt.dtype) + # Ensure shape is `[None]`, otherwise it causes issues after + # converting to TFLite. + next_token = tf.ensure_shape(next_token, [None]) next_token = tf.where(mask[:, index], prompt[:, index], next_token) # Update the prompt with the next token. next_token = next_token[:, tf.newaxis]