Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SlidingWindowCache: reduce differences to other Cache classes #30970

Merged
merged 4 commits into from
Jun 3, 2024

Conversation

gante
Copy link
Member

@gante gante commented May 22, 2024

What does this PR do?

Follow up to #30642: this PR aims at reducing the differences between SlidingWindowCache and StaticCache, such that long-term maintenance becomes easier. Fewer attributes/functions = less cognitive overload and fewer bugs 🤗

More specifically:
👉 no need for attributes regarding the sliding window (it is a form of maximum cache size, for which there was an attribute)
👉 list of 4D tensors holding the cache, as opposed to 5D tensors (to keep the same data format as in other caches)
👉 inherits from StaticCache, as most of the __init__ and other boilerplate functions are identical

Slow Mistral tests were ran locally, all green ✅

cc @zhenglongjiepheonix I meant to request these changes in the PR linked above, but I was slow to review 😛

@gante gante requested a review from ArthurZucker May 22, 2024 14:28
"""
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sliding_window is the config attribute name, not sliding_window_size

need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(self._cache, cache_cls))
or self._cache.max_batch_size < max_batch_size
or self._cache.max_batch_size != max_batch_size
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unrelated to the sliding window cache) this was incorrect, we need a new cache object with a different batch size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a nice catch!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall good for me, but wondering about where these graph breaks are from?

# assume this will be called only in the first generation step
# `cache_postion` will be used in other cases
return 0
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are these graph break from? (this did not work before?) because it's equivalent but less fast no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is an extra zero here involved to make cudagraphs happy, I believe we should not change the address of the tensor during compiling and direct assign violates that, in StaticCache there is no problem because k_out[:,:,cache_position] = key_states does not change the address of k_out, and if we want a 4d instead of 5d cache, the direct assign will just substitute the original tensor in layers list, causing address change

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh yeah which is why you had did not have this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is no tradeof to using this (make bench + test on a100 as well) fine, otherwise not fine but a comment to say why

@gante
Copy link
Member Author

gante commented May 24, 2024

@ArthurZucker @zhenglongjiepheonix the implementation from this PR is also faster 🙌

Setup:

  • A100 80GB
  • input length=502
  • max_new_tokens=128
  • compiling forward but calling generate (i.e. there is some overhead from calling the uncompiled generate)
  • model: mistralai/Mistral-7B-v0.1
code
from transformers import AutoTokenizer, MistralForCausalLM
import torch
import time

prompts = ["My favourite condiment is " * 100]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
print(inputs.input_ids.shape)

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

for i in range(5):
    start = time.time()
    generated_ids = model.generate(
        **inputs, max_new_tokens=128, do_sample=False, cache_implementation="sliding_window"
    )
    assert generated_ids.shape[1] == 128 + inputs.input_ids.shape[1]
    print(f"Time: {time.time() - start:.2f}s")

👉 static cache: 76.2 tok/s
👉 original sliding window: 70.7 tok/s
👉 this PR's sliding window: 74.9 tok/s

Could it be because there are fewer slicing OPs? (before, we had to slice the 5D cache into a 4D tensor at every layer)

@zhenglongjiepheonix
Copy link
Contributor

zhenglongjiepheonix commented May 27, 2024

@ArthurZucker @zhenglongjiepheonix the implementation from this PR is also faster 🙌

Setup:

  • A100 80GB
  • input length=502
  • max_new_tokens=128
  • compiling forward but calling generate (i.e. there is some overhead from calling the uncompiled generate)
  • model: mistralai/Mistral-7B-v0.1

code
👉 static cache: 76.2 tok/s 👉 original sliding window: 70.7 tok/s 👉 this PR's sliding window: 74.9 tok/s

Could it be because there are fewer slicing OPs? (before, we had to slice the 5D cache into a 4D tensor at every layer)

Yes, Slicing can be time-consuming, I have tested on my side and in your setting your implementation indeed saves about 1ms per token, I think it's good if we don't have to slice everytime by using zero

@gante gante requested a review from ArthurZucker May 29, 2024 15:19
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's merge @gante is not here

@ArthurZucker ArthurZucker merged commit d475f76 into huggingface:main Jun 3, 2024
26 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
…gface#30970)

* tmp commit

* sliding window with fewer differences

* make fixup + rebase

* missing overwrite
@gante gante deleted the sliding_window branch June 13, 2024 16:19
@gante gante mentioned this pull request Jun 13, 2024
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants