-
Notifications
You must be signed in to change notification settings - Fork 30.6k
[cache] Merge static sliding and static chunked layer #40893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
794483c
to
620d4aa
Compare
620d4aa
to
6d4be2c
Compare
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: | ||
"""Return the length and offset of the cache, used to generate the attention mask""" | ||
query_length = cache_position.shape[0] | ||
first_cache_position = cache_position[0] | ||
sliding_window = self.max_cache_len | ||
is_full = self.cumulative_length >= self.max_cache_len | ||
|
||
kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) | ||
# This is the true general case for any Cache using local attention (sliding or chunked) | ||
if first_cache_position >= sliding_window: | ||
# Here the Cache is already full | ||
kv_offset = max(self.cumulative_length - sliding_window + 1, 0) | ||
# The cache is already full | ||
if is_full: | ||
kv_length = sliding_window + query_length - 1 | ||
elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: | ||
# Here the Cache becomes full with the new input | ||
kv_length = first_cache_position + query_length | ||
# Not yet full, but becoming full on this update | ||
elif self.cumulative_length + query_length > sliding_window: | ||
kv_length = self.cumulative_length + query_length | ||
else: | ||
# Here the Cache is still smaller than the local size, but we return the local size as it's static | ||
kv_length = sliding_window | ||
return kv_length, kv_offset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is now much better as we only work with Python scalars, as opposed with a mix of scalar/tensor -> much easier downstream with the masks and compilation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would document why but lgtm
layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) | ||
elif layer_type == "chunked_attention": | ||
layer = ChunkedSlidingLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) | ||
layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't differentiate between chunked (block mask) and sliding because nothing changes for a cache pov! Indeed, they should overall behave the same, mask is what differs!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will add a comment to explain here!
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) | ||
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never understood this 1:
in ChunkedPrefill, maybe worth writing an explanation.
When full and M>1 new tokens arrive, we store last self.max_cache_len
tokens but we return self.max_cache_len - 1 + M
tokens (the arrived tokens + self.keys[:, :, 1:, :]
)
Why this removal of the first token?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because the sliding_window is exclusive, i.e. a token can see at most itself and the last sliding_window - 1
tokens, so the first token of the cache it out-of-scope when we're already full and a new token arrives. Technically, we could cache only sliding_window - 1
tokens altogether, as I did in DynamicSlidingWindowLayer
, which would maybe be the preferred way to avoid the additional slicing ops. But it's out-of-scope of this PR, and can technically allow us to rollback one token for some generation methods, so it's not entirely useless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have a comment to be explicit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, or we could indeed align with Dynamic version and remove the extra token in the cache, which would be more efficient technically - but this branch is very rare in practice, as it's only used when passing more than 1 token after the cache is already full, which is a specific case of prefix caching or similar
What does this PR do?
As per the title. As discussed quite a few times, they are exactly the same, except the Chunked version is more general, as it can handle an arbitrary number of new tokens even after prefill (i.e. prefill caching, chat continuation etc...).
This PR merges them both, to only keep the more general version, which will improve the scope of SlidingWindowLayer usage with the aforementioned use-cases!
Thus Static and Dynamic caches can now be used exactly the same way, in all generality
cc @gante @manueldeprada as well for viz! Finally merging them!
I made sure slow tests on fully sliding (Mistral) and hybrid (Gemma2) are still fine!