-
Notifications
You must be signed in to change notification settings - Fork 30.6k
Open
Labels
Description
System Info
transformers
version: 4.57.0.dev0- Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
- Python version: 3.11.11
- Huggingface_hub version: 1.0.0.rc1
- Safetensors version: 0.6.2
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: no
- GPU type: NVIDIA A100-SXM4-80GB
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Here's a minimal example showing that the sliding attention leaves more key-value pairs able to be attended to.
import torch
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
seq_len = 16
sliding_window = 4
cumulative_seqlens_q = torch.tensor([0, seq_len])
cumulative_seqlens_k = torch.tensor([0, seq_len])
# initialize masks to all ones -- attention is allowed everywhere
window_mask = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32)
full_mask = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32)
# build_attention_mask converts this to 0 & -inf
build_attention_mask(window_mask, cumulative_seqlens_q, cumulative_seqlens_k, sliding_window=sliding_window)
build_attention_mask(full_mask, cumulative_seqlens_q, cumulative_seqlens_k, sliding_window=1)
# entries that are still 0 allow the model to attend to that query-key pair
print("Key/Query pairs the model can attend to (full mask):", (full_mask == 0).sum().item())
print("Key/Query pairs the model can attend to (window mask):", (window_mask == 0).sum().item())
This outputs
Key/Query pairs the model can attend to (full mask): 136
Key/Query pairs the model can attend to (window mask): 202
Expected behavior
Using Continuous Batching with gpt-oss
gives meaningless results. I believe this is the reason.
My understanding is the true cause is this line sets elements of the mask from values of -inf to be 0, which means they are allowed to be attended to. Instead the window-attention should be more restrictive.