Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
db66596
added temperature param and a small test
TheAthleticCoder Mar 31, 2023
478b912
added config changes
TheAthleticCoder Mar 31, 2023
d6d9d7e
changed approach
TheAthleticCoder Mar 31, 2023
c397077
made requested changes
TheAthleticCoder Apr 1, 2023
9a4d332
added topP
TheAthleticCoder Apr 1, 2023
9042c74
updated the topP test
TheAthleticCoder Apr 1, 2023
e1ef638
updated beam sampler
TheAthleticCoder Apr 1, 2023
3121186
modified test for beam sampler
TheAthleticCoder Apr 1, 2023
0f414e6
fixed temperature
TheAthleticCoder Apr 6, 2023
c56189f
new tests
TheAthleticCoder Apr 6, 2023
815f0e8
new tests
TheAthleticCoder Apr 6, 2023
edbd950
added new test
TheAthleticCoder Apr 6, 2023
2cb9b30
added new test
TheAthleticCoder Apr 6, 2023
e90774a
attempt
TheAthleticCoder Apr 6, 2023
cda6cc9
edited temperature values
TheAthleticCoder Apr 6, 2023
2d46d1e
changed temp further
TheAthleticCoder Apr 6, 2023
40980f7
changed prompt
TheAthleticCoder Apr 6, 2023
488385a
changed state
TheAthleticCoder Apr 6, 2023
8a137d6
changed state
TheAthleticCoder Apr 6, 2023
56148b4
more attempts
TheAthleticCoder Apr 6, 2023
1f0b4f9
revert the tests
TheAthleticCoder Apr 6, 2023
b01b0b1
revert the tests
TheAthleticCoder Apr 6, 2023
30a850a
added required changes
TheAthleticCoder Apr 13, 2023
2de1a28
fixed greedy sampler
TheAthleticCoder Apr 13, 2023
4ab17a1
corrected top p test
TheAthleticCoder Apr 13, 2023
6d67192
Merge branch 'master' into issue947
TheAthleticCoder Apr 13, 2023
b5ddab4
pushed changes to random sampler
TheAthleticCoder Apr 13, 2023
b2e6adc
updated test
TheAthleticCoder Apr 13, 2023
4591cf8
updated test
TheAthleticCoder Apr 13, 2023
184d1b8
chained up kwargs to super
TheAthleticCoder Apr 13, 2023
8da1a43
Merge branch 'master' into issue947
mattdangerw Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def __init__(
self,
num_beams=5,
return_all_beams=False,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.num_beams = num_beams
self.return_all_beams = return_all_beams

Expand Down Expand Up @@ -156,7 +157,7 @@ def body(prompt, cache, index, log_probs):
# Compute the softmax distribution for the next token.
logits, _, cache = next(prompt, cache, index)
vocab_size = tf.shape(logits)[-1]
probs = keras.activations.softmax(logits)
probs = keras.activations.softmax(logits / self.temperature)

# Compute the running log-likelihood of each new candidate.
next_log_probs = tf.math.log(probs) + log_probs[..., tf.newaxis]
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/beam_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def next(prompt, cache, index):
return logits, hidden_states, cache

self.next = next
self.sampler = BeamSampler(num_beams=5)
self.sampler = BeamSampler(num_beams=5, temperature=1.0)
self.sampler_all_beams = BeamSampler(num_beams=5, return_all_beams=True)

def join_as_string(self, x):
Expand Down
7 changes: 5 additions & 2 deletions keras_nlp/samplers/greedy_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def next(prompt, cache, index):
```
"""

def __init__(self):
super().__init__()
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)

def get_next_token(self, probabilities):
return tf.argmax(probabilities, axis=-1)
2 changes: 1 addition & 1 deletion keras_nlp/samplers/greedy_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def next(prompt, cache, index):
return logits, hidden_states, cache

self.next = next
self.sampler = GreedySampler()
self.sampler = GreedySampler(temperature=1.0)

def join_as_string(self, x):
return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()]
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def next(prompt, state, index):
def __init__(
self,
seed=None,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.seed = seed

def get_next_token(self, probabilities):
Expand Down
18 changes: 17 additions & 1 deletion keras_nlp/samplers/random_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def next(prompt, cache, index):
return logits, hidden_states, cache

self.next = next
self.sampler = RandomSampler()
self.sampler = RandomSampler(temperature=1.0)

def join_as_string(self, x):
return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()]
Expand Down Expand Up @@ -71,6 +71,22 @@ def test_stateful_call(self):
)
self.assertEqual(self.join_as_string(output), ["sequentially"])

def test_temperature(self):
def next(prompt, cache, index):
# Dummy hidden states.
hidden_states = tf.ones([self.batch_size, 5])
logits = tf.range(self.vocab_size, 0, -1, dtype=tf.float32)
logits = tf.reshape(logits[tf.newaxis, :], (self.batch_size, -1))
return tf.constant(logits), hidden_states, cache

prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])

output = RandomSampler(temperature=1e-5)(
next=next,
prompt=prompt,
)
self.assertAllEqual(output, tf.zeros_like(output))

def test_early_stopping(self):
cache_chars = list("sequentially")
cache = tf.constant([[self.char_lookup[c] for c in cache_chars]])
Expand Down
17 changes: 13 additions & 4 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
class Sampler:
"""Base sampler class.

Args:
temperature: float. optional. defaults to '1.0'. Used to control the
randomness of the sampling. The higher the temperature, the
more diverse the samples.

Call Args:
{{call_args}}

Expand Down Expand Up @@ -84,6 +89,12 @@ def next(prompt, cache, index):
```
"""

def __init__(
self,
temperature=1.0,
):
self.temperature = temperature

def __call__(
self,
next,
Expand Down Expand Up @@ -115,7 +126,7 @@ def cond(prompt, cache, index):
def body(prompt, cache, index):
# Compute the softmax distribution for the next token.
logits, _, cache = next(prompt, cache, index)
probabilities = keras.activations.softmax(logits)
probabilities = keras.activations.softmax(logits / self.temperature)
# Compute the next token.
next_token = self.get_next_token(probabilities)
# Don't overwrite anywhere mask is True.
Expand All @@ -140,11 +151,9 @@ def body(prompt, cache, index):

def get_next_token(self, probabilities):
"""Get the next token.

Args:
probabilities: a Tensor, the probability distribution for next
token over all vocab tokens.

Get the next token based on given probability distribution over tokens.
Subclasses must implement this method.
"""
Expand All @@ -155,4 +164,4 @@ def from_config(cls, config):
return cls(**config)

def get_config(self):
return {}
return {"temperature": self.temperature}
3 changes: 2 additions & 1 deletion keras_nlp/samplers/top_k_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def __init__(
self,
k=5,
seed=None,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.k = k
self.seed = seed

Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/top_k_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def next(prompt, cache, index):
return logits, hidden_states, cache

self.next = next
self.sampler = TopKSampler(k=5)
self.sampler = TopKSampler(k=5, temperature=1.0)

def join_as_string(self, x):
return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()]
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/samplers/top_p_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def __init__(
p=0.1,
k=None,
seed=None,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.p = p
self.k = k
self.seed = seed
Expand Down
16 changes: 16 additions & 0 deletions keras_nlp/samplers/top_p_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@ def next(prompt, cache, index):
output_ids = set(output[0].numpy())
self.assertContainsSubset(output_ids, range(3))

def test_temperature(self):
def next(prompt, cache, index):
# Dummy hidden states.
hidden_states = tf.ones([self.batch_size, 5])
logits = tf.range(self.vocab_size, 0, -1, dtype=tf.float32)
logits = tf.reshape(logits[tf.newaxis, :], (self.batch_size, -1))
return tf.constant(logits), hidden_states, cache

prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])

output = TopPSampler(p=0.5, temperature=1e-9)(
next=next,
prompt=prompt,
)
self.assertAllEqual(output, tf.zeros_like(output))

@parameterized.named_parameters(
("jit_compile_false", False), ("jit_compile_true", True)
)
Expand Down