From 119f1a6680591c91c18ddef7cec9db3475932712 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 15 Sep 2025 19:15:25 +0200 Subject: [PATCH 1/5] merge --- src/transformers/__init__.py | 2 - src/transformers/cache_utils.py | 112 +++++++++----------------------- 2 files changed, 30 insertions(+), 84 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a8b19ef967f2..3671afcfa638 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -379,7 +379,6 @@ "DynamicLayer", "StaticLayer", "SlidingWindowLayer", - "ChunkedSlidingLayer", "QuantoQuantizedLayer", "HQQQuantizedLayer", "Cache", @@ -583,7 +582,6 @@ if TYPE_CHECKING: # All modeling imports from .cache_utils import Cache as Cache - from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer from .cache_utils import DynamicCache as DynamicCache from .cache_utils import DynamicLayer as DynamicLayer from .cache_utils import EncoderDecoderCache as EncoderDecoderCache diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d7b0fe6e1f83..2c53ed6466e4 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -395,7 +395,8 @@ def update( cache_position = cache_kwargs.get("cache_position") - is_full = self.cumulative_length >= self.max_cache_len + cumulative_length = self.cumulative_length + is_full = cumulative_length >= self.max_cache_len # Update it now that we saved the value above self.cumulative_length += key_states.shape[-2] @@ -407,88 +408,29 @@ def update( # Return the full states here return key_states, value_states - # Here we only assume decoding stage, i.e. 1 token at a time - if is_full: - # Roll all values to the left by 1 position - new_keys = self.keys.roll(-1, dims=-2) - new_values = self.values.roll(-1, dims=-2) - # Overwrite the last position with new states - # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) - index = torch.tensor([-1], dtype=int, device=self.device) - new_keys[:, :, index] = key_states - new_values[:, :, index] = value_states - - # Copy back into `self` (do not just assign again) in order to keep the static dynamo address - self.keys.copy_(new_keys) - self.values.copy_(new_values) - else: - try: - self.keys.index_copy_(2, cache_position, key_states) - self.values.index_copy_(2, cache_position, value_states) - except NotImplementedError: - self.keys[:, :, cache_position] = key_states - self.values[:, :, cache_position] = value_states - - return self.keys, self.values - - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the attention mask""" - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - kv_length = max(query_length, self.max_cache_len) - return kv_length, kv_offset - - def get_seq_length(self) -> int: - """Returns the sequence length of the cached states.""" - return self.cumulative_length - - -class ChunkedSlidingLayer(SlidingWindowLayer): - """ - An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4. - """ - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Update the key and value caches in-place, and return the necessary keys and value states. - - Args: - key_states (`torch.Tensor`): The new key states to cache. - value_states (`torch.Tensor`): The new value states to cache. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. - """ - # Lazy initialization - if self.keys is None: - self.lazy_initialization(key_states) - - cache_position = cache_kwargs.get("cache_position") - - cumulative_length = self.cumulative_length - is_full = cumulative_length >= self.max_cache_len - # Update it now that we saved the value above - self.cumulative_length += key_states.shape[-2] - if is_full: - full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) - # Fast decoding path -> here as the effective size is still sliding window, it is extremely important - # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address - # in memory (the values are the same as the full states, but not the address!!) + # In general, we should use a much simpler `cat` here as well, independently of the states size. However, + # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details if key_states.shape[-2] == 1: - self.keys.copy_(full_key_states) - self.values.copy_(full_value_states) + # Roll all values to the left by 1 position + new_keys = self.keys.roll(-1, dims=-2) + new_values = self.values.roll(-1, dims=-2) + # Overwrite the last position with new states + # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) + index = torch.tensor([-1], dtype=int, device=self.device) + new_keys[:, :, index] = key_states + new_values[:, :, index] = value_states + + # Copy back into `self` (do not just assign again) in order to keep the static dynamo address + self.keys.copy_(new_keys) + self.values.copy_(new_values) + # Very important to return the `self` tensors here, as they have the static dynamo address return self.keys, self.values + # Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...) + else: + full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + # Not yet full, but becoming full on this update elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: # Fast prefill path, no need to cat() in this case, as the cache is currently empty if cumulative_length == 0: @@ -504,12 +446,14 @@ def update( except NotImplementedError: self.keys[:, :, cache_position] = key_states self.values[:, :, cache_position] = value_states + + # Very important to return the `self` tensors here, as they have the static dynamo address return self.keys, self.values + # We only cache the last `sliding_window` tokens self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context - # which is outside the window return full_key_states, full_value_states def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: @@ -531,6 +475,10 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: kv_length = sliding_window return kv_length, kv_offset + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + class QuantizedLayer(DynamicLayer): """ @@ -1141,7 +1089,7 @@ def __init__( if layer_type == "sliding_attention": layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) elif layer_type == "chunked_attention": - layer = ChunkedSlidingLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) + layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) From 6d4be2c362d2a9e1739bbc68f0962774b2557819 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 15 Sep 2025 19:39:35 +0200 Subject: [PATCH 2/5] get rid of tensors in get_mask_sizes!! --- src/transformers/cache_utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2c53ed6466e4..c3b2a5be0678 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -431,7 +431,7 @@ def update( full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) # Not yet full, but becoming full on this update - elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: + elif cumulative_length + key_states.shape[2] > self.max_cache_len: # Fast prefill path, no need to cat() in this case, as the cache is currently empty if cumulative_length == 0: full_key_states = key_states @@ -459,17 +459,16 @@ def update( def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the attention mask""" query_length = cache_position.shape[0] - first_cache_position = cache_position[0] sliding_window = self.max_cache_len + is_full = self.cumulative_length >= self.max_cache_len - kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) - # This is the true general case for any Cache using local attention (sliding or chunked) - if first_cache_position >= sliding_window: - # Here the Cache is already full + kv_offset = max(self.cumulative_length - sliding_window + 1, 0) + # The cache is already full + if is_full: kv_length = sliding_window + query_length - 1 - elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: - # Here the Cache becomes full with the new input - kv_length = first_cache_position + query_length + # Not yet full, but becoming full on this update + elif self.cumulative_length + query_length > sliding_window: + kv_length = self.cumulative_length + query_length else: # Here the Cache is still smaller than the local size, but we return the local size as it's static kv_length = sliding_window From 4bfb1d4f1e90e7bd45d67f18763099bd2a12f1b3 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 16 Sep 2025 10:29:30 +0200 Subject: [PATCH 3/5] remove branch --- src/transformers/cache_utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c3b2a5be0678..2846562a39ae 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -400,14 +400,6 @@ def update( # Update it now that we saved the value above self.cumulative_length += key_states.shape[-2] - # Handle prefill phase when prompt length > sliding_window_size. - # Note that we store cropped key/value states in the cache but return the full key/value states. - if cache_position.shape[0] > self.max_cache_len: - self.keys.copy_(key_states[:, :, -self.max_cache_len :, :]) - self.values.copy_(value_states[:, :, -self.max_cache_len :, :]) - # Return the full states here - return key_states, value_states - if is_full: # In general, we should use a much simpler `cat` here as well, independently of the states size. However, # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details From 7dc84945ee0b1a395895fc5a3a5aa2fec6a7b389 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 16 Sep 2025 11:20:43 +0200 Subject: [PATCH 4/5] add comment explanation --- src/transformers/cache_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2846562a39ae..3949f30ca451 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -962,6 +962,8 @@ def __init__( layer_types = layer_types[: -config.num_kv_shared_layers] for layer_type in layer_types: + # From a cache point of view, both sliding and chunked are the same in how they should behave and how many + # states they should return - only the mask changes to make them different at the end! if layer_type in ("sliding_attention", "chunked_attention"): layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) else: @@ -1080,6 +1082,8 @@ def __init__( if layer_type == "sliding_attention": layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) elif layer_type == "chunked_attention": + # From a cache point of view, both sliding and chunked are the same in how they should behave and how many + # states they should return - only the mask changes to make them different at the end! layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) else: layer = StaticLayer(max_cache_len=max_cache_len) From e891959c94b83353f1bc55aa448633ae9ed0f753 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 16 Sep 2025 11:30:00 +0200 Subject: [PATCH 5/5] re-add the class with deprecation cycle --- src/transformers/__init__.py | 2 ++ src/transformers/cache_utils.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3671afcfa638..a8b19ef967f2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -379,6 +379,7 @@ "DynamicLayer", "StaticLayer", "SlidingWindowLayer", + "ChunkedSlidingLayer", "QuantoQuantizedLayer", "HQQQuantizedLayer", "Cache", @@ -582,6 +583,7 @@ if TYPE_CHECKING: # All modeling imports from .cache_utils import Cache as Cache + from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer from .cache_utils import DynamicCache as DynamicCache from .cache_utils import DynamicLayer as DynamicLayer from .cache_utils import EncoderDecoderCache as EncoderDecoderCache diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3949f30ca451..e519db4d8f2d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1357,6 +1357,15 @@ def is_compileable(self) -> bool: ### Deprecated classes +class ChunkedSlidingLayer(SlidingWindowLayer): + def __init__(self, max_cache_len: int, sliding_window: int): + logger.warning_once( + "`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 " + "Use `SlidingWindowLayer` instead, which has the exact same functionalities." + ) + super().__init__(max_cache_len, sliding_window) + + class OffloadedCache(DynamicCache): def __init__(self) -> None: logger.warning_once(