From 927861349e0a4ed7bfc27f9c8ae2a8a16822d6c6 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Thu, 23 Mar 2023 20:05:25 +0530 Subject: [PATCH 01/13] Returning all Beams and Probs as well as adding a Testing Unit --- keras_nlp/samplers/beam_sampler.py | 9 ++++++++- keras_nlp/samplers/beam_sampler_test.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 03b390a3a6..4dbd50f366 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -65,9 +65,11 @@ def next(prompt, state, index): def __init__( self, num_beams=5, + return_all_beams=False, ): super().__init__() self.num_beams = num_beams + self.return_all_beams = return_all_beams def __call__( self, @@ -165,13 +167,18 @@ def gather_beams(x): prompt, log_probs = unflatten_beams(prompt), unflatten_beams(log_probs) top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis] prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1) - return tf.squeeze(prompt, axis=1) + + if self.return_all_beams: + return tf.squeeze(prompt, axis=1), prompt, log_probs + else: + return tf.squeeze(prompt, axis=1) def get_config(self): config = super().get_config() config.update( { "num_beams": self.num_beams, + "return_all_beams": self.return_all_beams, } ) return config diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 95290b8f1a..8d08fa47d6 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -67,6 +67,22 @@ def test_stateful_call(self): ) self.assertEqual(self.join_as_string(output), ["sequentially"]) + def test_return_all_beams(self): + state_chars = list("sequentially") + state = tf.constant([[self.char_lookup[c] for c in state_chars]]) + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + state=state, + return_all_beams=True, + ) + self.assertEqual(self.join_as_string(output[0]), ["sequentially"]) + self.assertLen(output[1], 5) + self.assertEqual(self.join_as_string(output[1]), ["sequentially"] * 5) + self.assertLen(output[2], 5) + self.assertEqual(output[2], [1.0] * 5) + def test_early_stopping(self): state_chars = list("sequentially") state = tf.constant([[self.char_lookup[c] for c in state_chars]]) From c295fd9ac6c028850eee53e36970592857618b23 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Thu, 23 Mar 2023 21:08:51 +0530 Subject: [PATCH 02/13] modified test 1 --- keras_nlp/samplers/beam_sampler_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 8d08fa47d6..3a8b6831ff 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -78,10 +78,10 @@ def test_return_all_beams(self): return_all_beams=True, ) self.assertEqual(self.join_as_string(output[0]), ["sequentially"]) - self.assertLen(output[1], 5) - self.assertEqual(self.join_as_string(output[1]), ["sequentially"] * 5) - self.assertLen(output[2], 5) - self.assertEqual(output[2], [1.0] * 5) + # self.assertLen(output[1], 5) + # self.assertEqual(self.join_as_string(output[1]), ["sequentially"] * 5) + # self.assertLen(output[2], 5) + # self.assertEqual(output[2], [1.0] * 5) def test_early_stopping(self): state_chars = list("sequentially") From ccb723433ecfd1f6409e213c024fa3156ab34630 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Fri, 24 Mar 2023 01:22:46 +0530 Subject: [PATCH 03/13] overriding changes --- keras_nlp/samplers/beam_sampler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 4dbd50f366..53c73454b1 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -79,7 +79,13 @@ def __call__( index=0, mask=None, end_token_id=None, + return_all_beams=None, ): + if return_all_beams is None: + return_all_beams = self.return_all_beams + else: + return_all_beams = bool(return_all_beams) + batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. index = tf.cast(index, max_length.dtype) @@ -168,7 +174,7 @@ def gather_beams(x): top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis] prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1) - if self.return_all_beams: + if return_all_beams: return tf.squeeze(prompt, axis=1), prompt, log_probs else: return tf.squeeze(prompt, axis=1) From 608d5586680db00e8a4810195ce0962cb08bdec6 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Fri, 24 Mar 2023 02:10:06 +0530 Subject: [PATCH 04/13] added more tests since pipeline works :) --- keras_nlp/samplers/beam_sampler_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 3a8b6831ff..f321de5123 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -78,9 +78,9 @@ def test_return_all_beams(self): return_all_beams=True, ) self.assertEqual(self.join_as_string(output[0]), ["sequentially"]) - # self.assertLen(output[1], 5) - # self.assertEqual(self.join_as_string(output[1]), ["sequentially"] * 5) - # self.assertLen(output[2], 5) + self.assertLen(output[1], 5) + self.assertEqual(self.join_as_string(output[1]), ["sequentially"] * 5) + self.assertLen(output[2], 5) # self.assertEqual(output[2], [1.0] * 5) def test_early_stopping(self): From 7cc329722dd9615761c83d11d5f193b71fe82171 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Fri, 24 Mar 2023 03:25:21 +0530 Subject: [PATCH 05/13] fixed test cases and variable names --- keras_nlp/samplers/beam_sampler.py | 10 ++++++---- keras_nlp/samplers/beam_sampler_test.py | 6 ++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 53c73454b1..4674264539 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -170,12 +170,14 @@ def gather_beams(x): ) # Gather the top beam at each batch index. - prompt, log_probs = unflatten_beams(prompt), unflatten_beams(log_probs) - top_beams = tf.math.argmax(log_probs, axis=-1)[:, tf.newaxis] - prompt = tf.gather(prompt, top_beams, axis=1, batch_dims=1) + all_prompts, all_log_probs = unflatten_beams(prompt), unflatten_beams( + log_probs + ) + top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis] + prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1) if return_all_beams: - return tf.squeeze(prompt, axis=1), prompt, log_probs + return tf.squeeze(prompt, axis=1), all_prompts, all_log_probs else: return tf.squeeze(prompt, axis=1) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index f321de5123..17b9c0cdd9 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -78,10 +78,8 @@ def test_return_all_beams(self): return_all_beams=True, ) self.assertEqual(self.join_as_string(output[0]), ["sequentially"]) - self.assertLen(output[1], 5) - self.assertEqual(self.join_as_string(output[1]), ["sequentially"] * 5) - self.assertLen(output[2], 5) - # self.assertEqual(output[2], [1.0] * 5) + self.assertEqual(output[1].shape, (self.batch_size, 5, self.length)) + self.assertEqual(output[2].shape, (self.batch_size, 5, self.length)) def test_early_stopping(self): state_chars = list("sequentially") From 6084f7d189b3dd664ad87fb7af81c8e32f4c46b5 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Fri, 24 Mar 2023 03:52:27 +0530 Subject: [PATCH 06/13] fixed log_prob dimension test --- keras_nlp/samplers/beam_sampler_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 17b9c0cdd9..a1a7ba0a19 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -79,7 +79,7 @@ def test_return_all_beams(self): ) self.assertEqual(self.join_as_string(output[0]), ["sequentially"]) self.assertEqual(output[1].shape, (self.batch_size, 5, self.length)) - self.assertEqual(output[2].shape, (self.batch_size, 5, self.length)) + self.assertEqual(output[2].shape, (self.batch_size, 5)) def test_early_stopping(self): state_chars = list("sequentially") From d1c7d90a88924d110c3228ac702d9e3663b63836 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Fri, 24 Mar 2023 10:36:54 +0530 Subject: [PATCH 07/13] tried removing call argument --- keras_nlp/samplers/beam_sampler.py | 12 ++++++------ keras_nlp/samplers/beam_sampler_test.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 4674264539..6e1ea5dc37 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -79,12 +79,12 @@ def __call__( index=0, mask=None, end_token_id=None, - return_all_beams=None, + # return_all_beams=None, ): - if return_all_beams is None: - return_all_beams = self.return_all_beams - else: - return_all_beams = bool(return_all_beams) + # if return_all_beams is None: + # return_all_beams = self.return_all_beams + # else: + # return_all_beams = bool(return_all_beams) batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. @@ -176,7 +176,7 @@ def gather_beams(x): top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis] prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1) - if return_all_beams: + if self.return_all_beams: return tf.squeeze(prompt, axis=1), all_prompts, all_log_probs else: return tf.squeeze(prompt, axis=1) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index a1a7ba0a19..1736ad839e 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -37,6 +37,7 @@ def next(prompt, state, index): self.next = next self.sampler = BeamSampler(num_beams=5) + self.sampler_all_beams = BeamSampler(num_beams=5, return_all_beams=True) def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] @@ -71,11 +72,10 @@ def test_return_all_beams(self): state_chars = list("sequentially") state = tf.constant([[self.char_lookup[c] for c in state_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) - output = self.sampler( + output = self.sampler_all_beams( next=self.next, prompt=prompt, state=state, - return_all_beams=True, ) self.assertEqual(self.join_as_string(output[0]), ["sequentially"]) self.assertEqual(output[1].shape, (self.batch_size, 5, self.length)) From 726ee8d820d00425130a1d2187303a0f086958d8 Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Fri, 24 Mar 2023 18:28:49 +0530 Subject: [PATCH 08/13] modified documentation temporarily --- keras_nlp/samplers/beam_sampler.py | 32 +++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 6e1ea5dc37..1e8457cd03 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -36,11 +36,14 @@ class BeamSampler(Sampler): Args: num_beams: int. The number of beams that should be kept at each time-step. `num_beams` should be strictly positive. + return_all_beams: bool. When set to `True`, the sampler will return the top prompt, + all prompts and their respective probabilities score. Call Args: {{call_args}} Examples: + Example 1: ```python # Use a simple alphabet of lowercase characters to [0, 26). int_lookup = {i: chr(i + ord('a')) for i in range(26)} @@ -60,6 +63,30 @@ def next(prompt, state, index): print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) # >>> "zzzzzaaaaaaa" ``` + Example 2: + ```python + # Use a simple alphabet of lowercase characters to [0, 26). + int_lookup = {i: chr(i + ord('a')) for i in range(26)} + char_lookup = {v: k for k, v in int_lookup.items()} + batch_size, length, vocab_size = 1, 12, len(int_lookup) + + def next(prompt, state, index): + # A uniform distribution over our alphabet. + logits = tf.ones((batch_size, vocab_size)) + return logits, state + + output = keras_nlp.samplers.BeamSampler(return_all_beams=True)( + next=next, + prompt=tf.fill((batch_size, length,), char_lookup['z']), + index=5, + ) + print(["".join([int_lookup[i] for i in s]) for s in output[0].numpy()]) + print(output[1].shape) + print(output[2].shape) + # >>> "zzzzzaaaaaaa" + # >>> (1, 5, 12) + # >>> (1, 5) + ``` """ def __init__( @@ -79,12 +106,7 @@ def __call__( index=0, mask=None, end_token_id=None, - # return_all_beams=None, ): - # if return_all_beams is None: - # return_all_beams = self.return_all_beams - # else: - # return_all_beams = bool(return_all_beams) batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. From 1248b7f4146ec58a1ea2653710aa17b2182b585d Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Sat, 25 Mar 2023 12:19:33 +0530 Subject: [PATCH 09/13] sorting the prompts and scores --- keras_nlp/samplers/beam_sampler.py | 10 +++++++--- keras_nlp/samplers/beam_sampler_test.py | 7 ++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 1e8457cd03..2b20e41547 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -36,7 +36,7 @@ class BeamSampler(Sampler): Args: num_beams: int. The number of beams that should be kept at each time-step. `num_beams` should be strictly positive. - return_all_beams: bool. When set to `True`, the sampler will return the top prompt, + return_all_beams: bool. When set to `True`, the sampler will return the top prompt, all prompts and their respective probabilities score. Call Args: @@ -107,7 +107,6 @@ def __call__( mask=None, end_token_id=None, ): - batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1] # Make sure max length and start index are the same dtype. index = tf.cast(index, max_length.dtype) @@ -199,7 +198,12 @@ def gather_beams(x): prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1) if self.return_all_beams: - return tf.squeeze(prompt, axis=1), all_prompts, all_log_probs + sorted_indices = tf.argsort( + all_log_probs, axis=-1, direction="DESCENDING" + ) + sorted_log_probs = tf.gather(all_log_probs, sorted_indices, axis=-1) + sorted_prompts = tf.gather(all_prompts, sorted_indices, axis=1) + return sorted_prompts, sorted_log_probs else: return tf.squeeze(prompt, axis=1) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 1736ad839e..ce6abf6505 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -77,9 +77,10 @@ def test_return_all_beams(self): prompt=prompt, state=state, ) - self.assertEqual(self.join_as_string(output[0]), ["sequentially"]) - self.assertEqual(output[1].shape, (self.batch_size, 5, self.length)) - self.assertEqual(output[2].shape, (self.batch_size, 5)) + + self.assertEqual(output[0].shape, (self.batch_size, 5, self.length)) + self.assertEqual(output[1].shape, (self.batch_size, 5)) + self.assertTrue(tf.reduce_all(output[1][:, 1:] <= output[1][:, :-1])) def test_early_stopping(self): state_chars = list("sequentially") From 6106943770b2aee68ecd7c0511b07e50a64016af Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Sat, 25 Mar 2023 13:24:49 +0530 Subject: [PATCH 10/13] added batchdim as 1 to fix the extra dim error --- keras_nlp/samplers/beam_sampler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 2b20e41547..a715ed0543 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -201,8 +201,12 @@ def gather_beams(x): sorted_indices = tf.argsort( all_log_probs, axis=-1, direction="DESCENDING" ) - sorted_log_probs = tf.gather(all_log_probs, sorted_indices, axis=-1) - sorted_prompts = tf.gather(all_prompts, sorted_indices, axis=1) + sorted_log_probs = tf.gather( + all_log_probs, sorted_indices, axis=-1, batch_dims=1 + ) + sorted_prompts = tf.gather( + all_prompts, sorted_indices, axis=1, batch_dims=1 + ) return sorted_prompts, sorted_log_probs else: return tf.squeeze(prompt, axis=1) From 96f4656c991f7674ea3b01e625349c7a8bb8982b Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Sat, 25 Mar 2023 15:30:38 +0530 Subject: [PATCH 11/13] added 1 more test, changed documentation --- keras_nlp/samplers/beam_sampler.py | 9 +++++---- keras_nlp/samplers/beam_sampler_test.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index a715ed0543..eff45c9479 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -80,12 +80,13 @@ def next(prompt, state, index): prompt=tf.fill((batch_size, length,), char_lookup['z']), index=5, ) - print(["".join([int_lookup[i] for i in s]) for s in output[0].numpy()]) - print(output[1].shape) - print(output[2].shape) - # >>> "zzzzzaaaaaaa" + + print(output[0].shape) # >>> (1, 5, 12) + print(output[1].shape) # >>> (1, 5) + print(["".join([int_lookup[i] for i in s]) for s in output[0][0].numpy()]) + # >>> "zzzzzaaaaaaa" ``` """ diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index ce6abf6505..9074ae54f1 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -81,6 +81,9 @@ def test_return_all_beams(self): self.assertEqual(output[0].shape, (self.batch_size, 5, self.length)) self.assertEqual(output[1].shape, (self.batch_size, 5)) self.assertTrue(tf.reduce_all(output[1][:, 1:] <= output[1][:, :-1])) + self.assertEqual( + self.join_as_string(output[0][:, 0, :]), ["sequentially"] + ) def test_early_stopping(self): state_chars = list("sequentially") From 45e811774bc39a52c40e4216d3bf459432723abd Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Sun, 26 Mar 2023 19:10:05 +0530 Subject: [PATCH 12/13] made the style changes --- keras_nlp/samplers/beam_sampler.py | 15 +++++++-------- keras_nlp/samplers/beam_sampler_test.py | 14 +++++++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index eff45c9479..854caf3bc9 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -43,7 +43,7 @@ class BeamSampler(Sampler): {{call_args}} Examples: - Example 1: + 1. Return only the beam with the highest accumulated probability. ```python # Use a simple alphabet of lowercase characters to [0, 26). int_lookup = {i: chr(i + ord('a')) for i in range(26)} @@ -63,7 +63,7 @@ def next(prompt, state, index): print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) # >>> "zzzzzaaaaaaa" ``` - Example 2: + 2. Return all beams and their probabilities. ```python # Use a simple alphabet of lowercase characters to [0, 26). int_lookup = {i: chr(i + ord('a')) for i in range(26)} @@ -191,12 +191,8 @@ def gather_beams(x): maximum_iterations=(max_length - index), ) - # Gather the top beam at each batch index. - all_prompts, all_log_probs = unflatten_beams(prompt), unflatten_beams( - log_probs - ) - top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis] - prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1) + all_prompts = unflatten_beams(prompt) + all_log_probs = unflatten_beams(log_probs) if self.return_all_beams: sorted_indices = tf.argsort( @@ -210,6 +206,9 @@ def gather_beams(x): ) return sorted_prompts, sorted_log_probs else: + # Gather the top beam at each batch index. + top_beams = tf.math.argmax(all_log_probs, axis=-1)[:, tf.newaxis] + prompt = tf.gather(all_prompts, top_beams, axis=1, batch_dims=1) return tf.squeeze(prompt, axis=1) def get_config(self): diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index 9074ae54f1..e20c66035e 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -72,17 +72,21 @@ def test_return_all_beams(self): state_chars = list("sequentially") state = tf.constant([[self.char_lookup[c] for c in state_chars]]) prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) - output = self.sampler_all_beams( + sorted_prompts, sorted_log_probs = self.sampler_all_beams( next=self.next, prompt=prompt, state=state, ) - self.assertEqual(output[0].shape, (self.batch_size, 5, self.length)) - self.assertEqual(output[1].shape, (self.batch_size, 5)) - self.assertTrue(tf.reduce_all(output[1][:, 1:] <= output[1][:, :-1])) self.assertEqual( - self.join_as_string(output[0][:, 0, :]), ["sequentially"] + sorted_prompts.shape, (self.batch_size, 5, self.length) + ) + self.assertEqual(sorted_log_probs.shape, (self.batch_size, 5)) + self.assertTrue( + tf.reduce_all(sorted_log_probs[:, 1:] <= sorted_log_probs[:, :-1]) + ) + self.assertEqual( + self.join_as_string(sorted_prompts[:, 0, :]), ["sequentially"] ) def test_early_stopping(self): From 63cf74665f25a422b23fcc5afbe196af2682714b Mon Sep 17 00:00:00 2001 From: TheAthleticCoder Date: Tue, 28 Mar 2023 10:08:28 +0530 Subject: [PATCH 13/13] fixed minor comment --- keras_nlp/samplers/beam_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 854caf3bc9..a8664ba675 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -43,7 +43,7 @@ class BeamSampler(Sampler): {{call_args}} Examples: - 1. Return only the beam with the highest accumulated probability. + Return only the beam with the highest accumulated probability. ```python # Use a simple alphabet of lowercase characters to [0, 26). int_lookup = {i: chr(i + ord('a')) for i in range(26)} @@ -63,7 +63,7 @@ def next(prompt, state, index): print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) # >>> "zzzzzaaaaaaa" ``` - 2. Return all beams and their probabilities. + Return all beams and their probabilities. ```python # Use a simple alphabet of lowercase characters to [0, 26). int_lookup = {i: chr(i + ord('a')) for i in range(26)}