diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d7b0fe6e1f83..e519db4d8f2d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -372,85 +372,6 @@ def __init__(self, max_cache_len: int, sliding_window: int): super().__init__(max_cache_len=effective_max_cache_len) self.cumulative_length = 0 - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Update the key and value caches in-place, and return the necessary keys and value states. - - Args: - key_states (`torch.Tensor`): The new key states to cache. - value_states (`torch.Tensor`): The new value states to cache. - cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache. - - Returns: - tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states. - """ - # Lazy initialization - if self.keys is None: - self.lazy_initialization(key_states) - - cache_position = cache_kwargs.get("cache_position") - - is_full = self.cumulative_length >= self.max_cache_len - # Update it now that we saved the value above - self.cumulative_length += key_states.shape[-2] - - # Handle prefill phase when prompt length > sliding_window_size. - # Note that we store cropped key/value states in the cache but return the full key/value states. - if cache_position.shape[0] > self.max_cache_len: - self.keys.copy_(key_states[:, :, -self.max_cache_len :, :]) - self.values.copy_(value_states[:, :, -self.max_cache_len :, :]) - # Return the full states here - return key_states, value_states - - # Here we only assume decoding stage, i.e. 1 token at a time - if is_full: - # Roll all values to the left by 1 position - new_keys = self.keys.roll(-1, dims=-2) - new_values = self.values.roll(-1, dims=-2) - # Overwrite the last position with new states - # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) - index = torch.tensor([-1], dtype=int, device=self.device) - new_keys[:, :, index] = key_states - new_values[:, :, index] = value_states - - # Copy back into `self` (do not just assign again) in order to keep the static dynamo address - self.keys.copy_(new_keys) - self.values.copy_(new_values) - else: - try: - self.keys.index_copy_(2, cache_position, key_states) - self.values.index_copy_(2, cache_position, value_states) - except NotImplementedError: - self.keys[:, :, cache_position] = key_states - self.values[:, :, cache_position] = value_states - - return self.keys, self.values - - def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: - """Return the length and offset of the cache, used to generate the attention mask""" - query_length = cache_position.shape[0] - first_cache_position = cache_position[0] - - kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0) - # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns - kv_length = max(query_length, self.max_cache_len) - return kv_length, kv_offset - - def get_seq_length(self) -> int: - """Returns the sequence length of the cached states.""" - return self.cumulative_length - - -class ChunkedSlidingLayer(SlidingWindowLayer): - """ - An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4. - """ - def update( self, key_states: torch.Tensor, @@ -480,16 +401,29 @@ def update( self.cumulative_length += key_states.shape[-2] if is_full: - full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) - # Fast decoding path -> here as the effective size is still sliding window, it is extremely important - # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address - # in memory (the values are the same as the full states, but not the address!!) + # In general, we should use a much simpler `cat` here as well, independently of the states size. However, + # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details if key_states.shape[-2] == 1: - self.keys.copy_(full_key_states) - self.values.copy_(full_value_states) + # Roll all values to the left by 1 position + new_keys = self.keys.roll(-1, dims=-2) + new_values = self.values.roll(-1, dims=-2) + # Overwrite the last position with new states + # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855) + index = torch.tensor([-1], dtype=int, device=self.device) + new_keys[:, :, index] = key_states + new_values[:, :, index] = value_states + + # Copy back into `self` (do not just assign again) in order to keep the static dynamo address + self.keys.copy_(new_keys) + self.values.copy_(new_values) + # Very important to return the `self` tensors here, as they have the static dynamo address return self.keys, self.values - elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len: + # Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...) + else: + full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2) + # Not yet full, but becoming full on this update + elif cumulative_length + key_states.shape[2] > self.max_cache_len: # Fast prefill path, no need to cat() in this case, as the cache is currently empty if cumulative_length == 0: full_key_states = key_states @@ -504,33 +438,38 @@ def update( except NotImplementedError: self.keys[:, :, cache_position] = key_states self.values[:, :, cache_position] = value_states + + # Very important to return the `self` tensors here, as they have the static dynamo address return self.keys, self.values + # We only cache the last `sliding_window` tokens self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :]) self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :]) # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context - # which is outside the window return full_key_states, full_value_states def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the attention mask""" query_length = cache_position.shape[0] - first_cache_position = cache_position[0] sliding_window = self.max_cache_len + is_full = self.cumulative_length >= self.max_cache_len - kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0) - # This is the true general case for any Cache using local attention (sliding or chunked) - if first_cache_position >= sliding_window: - # Here the Cache is already full + kv_offset = max(self.cumulative_length - sliding_window + 1, 0) + # The cache is already full + if is_full: kv_length = sliding_window + query_length - 1 - elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window: - # Here the Cache becomes full with the new input - kv_length = first_cache_position + query_length + # Not yet full, but becoming full on this update + elif self.cumulative_length + query_length > sliding_window: + kv_length = self.cumulative_length + query_length else: # Here the Cache is still smaller than the local size, but we return the local size as it's static kv_length = sliding_window return kv_length, kv_offset + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + class QuantizedLayer(DynamicLayer): """ @@ -1023,6 +962,8 @@ def __init__( layer_types = layer_types[: -config.num_kv_shared_layers] for layer_type in layer_types: + # From a cache point of view, both sliding and chunked are the same in how they should behave and how many + # states they should return - only the mask changes to make them different at the end! if layer_type in ("sliding_attention", "chunked_attention"): layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) else: @@ -1141,7 +1082,9 @@ def __init__( if layer_type == "sliding_attention": layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) elif layer_type == "chunked_attention": - layer = ChunkedSlidingLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) + # From a cache point of view, both sliding and chunked are the same in how they should behave and how many + # states they should return - only the mask changes to make them different at the end! + layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size) else: layer = StaticLayer(max_cache_len=max_cache_len) layers.append(layer) @@ -1414,6 +1357,15 @@ def is_compileable(self) -> bool: ### Deprecated classes +class ChunkedSlidingLayer(SlidingWindowLayer): + def __init__(self, max_cache_len: int, sliding_window: int): + logger.warning_once( + "`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 " + "Use `SlidingWindowLayer` instead, which has the exact same functionalities." + ) + super().__init__(max_cache_len, sliding_window) + + class OffloadedCache(DynamicCache): def __init__(self) -> None: logger.warning_once(