Skip to content

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Sep 15, 2025

What does this PR do?

As per the title. As discussed quite a few times, they are exactly the same, except the Chunked version is more general, as it can handle an arbitrary number of new tokens even after prefill (i.e. prefill caching, chat continuation etc...).
This PR merges them both, to only keep the more general version, which will improve the scope of SlidingWindowLayer usage with the aforementioned use-cases!
Thus Static and Dynamic caches can now be used exactly the same way, in all generality

cc @gante @manueldeprada as well for viz! Finally merging them!

I made sure slow tests on fully sliding (Mistral) and hybrid (Gemma2) are still fine!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 459 to 475
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
sliding_window = self.max_cache_len
is_full = self.cumulative_length >= self.max_cache_len

kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0)
# This is the true general case for any Cache using local attention (sliding or chunked)
if first_cache_position >= sliding_window:
# Here the Cache is already full
kv_offset = max(self.cumulative_length - sliding_window + 1, 0)
# The cache is already full
if is_full:
kv_length = sliding_window + query_length - 1
elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window:
# Here the Cache becomes full with the new input
kv_length = first_cache_position + query_length
# Not yet full, but becoming full on this update
elif self.cumulative_length + query_length > sliding_window:
kv_length = self.cumulative_length + query_length
else:
# Here the Cache is still smaller than the local size, but we return the local size as it's static
kv_length = sliding_window
return kv_length, kv_offset
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is now much better as we only work with Python scalars, as opposed with a mix of scalar/tensor -> much easier downstream with the masks and compilation

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would document why but lgtm

layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
elif layer_type == "chunked_attention":
layer = ChunkedSlidingLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size)
layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't differentiate between chunked (block mask) and sliding because nothing changes for a cache pov! Indeed, they should overall behave the same, mask is what differs!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will add a comment to explain here!

@Cyrilvallez Cyrilvallez merged commit 087775d into main Sep 16, 2025
21 of 24 checks passed
@Cyrilvallez Cyrilvallez deleted the merge-cache-layers branch September 16, 2025 09:41
Comment on lines +423 to +424
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never understood this 1: in ChunkedPrefill, maybe worth writing an explanation.

When full and M>1 new tokens arrive, we store last self.max_cache_len tokens but we return self.max_cache_len - 1 + M tokens (the arrived tokens + self.keys[:, :, 1:, :])

Why this removal of the first token?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because the sliding_window is exclusive, i.e. a token can see at most itself and the last sliding_window - 1 tokens, so the first token of the cache it out-of-scope when we're already full and a new token arrives. Technically, we could cache only sliding_window - 1 tokens altogether, as I did in DynamicSlidingWindowLayer, which would maybe be the preferred way to avoid the additional slicing ops. But it's out-of-scope of this PR, and can technically allow us to rollback one token for some generation methods, so it's not entirely useless

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have a comment to be explicit!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or we could indeed align with Dynamic version and remove the extra token in the cache, which would be more efficient technically - but this branch is very rare in practice, as it's only used when passing more than 1 token after the cache is already full, which is a specific case of prefix caching or similar

ErfanBaghaei pushed a commit to ErfanBaghaei/transformers that referenced this pull request Sep 25, 2025
)

* merge

* get rid of tensors in get_mask_sizes!!

* remove branch

* add comment explanation

* re-add the class with deprecation cycle
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/transformers that referenced this pull request Oct 2, 2025
)

* merge

* get rid of tensors in get_mask_sizes!!

* remove branch

* add comment explanation

* re-add the class with deprecation cycle
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Oct 4, 2025
)

* merge

* get rid of tensors in get_mask_sizes!!

* remove branch

* add comment explanation

* re-add the class with deprecation cycle
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants