-
Notifications
You must be signed in to change notification settings - Fork 30.6k
[cache] Merge static sliding and static chunked layer #40893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -372,85 +372,6 @@ def __init__(self, max_cache_len: int, sliding_window: int): | |
super().__init__(max_cache_len=effective_max_cache_len) | ||
self.cumulative_length = 0 | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
cache_kwargs: Optional[dict[str, Any]] = None, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Update the key and value caches in-place, and return the necessary keys and value states. | ||
|
||
Args: | ||
key_states (`torch.Tensor`): The new key states to cache. | ||
value_states (`torch.Tensor`): The new value states to cache. | ||
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. | ||
|
||
Returns: | ||
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. | ||
""" | ||
# Lazy initialization | ||
if self.keys is None: | ||
self.lazy_initialization(key_states) | ||
|
||
cache_position = cache_kwargs.get("cache_position") | ||
|
||
is_full = self.cumulative_length >= self.max_cache_len | ||
# Update it now that we saved the value above | ||
self.cumulative_length += key_states.shape[-2] | ||
|
||
# Handle prefill phase when prompt length > sliding_window_size. | ||
# Note that we store cropped key/value states in the cache but return the full key/value states. | ||
if cache_position.shape[0] > self.max_cache_len: | ||
self.keys.copy_(key_states[:, :, -self.max_cache_len :, :]) | ||
self.values.copy_(value_states[:, :, -self.max_cache_len :, :]) | ||
# Return the full states here | ||
return key_states, value_states | ||
|
||
# Here we only assume decoding stage, i.e. 1 token at a time | ||
if is_full: | ||
# Roll all values to the left by 1 position | ||
new_keys = self.keys.roll(-1, dims=-2) | ||
new_values = self.values.roll(-1, dims=-2) | ||
# Overwrite the last position with new states | ||
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) | ||
index = torch.tensor([-1], dtype=int, device=self.device) | ||
new_keys[:, :, index] = key_states | ||
new_values[:, :, index] = value_states | ||
|
||
# Copy back into `self` (do not just assign again) in order to keep the static dynamo address | ||
self.keys.copy_(new_keys) | ||
self.values.copy_(new_values) | ||
else: | ||
try: | ||
self.keys.index_copy_(2, cache_position, key_states) | ||
self.values.index_copy_(2, cache_position, value_states) | ||
except NotImplementedError: | ||
self.keys[:, :, cache_position] = key_states | ||
self.values[:, :, cache_position] = value_states | ||
|
||
return self.keys, self.values | ||
|
||
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: | ||
"""Return the length and offset of the cache, used to generate the attention mask""" | ||
query_length = cache_position.shape[0] | ||
first_cache_position = cache_position[0] | ||
|
||
kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) | ||
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns | ||
kv_length = max(query_length, self.max_cache_len) | ||
return kv_length, kv_offset | ||
|
||
def get_seq_length(self) -> int: | ||
"""Returns the sequence length of the cached states.""" | ||
return self.cumulative_length | ||
|
||
|
||
class ChunkedSlidingLayer(SlidingWindowLayer): | ||
""" | ||
An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4. | ||
""" | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
|
@@ -480,16 +401,29 @@ def update( | |
self.cumulative_length += key_states.shape[-2] | ||
|
||
if is_full: | ||
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) | ||
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) | ||
# Fast decoding path -> here as the effective size is still sliding window, it is extremely important | ||
# to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address | ||
# in memory (the values are the same as the full states, but not the address!!) | ||
# In general, we should use a much simpler `cat` here as well, independently of the states size. However, | ||
# dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details | ||
if key_states.shape[-2] == 1: | ||
self.keys.copy_(full_key_states) | ||
self.values.copy_(full_value_states) | ||
# Roll all values to the left by 1 position | ||
new_keys = self.keys.roll(-1, dims=-2) | ||
new_values = self.values.roll(-1, dims=-2) | ||
# Overwrite the last position with new states | ||
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) | ||
index = torch.tensor([-1], dtype=int, device=self.device) | ||
new_keys[:, :, index] = key_states | ||
new_values[:, :, index] = value_states | ||
|
||
# Copy back into `self` (do not just assign again) in order to keep the static dynamo address | ||
self.keys.copy_(new_keys) | ||
self.values.copy_(new_values) | ||
# Very important to return the `self` tensors here, as they have the static dynamo address | ||
return self.keys, self.values | ||
elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: | ||
# Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...) | ||
else: | ||
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) | ||
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) | ||
# Not yet full, but becoming full on this update | ||
elif cumulative_length + key_states.shape[2] > self.max_cache_len: | ||
# Fast prefill path, no need to cat() in this case, as the cache is currently empty | ||
if cumulative_length == 0: | ||
full_key_states = key_states | ||
|
@@ -504,33 +438,38 @@ def update( | |
except NotImplementedError: | ||
self.keys[:, :, cache_position] = key_states | ||
self.values[:, :, cache_position] = value_states | ||
|
||
# Very important to return the `self` tensors here, as they have the static dynamo address | ||
return self.keys, self.values | ||
|
||
# We only cache the last `sliding_window` tokens | ||
self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) | ||
self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) | ||
# we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context | ||
# which is outside the window | ||
return full_key_states, full_value_states | ||
|
||
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: | ||
"""Return the length and offset of the cache, used to generate the attention mask""" | ||
query_length = cache_position.shape[0] | ||
Cyrilvallez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
first_cache_position = cache_position[0] | ||
sliding_window = self.max_cache_len | ||
is_full = self.cumulative_length >= self.max_cache_len | ||
|
||
kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) | ||
# This is the true general case for any Cache using local attention (sliding or chunked) | ||
if first_cache_position >= sliding_window: | ||
# Here the Cache is already full | ||
kv_offset = max(self.cumulative_length - sliding_window + 1, 0) | ||
# The cache is already full | ||
if is_full: | ||
kv_length = sliding_window + query_length - 1 | ||
elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: | ||
# Here the Cache becomes full with the new input | ||
kv_length = first_cache_position + query_length | ||
# Not yet full, but becoming full on this update | ||
elif self.cumulative_length + query_length > sliding_window: | ||
kv_length = self.cumulative_length + query_length | ||
else: | ||
# Here the Cache is still smaller than the local size, but we return the local size as it's static | ||
kv_length = sliding_window | ||
return kv_length, kv_offset | ||
Comment on lines
451
to
467
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is now much better as we only work with Python scalars, as opposed with a mix of scalar/tensor -> much easier downstream with the masks and compilation |
||
|
||
def get_seq_length(self) -> int: | ||
"""Returns the sequence length of the cached states.""" | ||
return self.cumulative_length | ||
|
||
|
||
class QuantizedLayer(DynamicLayer): | ||
""" | ||
|
@@ -1023,6 +962,8 @@ def __init__( | |
layer_types = layer_types[: -config.num_kv_shared_layers] | ||
|
||
for layer_type in layer_types: | ||
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many | ||
# states they should return - only the mask changes to make them different at the end! | ||
if layer_type in ("sliding_attention", "chunked_attention"): | ||
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) | ||
else: | ||
|
@@ -1141,7 +1082,9 @@ def __init__( | |
if layer_type == "sliding_attention": | ||
layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) | ||
elif layer_type == "chunked_attention": | ||
layer = ChunkedSlidingLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) | ||
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many | ||
# states they should return - only the mask changes to make them different at the end! | ||
layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't differentiate between chunked (block mask) and sliding because nothing changes for a cache pov! Indeed, they should overall behave the same, mask is what differs! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will add a comment to explain here! |
||
else: | ||
layer = StaticLayer(max_cache_len=max_cache_len) | ||
layers.append(layer) | ||
|
@@ -1414,6 +1357,15 @@ def is_compileable(self) -> bool: | |
### Deprecated classes | ||
|
||
|
||
class ChunkedSlidingLayer(SlidingWindowLayer): | ||
def __init__(self, max_cache_len: int, sliding_window: int): | ||
logger.warning_once( | ||
"`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 " | ||
"Use `SlidingWindowLayer` instead, which has the exact same functionalities." | ||
) | ||
super().__init__(max_cache_len, sliding_window) | ||
|
||
|
||
class OffloadedCache(DynamicCache): | ||
def __init__(self) -> None: | ||
logger.warning_once( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never understood this
1:
in ChunkedPrefill, maybe worth writing an explanation.When full and M>1 new tokens arrive, we store last
self.max_cache_len
tokens but we returnself.max_cache_len - 1 + M
tokens (the arrived tokens +self.keys[:, :, 1:, :]
)Why this removal of the first token?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because the sliding_window is exclusive, i.e. a token can see at most itself and the last
sliding_window - 1
tokens, so the first token of the cache it out-of-scope when we're already full and a new token arrives. Technically, we could cache onlysliding_window - 1
tokens altogether, as I did inDynamicSlidingWindowLayer
, which would maybe be the preferred way to avoid the additional slicing ops. But it's out-of-scope of this PR, and can technically allow us to rollback one token for some generation methods, so it's not entirely uselessThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have a comment to be explicit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, or we could indeed align with Dynamic version and remove the extra token in the cache, which would be more efficient technically - but this branch is very rare in practice, as it's only used when passing more than 1 token after the cache is already full, which is a specific case of prefix caching or similar