diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7c79f7dd4548..1e08144c414d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -122,8 +122,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the mask""" kv_offset = 0 query_length = cache_position.shape[0] - past_seen_tokens = self.get_seq_length() - kv_length = query_length + past_seen_tokens + kv_length = self.get_seq_length() + query_length return kv_length, kv_offset def get_seq_length(self) -> int: @@ -212,14 +211,13 @@ def update( def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: """Return the length and offset of the cache, used to generate the attention mask""" query_length = cache_position.shape[0] - first_cache_position = cache_position[0] + is_full = self.cumulative_length >= self.sliding_window - kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0) - - if self.get_seq_length() >= self.sliding_window: + kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0) + if is_full: kv_length = self.sliding_window - 1 + query_length else: - kv_length = self.get_seq_length() + query_length + kv_length = self.cumulative_length + query_length return kv_length, kv_offset @@ -461,9 +459,10 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: # Not yet full, but becoming full on this update elif self.cumulative_length + query_length > sliding_window: kv_length = self.cumulative_length + query_length + # Here the Cache is still smaller than the local size, but we return the local size as it's static else: - # Here the Cache is still smaller than the local size, but we return the local size as it's static kv_length = sliding_window + return kv_length, kv_offset def get_seq_length(self) -> int: