diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index cbb6b730f4..2830640d7e 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -97,8 +97,9 @@ def __init__( self, num_beams=5, return_all_beams=False, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.num_beams = num_beams self.return_all_beams = return_all_beams @@ -156,7 +157,7 @@ def body(prompt, cache, index, log_probs): # Compute the softmax distribution for the next token. logits, _, cache = next(prompt, cache, index) vocab_size = tf.shape(logits)[-1] - probs = keras.activations.softmax(logits) + probs = keras.activations.softmax(logits / self.temperature) # Compute the running log-likelihood of each new candidate. next_log_probs = tf.math.log(probs) + log_probs[..., tf.newaxis] diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index ff0d671834..979a300c14 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -38,7 +38,7 @@ def next(prompt, cache, index): return logits, hidden_states, cache self.next = next - self.sampler = BeamSampler(num_beams=5) + self.sampler = BeamSampler(num_beams=5, temperature=1.0) self.sampler_all_beams = BeamSampler(num_beams=5, return_all_beams=True) def join_as_string(self, x): diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 9d4cf7389e..6064270ed8 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -55,8 +55,11 @@ def next(prompt, cache, index): ``` """ - def __init__(self): - super().__init__() + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) def get_next_token(self, probabilities): return tf.argmax(probabilities, axis=-1) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index f45902f525..e612277e86 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -37,7 +37,7 @@ def next(prompt, cache, index): return logits, hidden_states, cache self.next = next - self.sampler = GreedySampler() + self.sampler = GreedySampler(temperature=1.0) def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index 9d2c731757..a0a945cd05 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -62,8 +62,9 @@ def next(prompt, state, index): def __init__( self, seed=None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.seed = seed def get_next_token(self, probabilities): diff --git a/keras_nlp/samplers/random_sampler_test.py b/keras_nlp/samplers/random_sampler_test.py index 2c08b09780..7ae739fa9b 100644 --- a/keras_nlp/samplers/random_sampler_test.py +++ b/keras_nlp/samplers/random_sampler_test.py @@ -38,7 +38,7 @@ def next(prompt, cache, index): return logits, hidden_states, cache self.next = next - self.sampler = RandomSampler() + self.sampler = RandomSampler(temperature=1.0) def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] @@ -71,6 +71,22 @@ def test_stateful_call(self): ) self.assertEqual(self.join_as_string(output), ["sequentially"]) + def test_temperature(self): + def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([self.batch_size, 5]) + logits = tf.range(self.vocab_size, 0, -1, dtype=tf.float32) + logits = tf.reshape(logits[tf.newaxis, :], (self.batch_size, -1)) + return tf.constant(logits), hidden_states, cache + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + + output = RandomSampler(temperature=1e-5)( + next=next, + prompt=prompt, + ) + self.assertAllEqual(output, tf.zeros_like(output)) + def test_early_stopping(self): cache_chars = list("sequentially") cache = tf.constant([[self.char_lookup[c] for c in cache_chars]]) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 7ccb35b288..e1c5ae5874 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -48,6 +48,11 @@ class Sampler: """Base sampler class. + Args: + temperature: float. optional. defaults to '1.0'. Used to control the + randomness of the sampling. The higher the temperature, the + more diverse the samples. + Call Args: {{call_args}} @@ -84,6 +89,12 @@ def next(prompt, cache, index): ``` """ + def __init__( + self, + temperature=1.0, + ): + self.temperature = temperature + def __call__( self, next, @@ -115,7 +126,7 @@ def cond(prompt, cache, index): def body(prompt, cache, index): # Compute the softmax distribution for the next token. logits, _, cache = next(prompt, cache, index) - probabilities = keras.activations.softmax(logits) + probabilities = keras.activations.softmax(logits / self.temperature) # Compute the next token. next_token = self.get_next_token(probabilities) # Don't overwrite anywhere mask is True. @@ -140,11 +151,9 @@ def body(prompt, cache, index): def get_next_token(self, probabilities): """Get the next token. - Args: probabilities: a Tensor, the probability distribution for next token over all vocab tokens. - Get the next token based on given probability distribution over tokens. Subclasses must implement this method. """ @@ -155,4 +164,4 @@ def from_config(cls, config): return cls(**config) def get_config(self): - return {} + return {"temperature": self.temperature} diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 68f369a996..a3369c4bc0 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -64,8 +64,9 @@ def __init__( self, k=5, seed=None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.k = k self.seed = seed diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py index 10ce77b956..8f4fe8bd20 100644 --- a/keras_nlp/samplers/top_k_sampler_test.py +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -37,7 +37,7 @@ def next(prompt, cache, index): return logits, hidden_states, cache self.next = next - self.sampler = TopKSampler(k=5) + self.sampler = TopKSampler(k=5, temperature=1.0) def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()] diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index ec5c4bec8e..2ef90d3b50 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -73,8 +73,9 @@ def __init__( p=0.1, k=None, seed=None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.p = p self.k = k self.seed = seed diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index f06563ac58..728d9a62c1 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -122,6 +122,22 @@ def next(prompt, cache, index): output_ids = set(output[0].numpy()) self.assertContainsSubset(output_ids, range(3)) + def test_temperature(self): + def next(prompt, cache, index): + # Dummy hidden states. + hidden_states = tf.ones([self.batch_size, 5]) + logits = tf.range(self.vocab_size, 0, -1, dtype=tf.float32) + logits = tf.reshape(logits[tf.newaxis, :], (self.batch_size, -1)) + return tf.constant(logits), hidden_states, cache + + prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"]) + + output = TopPSampler(p=0.5, temperature=1e-9)( + next=next, + prompt=prompt, + ) + self.assertAllEqual(output, tf.zeros_like(output)) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) )