-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Closed
Closed
Copy link
Labels
Description
System Info
transformers version: 4.46.3
- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.10
- Python version: 3.8.18
- Huggingface_hub version: 0.25.2
- Safetensors version: 0.4.5
- Accelerate version: 1.0.1
- Accelerate config: - compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- use_cpu: False
- debug: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False - PyTorch version (GPU?): 2.4.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: False
- Using GPU in script?: True
- GPU type: NVIDIA A800-SXM4-40GB
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Hi,
I am experiencing a significant increase in computation time when using an attention mask with the WhisperSdpaAttention in the transformers library. I am not sure if this is expected behavior or a potential bug. Below is the code I used to test this:
import torch
import time
from transformers.models.whisper.modeling_whisper import WhisperSdpaAttention
def build_mask(x, x_lens):
batch_size = x_lens.size(0)
max_seq_len = x_lens.max()
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=x_lens.dtype,
device=x_lens.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = x_lens.unsqueeze(1).expand(batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_attention_mask = audio_attention_mask_.to(
dtype=x.dtype,
device=x_lens.device,
)
audio_attention_mask[audio_attention_mask_] = float("-inf")
return audio_attention_mask
device = torch.device("cuda:0")
x = torch.randn(2, 200, 128).half().to(device)
x_lens = torch.tensor([200, 160]).long().to(device)
attn1 = WhisperSdpaAttention(embed_dim=128, num_heads=1, is_causal=False)
attn1.to(device).half()
with torch.no_grad():
begin = time.time()
z = attn1(x)
print("sdpa without mask: ", time.time() - begin)
begin = time.time()
mask = build_mask(x, x_lens).to(device)
out = attn1(x, attention_mask=mask)
print("sdpa with mask: ", time.time() - begin)
The output times are as follows:
SDPA without mask: 0.028657197952270508
SDPA with mask: 0.13893771171569824
Expected behavior
As you can see, the computation time increases significantly when an attention mask is used. Could you please let me know if this is expected behavior or if there might be an issue with the implementation?
Thank you!