-
Notifications
You must be signed in to change notification settings - Fork 301
Adding a temperature argument to the base sampler class and related tests #951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
keras_nlp/samplers/top_k_sampler.py
Outdated
self, | ||
k=5, | ||
seed=None, | ||
temperature=1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you do the changes suggested above, you won't have to do anything beyond passing a **kwargs
through.
|
||
self.next = next | ||
self.sampler = TopKSampler(k=5) | ||
self.sampler_temperature = TopKSampler(k=5, temperature=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just define this in the test below, as it is only used for one test.
next=next, | ||
prompt=prompt, | ||
) | ||
output_ids = set(output[0].numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to do set here, just self.assertAllEqual(output, [0] * self.length)
or something like that.
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! I think we just need to merge this with the latest changes on master.
For beam search, top-k, greedy, contrastive, we can probably just very pass temperature=1.0
in setup
for the test sampler, as the output won't change for these "top" samplers. That way our test files stay short.
For top-p and random we can include a simple test that assert that a very low temp will lead to deterministic output.
output_ids = set(output[0].numpy()) | ||
self.assertContainsSubset(output_ids, range(3)) | ||
|
||
def test_outputs_in_top_p_with_temperature(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could make this simpler, we are just trying to test that temperature is working. Something like (haven't run this)...
def test_temperature(self):
def next(prompt, state, index):
# Return a distribution where each id is progressively less likely.
logits = tf.range(self.vocab_size, 0, -1, dtype="float32")
logits = tf.repeat(logits[tf.newaxis, :], self.batch_size, axis=0)
return tf.constant(logits), state
prompt = tf.fill((self.batch_size, self.length), self.char_lookup["z"])
# Use a very low temperature, so we always sample the first token.
output = TopPSampler(p=0.5, temperature=1e-6)(
next=next,
prompt=prompt,
)
self.assertAllEqual(output, tf.zero_like(prompt))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome this looks great! Found one last small comment, and this needs to be merged or rebased with the master branch. Then let's land!
Thanks! Will resolve merge conflicts as I merge this. |
@mattdangerw Should I add the temperature arg to the contrastive sampler as well in this thread? Or will we be opening another issue for the same? |
Follow up is fine, I will merge this now. |
/gcbrun |
Partially resolves #947
Currently working on
sampler.py
and added a small unit test totop_k
file.