-
Notifications
You must be signed in to change notification settings - Fork 301
Speed top-p sampler up by only sampling from top-k tokens #980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed top-p sampler up by only sampling from top-k tokens #980
Conversation
Tested on CPU, with 25 iterations, after the fix it's 37s compared to 55s. |
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Dropped some comments.
keras_nlp/samplers/top_p_sampler.py
Outdated
Args: | ||
p: float, the `p` value of top-p. | ||
sample_from_top_k: int, defaults to None. If set, only sample from top |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just call this k
honestly (parallel with the same option in top-k), and in the description say that it is a heuristic cutoff point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought is having both p
and k
is a bit confusing to readers, here p
is the main arg while k
is a supporting role, using p
and k
combo kinda suggest they are parallel to each other.
No strong opinion here and agree sample_from_top_k
isn't a great name.
) | ||
self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) | ||
|
||
def test_sample_from_all_tokens(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this test adds to much. If we want to add a test here, it should test that our cutoff point is working.
e.g. p=1.0
, k=5
, uniform logits, assert that all outputs in the first five vocab options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case is just a sanity check on sampler call with k=None
works.
Agree the proposed test works better! changed.
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Think we should be more detailed in the docstring for this one.
keras_nlp/samplers/top_p_sampler.py
Outdated
Args: | ||
p: float, the `p` value of top-p. | ||
k: int, defaults to None. If set, only sample from top `k` tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... If set, this argument defines a heuristic "top-k" cutoff applied before the "top-p" sampling. All logits not in the top k
will be discarded, and the remaining logits will be sorted to find a cutoff point for p
. Setting this argument can significantly speed sampling by reducing the size of the sort.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
resolve #963