-
Notifications
You must be signed in to change notification settings - Fork 910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve inference speed of Santacoder and Starcoder (and others) #376
Comments
Some benchmarking results, comparing several implementations:
Note: Santacoder decode
Santacoder prefill
|
Starcoder decode
Starcoder prefill
|
@jlamypoirier Thanks for great investigation. Can you show me where did you implement the cuda graphs with dynamic size for SantaCoder? I wonder how it is implemented. |
Sorry for the late response, you can find my (messy) implementation in https://github.com/bigcode-project/transformers/blob/main/src/transformers/models/gpt_bigcode/inference_runner.py. Note that this version supports dynamic key lengths but not dynamic batch sizes. |
@jlamypoirier Amazing reports !! May I ask does sequence length indicate max_new_token? I got pretty high latency (about 4s) for starcoder when I set max_new_token=128 |
It's the time to generate one token. For full time you need to add prefill for context length and generate for range(context_length, context_length + max_new_tokens) |
@jlamypoirier These are great suggestions. Have any of these found their way upstream? edit: especially curious about
|
I did some extensive investigation, testing and benchmarking, and determined that the following is needed to speedup inference for the Bigcode models (and most of text-gen-inference models:
FlashAttention
for prefill only. This is recommended by the authors because theFlashAttention
kernel relies on a high query length to achieve good parallelization, and because FlashAttention needs a lot of extra work on the inputs/outputs/KV caches for each token.causal_lm
has a prototype implementation in [Prototype] Vectorized causal lm #272 (flash_causal_lm
is harder to vectorize, but according to the point abovecausal_lm
should be preferable.)FlashAttention
.) Padding the key length to a multiple of 8 also provides a high speedup, so N=8 is a bare minimum (though higher is better.details
(logprobs, prefill data, etc.) only when requested (Make generation details optional #288). These take a lot of time and force computing the whole model head (see 5. below), but the results are almost always thrown away.details
). This saves some time and more importantly avoids a memory bottleneck.nn.Module
because it adds a lot of bloat (hooks) on__call__
andgetattr
. In tests I was able to reduce the santacoder min latency by more than 20% in this way.Future work (more investigation needed):
out
argument, so some (easy) cpp work would be needed.Block
) in cpp. This would not be too hard and would help a lot with the latency for smaller models, but may not be needed if cuda graphs are used.concatenate
/filter
due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on everyfilter
, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway.The text was updated successfully, but these errors were encountered: