Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: New Cache abstraction and Attention Sinks support #26681

Merged
merged 35 commits into from Dec 8, 2023

Conversation

tomaarsen
Copy link
Member

@tomaarsen tomaarsen commented Oct 9, 2023

Closes #26553

Hello!

What does this PR do?

I had a few hours on Saturday to work up a draft version of the updated KV caching mechanism as discussed in #26553. Ideally, this should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly.

The implementation doesn't work well yet, as the VRAM usage quickly shoots up after generating even just 8 tokens. This is probably some bug that I haven't had time for yet. There's a few other comments that I have on specific sections of code, so I'll write some comments below.

Goal for this draft

The intention for this draft is to continue discussion about whether this is moving in the right direction, and to determine the scope (e.g. do we want to include this updated Cache for all architectures that use KV caching?).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten
@gante
@LysandreJik
@Guangxuan-Xiao

  • Tom Aarsen

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some first nits, thanks for kick-starting this effort!

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
@tomaarsen
Copy link
Member Author

@patrickvonplaten @gante
Based on your feedback, I intend to make the following changes (when I have a bit more time).

  1. move layer_idx to update(key_states, value_states, layer_idx) rather than storing it as a class attribute on Cache. This also involves adding layer_idx as class attributes to e.g. LlamaAttention and LlamaDecoderLayer. This also removes the set_layer_index magic from Cache.
  2. convert past_key_values to Cache instance (if not already) at the start of LlamaAttention.forward. This avoids all isinstance(...) calls in LlamaAttention, and removes the need for the black magicky __bool__.
  3. convert back to tuple of tuples when returning if some use_legacy_cache flag is True. (should this flag be propagated all the way up to LlamaModel?)
  4. use separate key_cache and value_cache dicts in the Cache instance for efficiency (removes a torch.cat) and simplicity.

Regarding #26681 (comment) I would have to experiment. There might be options, but they'll probably be slower than necessary.

@freckletonj
Copy link

past_key_values and caching will be an incredible feature.

Two PRs should be made irrelevant if this merges: #25086 #17574

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten @gante Based on your feedback, I intend to make the following changes (when I have a bit more time).

  1. move layer_idx to update(key_states, value_states, layer_idx) rather than storing it as a class attribute on Cache. This also involves adding layer_idx as class attributes to e.g. LlamaAttention and LlamaDecoderLayer. This also removes the set_layer_index magic from Cache.
  2. convert past_key_values to Cache instance (if not already) at the start of LlamaAttention.forward. This avoids all isinstance(...) calls in LlamaAttention, and removes the need for the black magicky __bool__.
  3. convert back to tuple of tuples when returning if some use_legacy_cache flag is True. (should this flag be propagated all the way up to LlamaModel?)
  4. use separate key_cache and value_cache dicts in the Cache instance for efficiency (removes a torch.cat) and simplicity.

Regarding #26681 (comment) I would have to experiment. There might be options, but they'll probably be slower than necessary.

Sounds great!

@tomaarsen
Copy link
Member Author

tomaarsen commented Oct 12, 2023

Addressed the various comments. Beyond that, I also made past_key_values.update(key, value, idx) returns key, value as you'd expect. Manual generation (i.e. repeated calling of a LlamaForCausalLM instance with the past_key_values) works well, I even see a ~2% speedup, but don't quote me on that speedup. model.generate doesn't work yet because use_legacy_flag defaults to None, i.e. Falsey, and model.generate doesn't work with the cache yet. @patrickvonplaten @gante Should we go for:

  1. Immediately update model.generate to work with Cache instances or,
  2. Insert use_legacy_cache=True as default for model.generate until this can be removed in some later PR?

This is all under the assumption that the PR is heading in the right direction 😄

As a heads up, the sink cache does not work yet. I still need to do some experiments to see if I can store rotated keys and back-rotate + forward-rotate them when the cached keys are requested, rather than storing non-rotated keys. That is what I'll work on next.

That leaves me with an additional question: each architecture requires slightly different key rotations. I'd like to implement this to be sufficiently adaptable, e.g. allowing architecture-specific functionality in src/transformers/models/<architecture>/cache_utils.py. However, what the relation to classes or functions in this file should be to the src/transformers/cache_utils.py cache classes is still unclear to me. In short: how do we allow an architecture to e.g. have slightly different implementations for the SinkCache or the DynamicCache?

  • Tom Aarsen

@gante
Copy link
Member

gante commented Oct 13, 2023

Hey @tomaarsen 👋

generate compatibility

Re generate compatibility: usually, I'm pro small PRs, tackling one problem at a time. However, since generate is the main gate to LLMs in transformers, and caching is only needed for auto-regressive generation, day 0 support is important -- at least gready_search and sample should be operational by the time this PR can be merged. Otherwise, we might get ourselves in a position where we realize we need to rewrite a significant part of Cache/generate to enable them together, which is very undesirable. However, I'm pro having this PR enabling the new cache on a single model (Llama) and taking care of the other models later as needed 🤗 I can also give a hand on the generate side, when we are happy with the state of the PR with the exception of generate.

model-specific cache changes

I do expect some models to require different code to convert back and forth from the legacy cache format (not all models have the same cache format, there are ~5% of them have slight differences e.g. see this in BLOOM). You are also writing that RoPE-based models may also need custom logic.

The model-specific cache modification part is an important design decision that will have an impact over many versions of transformers to come, so I'd also like to hear your thoughts @tomaarsen @patrickvonplaten

EDIT: what follows bellow is my "plan B" suggestion, read my next comment for the "plan A" :D
To me, that suggests five things:

  1. a model may need to subclass the base Cache to implement its own nuances regarding the cache format
  2. despite the above, the cache operations and output format are standardized for each type of cache, so we would benefit from a strong base class (as we do in the config files or in the base pretrained class).
  3. I suspect the model-specific changes for the cache would be small, so we could keep them in the modeling file. We can always move to a new file if needed :)
  4. Because of 1), each model would have to implement a thin wrapper class to the instantiable cache classes, even if the base class is fully compatible. This also goes in line with our philosophy in tramsformers where each model implements its own model-specific operations.
  5. Because of 4), in practice, each model defines its own cache. This means users are free to write their own custom caches -- power to the users 💪

@gante
Copy link
Member

gante commented Oct 13, 2023

Upon some shower thoughts, I've come across an alternative plan for the model-specific cache modification problem -- easier to implement and that would result in more readable code.

Instead of the base Cache holding the code to convert to and from the legacy format (which then requires subclassing at each model, if it has a different legacy cache format), the conversion logic could be held in model-specific functions to convert to and from the legacy format. In other words, each model would implement a to_legacy_cache and a from_legacy_cache, and the different types of Cache would be immediately available to the model with no further modifications!

@patrickvonplaten
Copy link
Contributor

lding the code to convert to and from the legacy format (which then requires subclassing at each model, if it has a different legacy cache format)

Agree here! Think it would be better if we only have a few selected cache classes that work for all models. The functions from_legacy_cache and to_legacy_cache indeed need to be model specific, so I think they can be just stand-alone functions in each modeling_....py file.

=> Thus, I think we can:

  • a) Have all Cache classes in a general generation/cache.py file. These general cache implementations should be identical for all models
  • b) For backwards compatibility each model needs specific from_legacy_cache and to_legacy_cache functions, but they don't need to be part of the Cache class IMO, they can just be stand-alone functions in each model class so that they are easier to deprecate going forward.

@tomaarsen
Copy link
Member Author

I just noticed that Joao's message has been edited, and I missed Patrick's message, so my response doesn't make much sense anymore - I deleted it.

I also prefer the from_legacy_cache and to_legacy_cache implementations. I'll get on it.

@liangan1
Copy link

@gante @tomaarsen This is really a good abstraction of kv_cache to enable window context and be compatible to the legacy kv_cache. But for the memory footprint, the 'torch.cat' are still needed when update the cache and reorder_cache using 'index_select' is also there with beam search.
Do you have a plan to use the pre-allocate buffer to avoid 'torch.cat'? The pre-allocated buffer should be compatible with semantic of your Cache. For the cache.update operation, token slots from the pre-allocated buffer are needed to store key/value token states.

@tomaarsen
Copy link
Member Author

I removed the commits regarding changing the order of the caches based on @gante's recommendation. He is working on more changes for this PR here: tomaarsen#1

@gante
Copy link
Member

gante commented Nov 10, 2023

@liangan1 both are issues in the near-future roadmap, with this cache abstraction being a requirement! 🙌

First we will work on pre-allocated buffers (a subclass of Cache), which result in significant speedups (especially with torch.compile). Then we will simplify beam search into an XLA-friendly method, just like in our TF and JAX implementations.

@liangan1
Copy link

liangan1 commented Nov 14, 2023

@liangan1 both are issues in the near-future roadmap, with this cache abstraction being a requirement! 🙌

First we will work on pre-allocated buffers (a subclass of Cache), which result in significant speedups (especially with torch.compile). Then we will simplify beam search into an XLA-friendly method, just like in our TF and JAX implementations.

Wow, cool. Can you share more details about your roadmap? e.g., the pre-allocate buffers. In fact, we also have implemented a indirect access kv_cache in Intel-extension-for-pytorch optimization for LLM which can be used for both greedy and beam and we are pleasure to contribute code if I can know your design about the pre-allocated buffer.

Comment on lines +293 to +295
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
Copy link
Contributor

@stas00 stas00 Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this warning needs some TLC - the current "will to errors" doesn't parse. Thank you!

did you mean "will lead to errors" perhaps? i.e. missing "lead"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! I'll take care of it. Thanks @stas00

Copy link
Member

@gante gante Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomaarsen FYI adding this fix to the next cache PR (adding support to encoder-decoder models), no need to open a PR :)

@stas00 thanks for flagging! 🤗

Comment on lines +841 to +842
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomaarsen When does this case happen?

@fayejf
Copy link

fayejf commented Apr 9, 2024

@tomaarsen Great work! Looks like the current implementation cannot handle case input_length > window_length at least for some models (mistral)
Do I miss anything? Thanks!

cache = SinkCache(window_length=1024, num_sink_tokens=4)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, 
                         use_cache=True,
                         past_key_values=cache, 
                         pad_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

Supporse inputs is of torch.Size([1, 15759]). It will raise
RuntimeError: shape '[-1, 15759]' is invalid for input of size 1024
when it trys to reshape position_ids

File ~/anaconda3/envs/llm/lib/python3.8/site-packages/transformers/models/mistral/modeling_mistral.py:984, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    982     position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    983 else:
--> 984     position_ids = position_ids.view(-1, seq_length).long()

@ArthurZucker
Copy link
Collaborator

Not sure it was tested fro mistral! Can you open a separate issue with a full reproducer? 🤗

@MoonRide303
Copy link

@fayejf You might want to look at this paper, too - handling extremely long inputs in a streaming fashion: Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention.

1B model that was fine-tuned on up to 5K sequence length passkey instances solved the 1M length problem.

@explanare
Copy link

@ArthurZucker It seems that even with Llama2, passing in SinkCache to generate causes errors.

I'm using transformers 4.39.3 and the Llama2 model was loaded using the following code:

from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-hf"
model = LlamaForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

The SinkCache was passed to generate as in @fayejf's script. I am not sure if this is the correct way to use SinkCache:

from transformers import SinkCache

prefix = 'Hello world!'
inputs = tokenizer(prefix, return_tensors='pt').to(device)

cache = SinkCache(window_length=1024, num_sink_tokens=4)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=256,
                         use_cache=True,
                         past_key_values=cache,
                         pad_token_id=tokenizer.pad_token_id)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

The code caused TypeError: object of type 'SinkCache' has no len() as a result of this DynamicCache.from_legacy_cache call (see the stack trace below). Looks like you are familiar with the StaticCache stuff, any suggestions on how to get around this? Thanks in advance!

File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:977, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    975 if use_cache:  # kept for BC (cache positions)
    976     if not isinstance(past_key_values, StaticCache):
