Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Caching Does Not Require Prompt Template and Other Ambiguities #1619

Open
ethan-digi opened this issue May 16, 2024 · 5 comments
Open
Labels

Comments

@ethan-digi
Copy link

I'm having difficulty understanding how to approach prompt-related performance tuning due to ambiguity on how TensorRT's prefix caching works.

My current understanding of prefix caching as implemented in TensorRT, based on the description in Issue 620, is the following:
Blocks shared between prompts will be re-used in later inferences using those prompts. Only issue is, how does this even differ from normal kv caching? Is there some sort of special logic applied to identify frequently reused blocks? My understanding is that, unless a block is evicted from the cache, it should automatically remain in memory to be reused, meaning that prefix caching should function by default when using PagedAttention. I suppose the change to enable prefix caching could have been just disabling any automatic eviction policies that were used prior.

What makes this even more confusing is that in the PagedAttention paper, prefix sharing is laid out to function as such:

For this type of application, many user prompts share a
prefix, thus the LLM service provider can store the KV cache
of the prefix in advance to reduce the redundant computation spent on the prefix ... this can be conveniently
achieved by reserving a set of physical blocks for a set of
predefined shared prefixes by the LLM service provider

This would indicate that in order to receive the benefits of prefix caching, that we need to submit a prompt template, which is not mentioned in the documentation. So it would seem there's an automatic system.

I'm wondering for the broader reason of understanding how slight changes in tokens in the prompt will affect performance.

Consider a prompt structured as: "Please ask [user] how their day is going. Be sure to greet them by name". If [user] changes every request, will the entire prompt be thrown out and regenerated, or if, supposing for sake of example, blocks contain of 2 tokens and each word is one token, would everything except "[user] how" be kept, and only that block reloaded?

@thorjohnsen
Copy link

If you have a prompt structured as "Please ask [user] how their day is going", and [user] changes every request, only the tokens "Please ask" will potentially be reused. This is subject to kv cache block size. Default block size is 64, but can be any power of two <= 128. In order to reuse a block, that block must be full and all tokens in the block must be a perfect match.

@thorjohnsen
Copy link

I am unaware of any feature that allows client to reserve cache pages to permanently store a set of predefined shared prefixes. Regular prefix caching will work for any shared prefix, but is subject to a few limitations that might limit how much benefit you see:

  1. Only full blocks that are a perfect match will be reused. For instance, if block size is 64 and your shared prefix has 180 tokens, only the first 128 tokens will be reused.
  2. Reusable KV cache blocks are evicted when the memory is needed for other purposes. On a system with high load, eviction is quite likely to occur, especially if some time passes between each use of the shared prefix.

@thorjohnsen
Copy link

We can add a feature to reserve memory for shared prefix if client(s) request it. It would be added as one more input, probably named "prefix_length" that is an int32 value that says that the first "prefix_length" tokens of the request are part of a shared prefix. This would have to be an "advisory" field, so we would reserve space for the field if we can. Besides making the prefix more permanent by delaying eviction, we could remove the "full block" limitation so that the entire prefix could be shared, not just the part that fits into a whole number of blocks.

@ethan-digi
Copy link
Author

ethan-digi commented May 28, 2024

@thorjohnsen thank you very much for your response. I think a prefix_length field, potentially with full block limitation removed, would be extremely useful. Essentially a way to indirectly sacrifice caching of blocks that are highly unlikely to be reused (e.g., end of the prompt) in order to have guaranteed caching of tokens that are known to be reused. I would like to assist with it but I don't think I have the time at the moment.

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days."

@github-actions github-actions bot added the stale label Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants