-
Notifications
You must be signed in to change notification settings - Fork 25.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SlidingWindowCache: reduce differences to other Cache classes #30970
Conversation
""" | ||
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. | ||
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`, | ||
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sliding_window
is the config attribute name, not sliding_window_size
need_new_cache = ( | ||
not hasattr(self, "_cache") | ||
or (not isinstance(self._cache, cache_cls)) | ||
or self._cache.max_batch_size < max_batch_size | ||
or self._cache.max_batch_size != max_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(unrelated to the sliding window cache) this was incorrect, we need a new cache object with a different batch size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a nice catch!
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
Outdated
Show resolved
Hide resolved
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall good for me, but wondering about where these graph breaks are from?
# assume this will be called only in the first generation step | ||
# `cache_postion` will be used in other cases | ||
return 0 | ||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where are these graph break from? (this did not work before?) because it's equivalent but less fast no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is an extra zero
here involved to make cudagraphs happy, I believe we should not change the address of the tensor during compiling and direct assign violates that, in StaticCache
there is no problem because k_out[:,:,cache_position] = key_states
does not change the address of k_out
, and if we want a 4d instead of 5d cache, the direct assign will just substitute the original tensor in layers list, causing address change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh yeah which is why you had did not have this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is no tradeof to using this (make bench + test on a100 as well) fine, otherwise not fine but a comment to say why
@ArthurZucker @zhenglongjiepheonix the implementation from this PR is also faster 🙌 Setup:
codefrom transformers import AutoTokenizer, MistralForCausalLM
import torch
import time
prompts = ["My favourite condiment is " * 100]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
print(inputs.input_ids.shape)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
for i in range(5):
start = time.time()
generated_ids = model.generate(
**inputs, max_new_tokens=128, do_sample=False, cache_implementation="sliding_window"
)
assert generated_ids.shape[1] == 128 + inputs.input_ids.shape[1]
print(f"Time: {time.time() - start:.2f}s") 👉 static cache: 76.2 tok/s Could it be because there are fewer slicing OPs? (before, we had to slice the 5D cache into a 4D tensor at every layer) |
Yes, Slicing can be time-consuming, I have tested on my side and in your setting your implementation indeed saves about 1ms per token, I think it's good if we don't have to slice everytime by using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's merge @gante is not here
…gface#30970) * tmp commit * sliding window with fewer differences * make fixup + rebase * missing overwrite
What does this PR do?
Follow up to #30642: this PR aims at reducing the differences between
SlidingWindowCache
andStaticCache
, such that long-term maintenance becomes easier. Fewer attributes/functions = less cognitive overload and fewer bugs 🤗More specifically:
👉 no need for attributes regarding the sliding window (it is a form of maximum cache size, for which there was an attribute)
👉 list of 4D tensors holding the cache, as opposed to 5D tensors (to keep the same data format as in other caches)
👉 inherits from
StaticCache
, as most of the__init__
and other boilerplate functions are identicalSlow Mistral tests were ran locally, all green ✅
cc @zhenglongjiepheonix I meant to request these changes in the PR linked above, but I was slow to review 😛