Skip to content

Conversation

@gante
Copy link
Collaborator

@gante gante commented Jun 25, 2023

⚠️ do not merge!

This is an experimental PR that rewrites the generation code to mimic what we do in XLA JAX/TF.

It allows a fixed cache size and avoids dynamic slicing before torch.nn.functional.scaled_dot_product_attention, allowing compilation with fullgraph=True. The trick consists in keeping the prefill step with dynamic shapes (were dynamic slicing is not needed), and all subsequent steps will append the latest KV values in the last position in the fixed-shaped tensors. Between each generation step, the values are moved into the right position, so we can keep appending new values at the end.

Learnings

PT version: torch==2.1.0.dev20230621+cu118

  • We can compile this version with fullgraph=True, keeping the correctness of the results;
  • It is slower than the existing main, especially when compiled: ~3% slower without compilation, ~10% slower with compilation;
  • Causes for the slowdown vs main, as seen in the profiler (comparing python scripts/run_llama.py --model huggingface/llama-7b --preallocate --profile runs):
    • there is extra logic between each generation step, accounting for extra >1ms/token
    • scaled_dot_product_attention, because it now needs the attention mask, takes extra >100us/layer, or >3ms/token

@fxmarty
Copy link
Owner

fxmarty commented Jun 26, 2023

scaled_dot_product_attention, because it now needs the attention mask, takes extra >100us/layer, or >3ms/token

Maybe you already know but AFAIK SDPA does not currently dispatch to flash attention/mem efficient attention when an attention_mask is passed, see pytorch/pytorch#96099 (comment) & https://huggingface.slack.com/archives/C046RST834Y/p1679584564826879, so that may be the reason for slowdowns.

allowing compilation with fullgraph=True

No more issues with the VariableBuilder.init() error?

@gante
Copy link
Collaborator Author

gante commented Jun 27, 2023

No more issues with the VariableBuilder.init() error?

Nope (and I didn't even change the PT version 👀 )

@fxmarty
Copy link
Owner

fxmarty commented Jun 27, 2023

@gante (I am talking not about an actual error, but the error that shows in logging.DEBUG mode only)

@gante
Copy link
Collaborator Author

gante commented Jun 27, 2023

@fxmarty I don't see it (but maybe I'm not looking in the right place)

@gante
Copy link
Collaborator Author

gante commented Jun 27, 2023

Using the plot facilities from #12 (and using the plots in that PR as a reference for the performance in main)

batch size sweep

llama_sweep_58bb48d_batch

prompt length sweep

llama_sweep_58bb48d_length

performance conclusions

  • For batch size = 1 it's roughly on par with main
  • For batch size > 1 it's slower than main
  • This PR has lower memory consumption, especially when compiled (with batch size=4, 16223MB vs 16895MB on main)

@fxmarty
Copy link
Owner

fxmarty commented Jun 28, 2023

Interesting, great work! I assume for batch size > 1 we may get more and more compute bound and so having a huge empty KV cache used for actual compute is not that great?

@fxmarty
Copy link
Owner

fxmarty commented Jun 28, 2023

@fxmarty I don't see it (but maybe I'm not looking in the right place)

I have the error log using

import logging
torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, inductor=logging.DEBUG)

and inspecting logs from python run_llama.py --model huggingface/llama-7b --preallocate --compile static &> compile.log

I have a TypeError: VariableBuilder.__init__() got an unexpected keyword argument 'guards' in the logs.

@gante
Copy link
Collaborator Author

gante commented Jun 28, 2023

I assume for batch size > 1 we may get more and more compute bound and so having a huge empty KV cache used for actual compute is not that great?

My thoughts as well. Dynamic slicing (as in main) forces us out of graph mode, but there are benefits. Perhaps this difference will diminish as PT works to get a better compiler 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants