Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 49 additions & 97 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,85 +372,6 @@ def __init__(self, max_cache_len: int, sliding_window: int):
super().__init__(max_cache_len=effective_max_cache_len)
self.cumulative_length = 0

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.

Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.

Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if self.keys is None:
self.lazy_initialization(key_states)

cache_position = cache_kwargs.get("cache_position")

is_full = self.cumulative_length >= self.max_cache_len
# Update it now that we saved the value above
self.cumulative_length += key_states.shape[-2]

# Handle prefill phase when prompt length > sliding_window_size.
# Note that we store cropped key/value states in the cache but return the full key/value states.
if cache_position.shape[0] > self.max_cache_len:
self.keys.copy_(key_states[:, :, -self.max_cache_len :, :])
self.values.copy_(value_states[:, :, -self.max_cache_len :, :])
# Return the full states here
return key_states, value_states

# Here we only assume decoding stage, i.e. 1 token at a time
if is_full:
# Roll all values to the left by 1 position
new_keys = self.keys.roll(-1, dims=-2)
new_values = self.values.roll(-1, dims=-2)
# Overwrite the last position with new states
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
index = torch.tensor([-1], dtype=int, device=self.device)
new_keys[:, :, index] = key_states
new_values[:, :, index] = value_states

# Copy back into `self` (do not just assign again) in order to keep the static dynamo address
self.keys.copy_(new_keys)
self.values.copy_(new_values)
else:
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states

return self.keys, self.values

def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]

kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0)
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
kv_length = max(query_length, self.max_cache_len)
return kv_length, kv_offset

def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length


class ChunkedSlidingLayer(SlidingWindowLayer):
"""
An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4.
"""

def update(
self,
key_states: torch.Tensor,
Expand Down Expand Up @@ -480,16 +401,29 @@ def update(
self.cumulative_length += key_states.shape[-2]

if is_full:
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
# Fast decoding path -> here as the effective size is still sliding window, it is extremely important
# to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address
# in memory (the values are the same as the full states, but not the address!!)
# In general, we should use a much simpler `cat` here as well, independently of the states size. However,
# dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
if key_states.shape[-2] == 1:
self.keys.copy_(full_key_states)
self.values.copy_(full_value_states)
# Roll all values to the left by 1 position
new_keys = self.keys.roll(-1, dims=-2)
new_values = self.values.roll(-1, dims=-2)
# Overwrite the last position with new states
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
index = torch.tensor([-1], dtype=int, device=self.device)
new_keys[:, :, index] = key_states
new_values[:, :, index] = value_states

# Copy back into `self` (do not just assign again) in order to keep the static dynamo address
self.keys.copy_(new_keys)
self.values.copy_(new_values)
# Very important to return the `self` tensors here, as they have the static dynamo address
return self.keys, self.values
elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len:
# Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
else:
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
Comment on lines +423 to +424
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never understood this 1: in ChunkedPrefill, maybe worth writing an explanation.

When full and M>1 new tokens arrive, we store last self.max_cache_len tokens but we return self.max_cache_len - 1 + M tokens (the arrived tokens + self.keys[:, :, 1:, :])

Why this removal of the first token?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because the sliding_window is exclusive, i.e. a token can see at most itself and the last sliding_window - 1 tokens, so the first token of the cache it out-of-scope when we're already full and a new token arrives. Technically, we could cache only sliding_window - 1 tokens altogether, as I did in DynamicSlidingWindowLayer, which would maybe be the preferred way to avoid the additional slicing ops. But it's out-of-scope of this PR, and can technically allow us to rollback one token for some generation methods, so it's not entirely useless

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have a comment to be explicit!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or we could indeed align with Dynamic version and remove the extra token in the cache, which would be more efficient technically - but this branch is very rare in practice, as it's only used when passing more than 1 token after the cache is already full, which is a specific case of prefix caching or similar

# Not yet full, but becoming full on this update
elif cumulative_length + key_states.shape[2] > self.max_cache_len:
# Fast prefill path, no need to cat() in this case, as the cache is currently empty
if cumulative_length == 0:
full_key_states = key_states
Expand All @@ -504,33 +438,38 @@ def update(
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states

# Very important to return the `self` tensors here, as they have the static dynamo address
return self.keys, self.values

# We only cache the last `sliding_window` tokens
self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
# we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
# which is outside the window
return full_key_states, full_value_states

def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
sliding_window = self.max_cache_len
is_full = self.cumulative_length >= self.max_cache_len

kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0)
# This is the true general case for any Cache using local attention (sliding or chunked)
if first_cache_position >= sliding_window:
# Here the Cache is already full
kv_offset = max(self.cumulative_length - sliding_window + 1, 0)
# The cache is already full
if is_full:
kv_length = sliding_window + query_length - 1
elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window:
# Here the Cache becomes full with the new input
kv_length = first_cache_position + query_length
# Not yet full, but becoming full on this update
elif self.cumulative_length + query_length > sliding_window:
kv_length = self.cumulative_length + query_length
else:
# Here the Cache is still smaller than the local size, but we return the local size as it's static
kv_length = sliding_window
return kv_length, kv_offset
Comment on lines 451 to 467
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is now much better as we only work with Python scalars, as opposed with a mix of scalar/tensor -> much easier downstream with the masks and compilation


def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length


class QuantizedLayer(DynamicLayer):
"""
Expand Down Expand Up @@ -1023,6 +962,8 @@ def __init__(
layer_types = layer_types[: -config.num_kv_shared_layers]

for layer_type in layer_types:
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
# states they should return - only the mask changes to make them different at the end!
if layer_type in ("sliding_attention", "chunked_attention"):
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
else:
Expand Down Expand Up @@ -1141,7 +1082,9 @@ def __init__(
if layer_type == "sliding_attention":
layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
elif layer_type == "chunked_attention":
layer = ChunkedSlidingLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size)
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
# states they should return - only the mask changes to make them different at the end!
layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't differentiate between chunked (block mask) and sliding because nothing changes for a cache pov! Indeed, they should overall behave the same, mask is what differs!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will add a comment to explain here!

else:
layer = StaticLayer(max_cache_len=max_cache_len)
layers.append(layer)
Expand Down Expand Up @@ -1414,6 +1357,15 @@ def is_compileable(self) -> bool:
### Deprecated classes


class ChunkedSlidingLayer(SlidingWindowLayer):
def __init__(self, max_cache_len: int, sliding_window: int):
logger.warning_once(
"`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 "
"Use `SlidingWindowLayer` instead, which has the exact same functionalities."
)
super().__init__(max_cache_len, sliding_window)


class OffloadedCache(DynamicCache):
def __init__(self) -> None:
logger.warning_once(
Expand Down