Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve inference speed of Santacoder and Starcoder (and others) #376

Open
jlamypoirier opened this issue May 29, 2023 · 7 comments
Open

Comments

@jlamypoirier
Copy link
Contributor

jlamypoirier commented May 29, 2023

I did some extensive investigation, testing and benchmarking, and determined that the following is needed to speedup inference for the Bigcode models (and most of text-gen-inference models:

  1. Use FlashAttention for prefill only. This is recommended by the authors because the FlashAttention kernel relies on a high query length to achieve good parallelization, and because FlashAttention needs a lot of extra work on the inputs/outputs/KV caches for each token.
  2. Vectorize as much pre/post-processing operations as possible, i.e. avoid loops (especially for cuda ops). The warpers / logit processors have already been vectorized in feat(server): support vectorized warpers in flash causal lm #317, and the rest of causal_lm has a prototype implementation in [Prototype] Vectorized causal lm #272 (flash_causal_lm is harder to vectorize, but according to the point above causal_lm should be preferable.)
  3. Perform some form of KV cache pre-allocation and key length padding to a multiple of 8. A complete, static pre-allocated tensor adds complications because of the need to concatenate/filter batches, but it's easy to pre-allocate only a few tokens in advance to run the slow concatenation on every N tokens instead of all of them. (Again, this is not doable with FlashAttention.) Padding the key length to a multiple of 8 also provides a high speedup, so N=8 is a bare minimum (though higher is better.
  4. Compute the details (logprobs, prefill data, etc.) only when requested (Make generation details optional #288). These take a lot of time and force computing the whole model head (see 5. below), but the results are almost always thrown away.
  5. Compute the model head only for the last token in prefill (unless we do need them for details). This saves some time and more importantly avoids a memory bottleneck.
  6. Use deterministic generation only when a seed is provided. Otherwise, sampling needs to be done in a loop because Pytorch doesn't support vectorized generators.
  7. Trim the python code. Avoid any unnecessary function call (use inline when possible), attribute getting, etc., as these end up contributing a lot to the CPU latency. Avoid subclassing nn.Module because it adds a lot of bloat (hooks) on __call__ and getattr. In tests I was able to reduce the santacoder min latency by more than 20% in this way.

Future work (more investigation needed):

  1. Try and compare more fused kernels. For fused softmax compare Jit (used in [Prototype] Vectorized causal lm #272) and Megatron's implementation (probably better). Compare fused and standard layer norm (results below seem to go against fused). Try fused dense (with gelu) in MLP (or try Jit?)
  2. Reduce memory allocations by pre-allocating and/or reusing tensors. The main obstacle is that many operations still don't support the out argument, so some (easy) cpp work would be needed.
  3. Write the cpu-intensive part (Block) in cpp. This would not be too hard and would help a lot with the latency for smaller models, but may not be needed if cuda graphs are used.
  4. Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway.
  5. Look more into tensor parallelism. I know it's already implemented in text-gen-inference, but I haven't looked into it myself.
@jlamypoirier
Copy link
Contributor Author

jlamypoirier commented May 29, 2023

Some benchmarking results, comparing several implementations:

  1. flash: flash_santacoder, the current implementation.
  2. causal: The gpt_bigcode model from HF transformers, run with causal_lm.
  3. vector: The gpt_bigcode model from HF transformers, run with vectorized_causal_lm from [Prototype] Vectorized causal lm #272. (Opt. 2 above).
  4. bigcode: The gpt_bigcode model from the Bigcode transformers repo, with minor adaptations and trimming to work with text-gen-inference and vectorized_causal_lm (Opt. 1, 3, 4, 5, 6)
  5. bigcode2: bigcode with some additional optimizations taken from flash_santacoder, mainly the FastLinear and FastLayerNorm layers. Also some simplifications on the attention mask.
  6. bigcode3: bigcode2 with a trimmed python code (Opt. 7)

Note: flash and causal are based on commit 5a58226 (May 16th) so may be missing the latest optimizations.
Also note: curves are smoothed out, otherwise they oscillate wildly without key length padding (causal and vector)

Santacoder decode

  • For batch size=1, CPU is always the bottleneck. flash is the fastest, and there is a huge difference between bigcode1/2/3. Megatron's fused softmax might bring bigcode3 and flash nearly on par (I still expect flash to be faster because it has fewer kernels)
  • flash and causal are really bad at high batch size, especially for long sequences. This is attributable to non-vectorized operations and the poor performance of FlashAttention.
  • vector already brings down the batch size overhead to a minimum.
  • bigcode1/2/3 show additional improvements.
  • Surprisingly, bigcode2/3 are slower than bigcode for bs=256and large sequences. Attributable to sub-optimal fused layer norm?

santacoder_bs_1_tok_2040_decode_step_5_10
santacoder_bs_32_tok_2040_decode_step_5_10
santacoder_bs_256_tok_2040_decode_step_5_10

Santacoder prefill

  • causal and vector are really bad (no FlashAttention)
  • flash is not that great either, it seems attributable to the lots of processing in generate_token.
  • bigcode1/2/3 work the best and are bery similar (except for bs=1 when CPU-bound). bigcode2/3 are marginally better in general (because of fused layer norm?)

santacoder_bs_1_tok_2040_prefill_step_11_10
santacoder_bs_32_tok_2040_prefill_step_11_10
santacoder_bs_256_tok_2040_prefill_step_11_10

@jlamypoirier
Copy link
Contributor Author

jlamypoirier commented May 29, 2023

Starcoder decode

  • Similar to Santacoder, but flash is already inefficient at a batch size of 1, often even worse than causal.
  • Latency for small batch sizes is bottlenecked from reading the weights,(15.5e9 params * 2B/param / 2039e9B/s = 15.2 ms), so tensor parallelism would likely reduce it.
  • causal goes crazy for large sequences, not sure why.
  • Again, bigcode2/3 are worse than bigcode, suspecting the fused layer norm.
  • For batch size 256, the times at small seqlen are higher than for smaller batch sizes, suggesting reading the weights is no longer the bottleneck.

starcoder_bs_1_tok_8190_decode_step_11_10
starcoder_bs_32_tok_8190_decode_step_11_10
starcoder_bs_256_tok_8190_decode_step_11_10

Starcoder prefill

  • Similar to Santacoder.
  • bigcode2/3 are marginally faster than bigcode but run out of memory faster.

starcoder_bs_1_tok_8190_prefill_step_29_1
starcoder_bs_32_tok_8190_prefill_step_29_1

@huyphan168
Copy link

@jlamypoirier Thanks for great investigation.
"""Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway."""

Can you show me where did you implement the cuda graphs with dynamic size for SantaCoder? I wonder how it is implemented.

@jlamypoirier
Copy link
Contributor Author

@jlamypoirier Thanks for great investigation. """Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway."""

Can you show me where did you implement the cuda graphs with dynamic size for SantaCoder? I wonder how it is implemented.

Sorry for the late response, you can find my (messy) implementation in https://github.com/bigcode-project/transformers/blob/main/src/transformers/models/gpt_bigcode/inference_runner.py. Note that this version supports dynamic key lengths but not dynamic batch sizes.

@aliswel-mt
Copy link

aliswel-mt commented Jul 3, 2023

@jlamypoirier Amazing reports !! May I ask does sequence length indicate max_new_token? I got pretty high latency (about 4s) for starcoder when I set max_new_token=128

@jlamypoirier
Copy link
Contributor Author

@jlamypoirier Amazing reports !! May I ask does sequence length indicate max_new_token? I got pretty high latency (about 4s) for starcoder when I set max_new_token=128

It's the time to generate one token. For full time you need to add prefill for context length and generate for range(context_length, context_length + max_new_tokens)

@truenorth8
Copy link

truenorth8 commented Oct 14, 2023

@jlamypoirier These are great suggestions. Have any of these found their way upstream?
If not, is your version available anywhere?

edit: especially curious about

Compute the model head only for the last token in prefill (unless we do need them for details). This saves some time and more importantly avoids a memory bottleneck.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants