Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ class BeamSampler(Sampler):
Args:
num_beams: int. The number of beams that should be kept at each
time-step. `num_beams` should be strictly positive.
return_all_beams: bool. When set to `True`, the sampler will return the top prompt,
all prompts and their respective probabilities score.

Call Args:
{{call_args}}

Examples:
Return only the beam with the highest accumulated probability.
```python
# Use a simple alphabet of lowercase characters to [0, 26).
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
Expand All @@ -60,14 +63,41 @@ def next(prompt, state, index):
print(["".join([int_lookup[i] for i in s]) for s in output.numpy()])
# >>> "zzzzzaaaaaaa"
```
Return all beams and their probabilities.
```python
# Use a simple alphabet of lowercase characters to [0, 26).
int_lookup = {i: chr(i + ord('a')) for i in range(26)}
char_lookup = {v: k for k, v in int_lookup.items()}
batch_size, length, vocab_size = 1, 12, len(int_lookup)

def next(prompt, state, index):
# A uniform distribution over our alphabet.
logits = tf.ones((batch_size, vocab_size))
return logits, state

output = keras_nlp.samplers.BeamSampler(return_all_beams=True)(
next=next,
prompt=tf.fill((batch_size, length,), char_lookup['z']),
index=5,
)

print(output[0].shape)
# >>> (1, 5, 12)
print(output[1].shape)
# >>> (1, 5)
print(["".join([int_lookup[i] for i in s]) for s in output[0][0].numpy()])
# >>> "zzzzzaaaaaaa"
```
"""

def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to document this above!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

self,
num_beams=5,
return_all_beams=False,
):
super().__init__()
self.num_beams = num_beams
self.return_all_beams = return_all_beams

def __call__(
self,
Expand Down Expand Up @@ -161,17 +191,32 @@ def gather_beams(x):
maximum_iterations=(max_length - index),
)

# Gather the top beam at each batch index.
prompt, log_probs = unflatten_beams(prompt), unflatten_beams(log_probs)
top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis]
prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1)
return tf.squeeze(prompt, axis=1)
all_prompts = unflatten_beams(prompt)
all_log_probs = unflatten_beams(log_probs)

if self.return_all_beams:
sorted_indices = tf.argsort(
all_log_probs, axis=-1, direction="DESCENDING"
)
sorted_log_probs = tf.gather(
all_log_probs, sorted_indices, axis=-1, batch_dims=1
)
sorted_prompts = tf.gather(
all_prompts, sorted_indices, axis=1, batch_dims=1
)
return sorted_prompts, sorted_log_probs
else:
# Gather the top beam at each batch index.
top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis]
prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1)
return tf.squeeze(prompt, axis=1)

def get_config(self):
config = super().get_config()
config.update(
{
"num_beams": self.num_beams,
"return_all_beams": self.return_all_beams,
}
)
return config
22 changes: 22 additions & 0 deletions keras_nlp/samplers/beam_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def next(prompt, state, index):

self.next = next
self.sampler = BeamSampler(num_beams=5)
self.sampler_all_beams = BeamSampler(num_beams=5, return_all_beams=True)

def join_as_string(self, x):
return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()]
Expand Down Expand Up @@ -67,6 +68,27 @@ def test_stateful_call(self):
)
self.assertEqual(self.join_as_string(output), ["sequentially"])

def test_return_all_beams(self):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
sorted_prompts, sorted_log_probs = self.sampler_all_beams(
next=self.next,
prompt=prompt,
state=state,
)

self.assertEqual(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also test self.join_as_string(output[0][:, 0, :]) == ["sequentially"] since we are testing returning all beams.

sorted_prompts.shape, (self.batch_size, 5, self.length)
)
self.assertEqual(sorted_log_probs.shape, (self.batch_size, 5))
self.assertTrue(
tf.reduce_all(sorted_log_probs[:, 1:] <= sorted_log_probs[:, :-1])
)
self.assertEqual(
self.join_as_string(sorted_prompts[:, 0, :]), ["sequentially"]
)

def test_early_stopping(self):
state_chars = list("sequentially")
state = tf.constant([[self.char_lookup[c] for c in state_chars]])
Expand Down