Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions keras_nlp/samplers/top_p_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class TopPSampler(Sampler):

Args:
p: float, the `p` value of top-p.
k: int, defaults to None. If set, this argument defines a
heuristic "top-k" cutoff applied before the "top-p" sampling. All
logits not in the top `k` will be discarded, and the remaining
logits will be sorted to find a cutoff point for `p`. Setting this
arg can significantly speed sampling up by reducing the number
of tokens to sort.
seed: int, defaults to None. The random seed.

Call Args:
Expand Down Expand Up @@ -65,16 +71,21 @@ def next(prompt, cache, index):
def __init__(
self,
p=0.1,
k=None,
seed=None,
):
super().__init__()
self.p = p
self.k = k
self.seed = seed

def get_next_token(self, probabilities):
# Sort preds in descending order.
cutoff = tf.shape(probabilities)[1]
if self.k is not None:
# If `k` is set, only sample from top `k` tokens.
cutoff = tf.math.minimum(cutoff, self.k)
sorted_preds, sorted_indices = tf.math.top_k(
probabilities, k=tf.shape(probabilities)[1], sorted=True
probabilities, k=cutoff, sorted=True
)
# Calculate cumulative probability distribution.
cumulative_probabilities = tf.math.cumsum(sorted_preds, axis=-1)
Expand All @@ -88,7 +99,7 @@ def get_next_token(self, probabilities):
probabilities = tf.where(
shifted_keep_mask,
sorted_preds,
tf.zeros(tf.shape(probabilities), dtype=sorted_preds.dtype),
tf.zeros(tf.shape(sorted_preds), dtype=sorted_preds.dtype),
)
sorted_next_token = tf.random.categorical(
tf.math.log(probabilities), 1, seed=self.seed
Expand All @@ -100,6 +111,7 @@ def get_config(self):
config.update(
{
"p": self.p,
"k": self.k,
"seed": self.seed,
}
)
Expand Down
19 changes: 19 additions & 0 deletions keras_nlp/samplers/top_p_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ def test_early_stopping(self):
)
self.assertEqual(self.join_as_string(output), ["sequentzzzzz"])

def test_only_sample_from_top_k_tokens(self):
def next(prompt, cache, index):
# Dummy hidden states.
hidden_states = tf.ones([self.batch_size, 5])
# Return a distribution where each id is progressively less likely.
logits = tf.range(self.vocab_size, 0, -1, dtype="float32")
logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0)
return logits, hidden_states, cache

prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
output = TopPSampler(p=1, k=5)(
next=next,
prompt=prompt,
index=5,
)
generated_str = self.join_as_string(output[:, 5:])[0]
token_set = set(generated_str)
self.assertContainsSubset(token_set, set("abcde"))

def test_outputs_in_top_p(self):
def next(prompt, cache, index):
# Dummy hidden states.
Expand Down