--> 977         past_key_values = DynamicCache.from_legacy_cache(past_key_values)
    978         past_seen_tokens = past_key_values.get_seq_length()
    980 if cache_position is None:

File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/cache_utils.py:181, in DynamicCache.from_legacy_cache(cls, past_key_values)
    179 cache = cls()
    180 if past_key_values is not None:
--> 181     for layer_idx in range(len(past_key_values)):
    182         key_states, value_states = past_key_values[layer_idx]
    183         cache.update(key_states, value_states, layer_idx)

@fayejf
Copy link

fayejf commented Apr 15, 2024

@fayejf You might want to look at this paper, too - handling extremely long inputs in a streaming fashion: Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention.

1B model that was fine-tuned on up to 5K sequence length passkey instances solved the 1M length problem.

@MoonRide303 Thanks for sharing! very interesting paper!

@fayejf
Copy link

fayejf commented Apr 15, 2024

@ArthurZucker I was with 4.38.2 and haven't got a chance to test it with latest release. But looks like @explanare have seen similar issues. I also tried with llama model and I've seen legacy cache error as well.

I'm not sure if I'm using SinkCache correctly. I didn't find any doc for it unfortuantely. :(
I follow the way in test file.

@ys-zong
Copy link

ys-zong commented Apr 17, 2024

@ArthurZucker It seems that even with Llama2, passing in SinkCache to generate causes errors.

I'm using transformers 4.39.3 and the Llama2 model was loaded using the following code:

from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-hf"
model = LlamaForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto',
      torch_dtype=torch.bfloat16)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

The SinkCache was passed to generate as in @fayejf's script. I am not sure if this is the correct way to use SinkCache:

from transformers import SinkCache

prefix = 'Hello world!'
inputs = tokenizer(prefix, return_tensors='pt').to(device)

cache = SinkCache(window_length=1024, num_sink_tokens=4)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=256,
                         use_cache=True,
                         past_key_values=cache,
                         pad_token_id=tokenizer.pad_token_id)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

The code caused TypeError: object of type 'SinkCache' has no len() as a result of this DynamicCache.from_legacy_cache call (see the stack trace below). Looks like you are familiar with the StaticCache stuff, any suggestions on how to get around this? Thanks in advance!

File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:977, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    975 if use_cache:  # kept for BC (cache positions)
    976     if not isinstance(past_key_values, StaticCache):
--> 977         past_key_values = DynamicCache.from_legacy_cache(past_key_values)
    978         past_seen_tokens = past_key_values.get_seq_length()
    980 if cache_position is None:

File /miniconda3/envs/pytorch2/lib/python3.10/site-packages/transformers/cache_utils.py:181, in DynamicCache.from_legacy_cache(cls, past_key_values)
    179 cache = cls()
    180 if past_key_values is not None:
--> 181     for layer_idx in range(len(past_key_values)):
    182         key_states, value_states = past_key_values[layer_idx]
    183         cache.update(key_states, value_states, layer_idx)

I got the same error with Llama2 using 4.38.2 following code here. Also tried transformers==4.39.0, got the same error.

File "/opt/conda/envs/dn/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 982, in forward
    past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  File "/opt/conda/envs/dn/lib/python3.10/site-packages/transformers/cache_utils.py", line 166, in from_legacy_cache
    for layer_idx in range(len(past_key_values)):
TypeError: object of type 'SinkCache' has no len()

@ArthurZucker
Copy link
Collaborator

cc @gante as well as we broke it

@ArthurZucker
Copy link
Collaborator

Will might be fixed by #30476. Only tuples should be going that path

@gante
Copy link
Member

gante commented Apr 30, 2024

@fayejf @ys-zong @ArthurZucker

SinkCache is operational on main with Mistral and a few other models (especially non-RoPE models). Llama and Llama-dependent models + SinkCache is fixed in this PR, which hopefully will get merged soon 🤗

Apologies for this temporary issue, we had to break a few eggs (sink cache) to make an omelet (torch.compile support) :D

Edit: merged, let me know if you run into issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement StreamingLLM/Windowed Attention with Attention Sinks