Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,16 @@ def sdpa_mask_recent_torch(
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
return None

# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
if mask_function is bidirectional_mask_function:
if padding_mask is not None:
# used for slicing without data-dependent slicing
mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
else:
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)

# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
Expand Down Expand Up @@ -485,6 +495,14 @@ def sdpa_mask_older_torch(
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
return None

# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
if mask_function is bidirectional_mask_function:
if padding_mask is not None:
return padding_mask[:, None, None, :].expand(-1, -1, q_length, -1)
else:
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)

# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
Expand Down