Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the mask"""
kv_offset = 0
query_length = cache_position.shape[0]
past_seen_tokens = self.get_seq_length()
kv_length = query_length + past_seen_tokens
kv_length = self.get_seq_length() + query_length
return kv_length, kv_offset

def get_seq_length(self) -> int:
Expand Down Expand Up @@ -212,14 +211,13 @@ def update(
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
is_full = self.cumulative_length >= self.sliding_window

kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)

if self.get_seq_length() >= self.sliding_window:
kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
if is_full:
kv_length = self.sliding_window - 1 + query_length
else:
kv_length = self.get_seq_length() + query_length
kv_length = self.cumulative_length + query_length

return kv_length, kv_offset

Expand Down Expand Up @@ -461,9 +459,10 @@ def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
# Not yet full, but becoming full on this update
elif self.cumulative_length + query_length > sliding_window:
kv_length = self.cumulative_length + query_length
# Here the Cache is still smaller than the local size, but we return the local size as it's static
else:
# Here the Cache is still smaller than the local size, but we return the local size as it's static
kv_length = sliding_window

return kv_length, kv_offset

def get_seq_length(self) -> int:
Expand Down