Skip to content

Significant Increase in Computation Time When Using Attention Mask in SDPA Attention #36584

@tartarleft

Description

@tartarleft

System Info

transformers version: 4.46.3

  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.10
  • Python version: 3.8.18
  • Huggingface_hub version: 0.25.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: DEEPSPEED
    - use_cpu: False
    - debug: False
    - num_processes: 8
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
  • PyTorch version (GPU?): 2.4.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: False
  • Using GPU in script?: True
  • GPU type: NVIDIA A800-SXM4-40GB

Who can help?

@ylacombe, @eustlb

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi,

I am experiencing a significant increase in computation time when using an attention mask with the WhisperSdpaAttention in the transformers library. I am not sure if this is expected behavior or a potential bug. Below is the code I used to test this:

import torch
import time
from transformers.models.whisper.modeling_whisper import WhisperSdpaAttention

def build_mask(x, x_lens):
    batch_size = x_lens.size(0)
    max_seq_len = x_lens.max()
    # Create a sequence tensor of shape (batch_size, max_seq_len)
    seq_range = (
        torch.arange(
            0,
            max_seq_len,
            dtype=x_lens.dtype,
            device=x_lens.device,
        )
        .unsqueeze(0)
        .expand(batch_size, max_seq_len)
    )
    lengths_expand = x_lens.unsqueeze(1).expand(batch_size, max_seq_len)
    # Create mask
    padding_mask = seq_range >= lengths_expand

    audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
        batch_size, 1, max_seq_len, max_seq_len
    )
    audio_attention_mask = audio_attention_mask_.to(
        dtype=x.dtype,
        device=x_lens.device,
    )
    audio_attention_mask[audio_attention_mask_] = float("-inf")

    return audio_attention_mask

device = torch.device("cuda:0")
x = torch.randn(2, 200, 128).half().to(device)
x_lens = torch.tensor([200, 160]).long().to(device)

attn1 = WhisperSdpaAttention(embed_dim=128, num_heads=1, is_causal=False)
attn1.to(device).half()

with torch.no_grad():
    begin = time.time()
    z = attn1(x)
    print("sdpa without mask: ", time.time() - begin)

    begin = time.time()
    mask = build_mask(x, x_lens).to(device)
    out = attn1(x, attention_mask=mask)
    print("sdpa with mask: ", time.time() - begin)

The output times are as follows:

SDPA without mask: 0.028657197952270508
SDPA with mask: 0.13893771171569824

Expected behavior

As you can see, the computation time increases significantly when an attention mask is used. Could you please let me know if this is expected behavior or if there might be an issue with the implementation?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions