diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 03b390a3a6..a8664ba675 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -36,11 +36,14 @@ class BeamSampler(Sampler): Args: num_beams: int. The number of beams that should be kept at each time-step. `num_beams` should be strictly positive. + return_all_beams: bool. When set to `True`, the sampler will return the top prompt, + all prompts and their respective probabilities score. Call Args: {{call_args}} Examples: + Return only the beam with the highest accumulated probability. ```python # Use a simple alphabet of lowercase characters to [0, 26). int_lookup = {i: chr(i + ord('a')) for i in range(26)} @@ -60,14 +63,41 @@ def next(prompt, state, index): print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) # >>> "zzzzzaaaaaaa" ``` + Return all beams and their probabilities. + ```python + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # A uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.BeamSampler(return_all_beams=True)( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, + ) + + print(output[0].shape) + # >>> (1, 5, 12) + print(output[1].shape) + # >>> (1, 5) + print(["".join([int_lookup[i] for i in s]) for s in output[0][0].numpy()]) + # >>> "zzzzzaaaaaaa" + ``` """ def __init__( self, num_beams=5, + return_all_beams=False, ): super().__init__() self.num_beams = num_beams + self.return_all_beams = return_all_beams def __call__( self, @@ -161,17 +191,32 @@ def gather_beams(x): maximum_iterations=(max_length - index), ) - # Gather the top beam at each batch index. - prompt, log_probs = unflatten_beams(prompt), unflatten_beams(log_probs) - top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis] - prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1) - return tf.squeeze(prompt, axis=1) + all_prompts = unflatten_beams(prompt) + all_log_probs = unflatten_beams(log_probs) + + if self.return_all_beams: + sorted_indices = tf.argsort( + all_log_probs, axis=-1, direction="DESCENDING" + ) + sorted_log_probs = tf.gather( + all_log_probs, sorted_indices, axis=-1, batch_dims=1 + ) + sorted_prompts = tf.gather( + all_prompts, sorted_indices, axis=1, batch_dims=1 + ) + return sorted_prompts, sorted_log_probs + else: + # Gather the top beam at each batch index. + top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis] + prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1) + return tf.squeeze(prompt, axis=1) def get_config(self): config = super().get_config() config.update( { "num_beams": self.num_beams, + "return_all_beams": self.return_all_beams, } ) return config diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 95290b8f1a..e20c66035e 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -37,6 +37,7 @@ def next(prompt, state, index): self.next = next self.sampler = BeamSampler(num_beams=5) + self.sampler_all_beams = BeamSampler(num_beams=5, return_all_beams=True) def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] @@ -67,6 +68,27 @@ def test_stateful_call(self): ) self.assertEqual(self.join_as_string(output), ["sequentially"]) + def test_return_all_beams(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + sorted_prompts, sorted_log_probs = self.sampler_all_beams( + next=self.next, + prompt=prompt, + state=state, + ) + + self.assertEqual( + sorted_prompts.shape, (self.batch_size, 5, self.length) + ) + self.assertEqual(sorted_log_probs.shape, (self.batch_size, 5)) + self.assertTrue( + tf.reduce_all(sorted_log_probs[:, 1:] <= sorted_log_probs[:, :-1]) + ) + self.assertEqual( + self.join_as_string(sorted_prompts[:, 0, :]), ["sequentially"] + ) + def test_early_stopping(self): state_chars = list("sequentially") state = tf.constant([[self.char_lookup[c] for c in state_chars]])