From 20e58716261fa11e9cbe493934815c8128717fe7 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 10 Apr 2023 18:48:21 -0700 Subject: [PATCH 1/3] Speed top-p sampler up by only sampling from top-k tokens --- keras_nlp/samplers/top_p_sampler.py | 24 +++++++++++++++++++----- keras_nlp/samplers/top_p_sampler_test.py | 22 ++++++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index f38331b9a3..656a80781a 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -34,6 +34,8 @@ class TopPSampler(Sampler): Args: p: float, the `p` value of top-p. + sample_from_top_k: int, defaults to None. If set, only sample from top + `sample_from_top_k` tokens. This makes top-p sampler faster. seed: int, defaults to None. The random seed. Call Args: @@ -65,17 +67,28 @@ def next(prompt, cache, index): def __init__( self, p=0.1, + sample_from_top_k=40, seed=None, ): super().__init__() self.p = p + self.sample_from_top_k = sample_from_top_k self.seed = seed def get_next_token(self, probabilities): - # Sort preds in descending order. - sorted_preds, sorted_indices = tf.math.top_k( - probabilities, k=tf.shape(probabilities)[1], sorted=True - ) + if self.sample_from_top_k is not None: + k = tf.math.minimum( + self.sample_from_top_k, tf.shape(probabilities)[1] + ) + # Filter out top-k tokens. + sorted_preds, sorted_indices = tf.math.top_k( + probabilities, k=k, sorted=True + ) + else: + # Sort preds in descending order. + sorted_preds, sorted_indices = tf.math.top_k( + probabilities, k=tf.shape(probabilities)[1], sorted=True + ) # Calculate cumulative probability distribution. cumulative_probabilities = tf.math.cumsum(sorted_preds, axis=-1) # Create a mask for the tokens to keep. @@ -88,7 +101,7 @@ def get_next_token(self, probabilities): probabilities = tf.where( shifted_keep_mask, sorted_preds, - tf.zeros(tf.shape(probabilities), dtype=sorted_preds.dtype), + tf.zeros(tf.shape(sorted_preds), dtype=sorted_preds.dtype), ) sorted_next_token = tf.random.categorical( tf.math.log(probabilities), 1, seed=self.seed @@ -100,6 +113,7 @@ def get_config(self): config.update( { "p": self.p, + "sample_from_top_k": self.sample_from_top_k, "seed": self.seed, } ) diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index 68afbb283a..b8105943b0 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -88,6 +88,28 @@ def test_early_stopping(self): ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) + def test_sample_from_all_tokens(self): + def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([self.batch_size, 5]) + # Return a distribution favoring the first token in the vocab. + logits = ( + tf.one_hot( + tf.zeros(self.batch_size, dtype=tf.int32), + self.vocab_size, + ) + * 1e9 + ) + return logits, hidden_states, cache + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = TopPSampler(p=0.1, sample_from_top_k=None)( + next=next, + prompt=prompt, + index=5, + ) + self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) + def test_outputs_in_top_p(self): def next(prompt, cache, index): # Dummy hidden states. From 208c07d68f66652448dfd7d3cb099c28a3d0a117 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 11 Apr 2023 17:36:19 -0700 Subject: [PATCH 2/3] address comments --- keras_nlp/samplers/top_p_sampler.py | 30 ++++++++++-------------- keras_nlp/samplers/top_p_sampler_test.py | 19 +++++++-------- 2 files changed, 20 insertions(+), 29 deletions(-) diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 656a80781a..6e471b1398 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -34,8 +34,8 @@ class TopPSampler(Sampler): Args: p: float, the `p` value of top-p. - sample_from_top_k: int, defaults to None. If set, only sample from top - `sample_from_top_k` tokens. This makes top-p sampler faster. + k: int, defaults to None. If set, only sample from top `k` tokens. + This makes top-p sampler faster by avoiding sorting all tokens. seed: int, defaults to None. The random seed. Call Args: @@ -67,28 +67,22 @@ def next(prompt, cache, index): def __init__( self, p=0.1, - sample_from_top_k=40, + k=None, seed=None, ): super().__init__() self.p = p - self.sample_from_top_k = sample_from_top_k + self.k = k self.seed = seed def get_next_token(self, probabilities): - if self.sample_from_top_k is not None: - k = tf.math.minimum( - self.sample_from_top_k, tf.shape(probabilities)[1] - ) - # Filter out top-k tokens. - sorted_preds, sorted_indices = tf.math.top_k( - probabilities, k=k, sorted=True - ) - else: - # Sort preds in descending order. - sorted_preds, sorted_indices = tf.math.top_k( - probabilities, k=tf.shape(probabilities)[1], sorted=True - ) + cutoff = tf.shape(probabilities)[1] + if self.k is not None: + # If `k` is set, only sample from top `k` tokens. + cutoff = tf.math.minimum(cutoff, self.k) + sorted_preds, sorted_indices = tf.math.top_k( + probabilities, k=cutoff, sorted=True + ) # Calculate cumulative probability distribution. cumulative_probabilities = tf.math.cumsum(sorted_preds, axis=-1) # Create a mask for the tokens to keep. @@ -113,7 +107,7 @@ def get_config(self): config.update( { "p": self.p, - "sample_from_top_k": self.sample_from_top_k, + "k": self.k, "seed": self.seed, } ) diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index b8105943b0..f06563ac58 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -88,27 +88,24 @@ def test_early_stopping(self): ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) - def test_sample_from_all_tokens(self): + def test_only_sample_from_top_k_tokens(self): def next(prompt, cache, index): # Dummy hidden states. hidden_states = tf.ones([self.batch_size, 5]) - # Return a distribution favoring the first token in the vocab. - logits = ( - tf.one_hot( - tf.zeros(self.batch_size, dtype=tf.int32), - self.vocab_size, - ) - * 1e9 - ) + # Return a distribution where each id is progressively less likely. + logits = tf.range(self.vocab_size, 0, -1, dtype="float32") + logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0) return logits, hidden_states, cache prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) - output = TopPSampler(p=0.1, sample_from_top_k=None)( + output = TopPSampler(p=1, k=5)( next=next, prompt=prompt, index=5, ) - self.assertEqual(self.join_as_string(output), ["zzzzzaaaaaaa"]) + generated_str = self.join_as_string(output[:, 5:])[0] + token_set = set(generated_str) + self.assertContainsSubset(token_set, set("abcde")) def test_outputs_in_top_p(self): def next(prompt, cache, index): From a536bec891f14bce0181ab1ca9f38a15ea3b2ba3 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Wed, 12 Apr 2023 17:47:41 -0700 Subject: [PATCH 3/3] better docstring --- keras_nlp/samplers/top_p_sampler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 6e471b1398..ec5c4bec8e 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -34,8 +34,12 @@ class TopPSampler(Sampler): Args: p: float, the `p` value of top-p. - k: int, defaults to None. If set, only sample from top `k` tokens. - This makes top-p sampler faster by avoiding sorting all tokens. + k: int, defaults to None. If set, this argument defines a + heuristic "top-k" cutoff applied before the "top-p" sampling. All + logits not in the top `k` will be discarded, and the remaining + logits will be sorted to find a cutoff point for `p`. Setting this + arg can significantly speed sampling up by reducing the number + of tokens to sort. seed: int, defaults to None. The random seed. Call Args: