-
Notifications
You must be signed in to change notification settings - Fork 2
Experiment: append current KV last #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Maybe you already know but AFAIK SDPA does not currently dispatch to flash attention/mem efficient attention when an attention_mask is passed, see pytorch/pytorch#96099 (comment) & https://huggingface.slack.com/archives/C046RST834Y/p1679584564826879, so that may be the reason for slowdowns.
No more issues with the VariableBuilder.init() error? |
Nope (and I didn't even change the PT version 👀 ) |
|
@gante (I am talking not about an actual error, but the error that shows in logging.DEBUG mode only) |
|
@fxmarty I don't see it (but maybe I'm not looking in the right place) |
|
Using the plot facilities from #12 (and using the plots in that PR as a reference for the performance in batch size sweepprompt length sweepperformance conclusions
|
|
Interesting, great work! I assume for batch size > 1 we may get more and more compute bound and so having a huge empty KV cache used for actual compute is not that great? |
I have the error log using import logging
torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, inductor=logging.DEBUG)and inspecting logs from I have a |
My thoughts as well. Dynamic slicing (as in |


This is an experimental PR that rewrites the generation code to mimic what we do in XLA JAX/TF.
It allows a fixed cache size and avoids dynamic slicing before
torch.nn.functional.scaled_dot_product_attention, allowing compilation withfullgraph=True. The trick consists in keeping the prefill step with dynamic shapes (were dynamic slicing is not needed), and all subsequent steps will append the latest KV values in the last position in the fixed-shaped tensors. Between each generation step, the values are moved into the right position, so we can keep appending new values at the end.Learnings
PT version:
torch==2.1.0.dev20230621+cu118fullgraph=True, keeping the correctness of the results;main, especially when compiled: ~3% slower without compilation, ~10% slower with compilation;main, as seen in the profiler (comparingpython scripts/run_llama.py --model huggingface/llama-7b --preallocate --profileruns):scaled_dot_product_attention, because it now needs the attention mask, takes extra >100us/layer, or >3ms/token