-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize sample_topp by filtering out small value elements up front #276
Conversation
This works because we know that in worst case only 1 element will be selected and therefore the remaining (n-1) elements have to split the remaining (1-topp) probability. Probabilities smaller than that cannot be selected and can be filtered out up front.
Ah, didn't see #270 and #274 came first with the same idea :) For fun, I ran a code generation model on llama2.scala for a while and asked it for suggestions: https://gist.github.com/jrudolph/fb7641ba2406de705c5499280783b55c The suggested algorithms are of sometimes comically bad quality but some ideas seem interesting:
|
I think after the filtering it doesn't matter much any more, after all the speed improvements are only needed anyway for small models (since sampling speed only depends on vocabulary size regardless of model size). |
Ok, in Scala, the effect of scanning is improving just the top-p selection process by another 10x (but that's also because the naive idiomatic sorting involves a high abstraction overhead due to boxing).
|
Thank you for a nice PR! |
Here's a small report on my experiments of trying out different top-p algorithms: https://blog.virtual-void.net/2023/08/29/calculating-top-p/ |
optimize sample_topp by filtering out small value elements up front
please also consider #313, constant cut-off. |
Refs #246
This works because we know that in worst case only 1 element will be selected and therefore the remaining (n-1) elements have to split the remaining (1-topp) probability. Probabilities smaller than that cannot be selected and can be filtered out up front.
E.g. for p = 0.9 that means that usually only 100-1000 tokens remain, speeding up the remaining process considerably.
(In llama2.scala, I further improved on that by avoiding the sort in most cases, based on the observation that the distribution looks like power-law and only very few elements will be selected ultimately, so that iteratively scanning over the array to find the next best element (kind of selection sort) keeping track of cumulative p seems to be a slightly better solution yet).