Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 71 additions & 83 deletions benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import gc
import sys

from flash_sparse_attn.utils.mask import create_mask

# Import the compiled CUDA extension
try:
from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func
Expand Down Expand Up @@ -65,42 +67,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def prepare_mask(
hidden_states: torch.Tensor,
attn_bias: torch.Tensor,
causal_mask: torch.Tensor = None,
window_size: int = None,
):
"""
Args:
hidden_states: Input hidden states to determine dtype minimum value
attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length)
causal_mask: Optional causal mask to apply
window_size: Window size of tokens not masked

Returns:
tuple: (attn_bias, attn_mask)
"""
dtype = hidden_states.dtype
min_dtype = torch.finfo(dtype).min

if attn_bias.shape[-1] > window_size:
if causal_mask is not None:
topk_values, topk_indices = torch.topk(
attn_bias.masked_fill(~causal_mask, min_dtype).detach(),
window_size, dim=-1, largest=True, sorted=False
)
else:
topk_values, topk_indices = torch.topk(
attn_bias,
window_size, dim=-1, largest=True, sorted=False
)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype)
else:
attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


def dynamic_mask_attention_python(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -127,32 +93,38 @@ def dynamic_mask_attention_python(
Returns:
tuple: (attn_outputs, dq, dk, dv, dbias)
"""
_, num_heads, _, _ = query_states.shape
_, num_kv_heads, _, _ = key_states.shape
batch_size, num_heads, query_len, _ = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape

num_queries_per_kv = num_heads // num_kv_heads

attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=causal_mask if is_causal else None,
batch_size=batch_size,
query_len=query_len,
key_len=key_len,
window_size=window_size,
min_dtype=torch.finfo(query_states.dtype).min,
type="topk"
)

query_states_leaf = query_states
key_states_leaf = key_states
value_states_leaf = value_states

attn_bias, attn_mask = prepare_mask(
query_states,
attn_bias,
causal_mask if is_causal else None,
window_size,
)
attn_bias_leaf = attn_bias
attn_bias_leaf.retain_grad()

key_states = repeat_kv(key_states, num_queries_per_kv)
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)

# Sparse attention weight calculation
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias
attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask
if attn_mask is not None:
attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask
attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization
attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values
attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim]
Expand Down Expand Up @@ -192,16 +164,25 @@ def dynamic_mask_attention_cuda(
if flash_sparse_attn_func is None:
raise ImportError("CUDA implementation not available")

batch_size, num_heads, query_len, _ = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape

num_queries_per_kv = num_heads // num_kv_heads
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable num_queries_per_kv is not used.

Suggested change
num_queries_per_kv = num_heads // num_kv_heads

Copilot uses AI. Check for mistakes.

attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=causal_mask if is_causal else None,
batch_size=batch_size,
query_len=query_len,
key_len=key_len,
window_size=window_size,
min_dtype=torch.finfo(query_states.dtype).min,
type="topk"
)

query_states_leaf = query_states
key_states_leaf = key_states
value_states_leaf = value_states

attn_bias, attn_mask = prepare_mask(
query_states,
attn_bias,
causal_mask if is_causal else None,
window_size,
)
attn_bias_leaf = attn_bias
attn_bias_leaf.retain_grad()

Expand Down Expand Up @@ -259,29 +240,28 @@ def dynamic_mask_attention_triton(
if triton_sparse_attn_func is None:
raise RuntimeError("Triton implementation not available")

_, num_heads, _, _ = query_states.shape
_, num_kv_heads, _, _ = key_states.shape
batch_size, num_heads, query_len, _ = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape

num_queries_per_kv = num_heads // num_kv_heads

attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=causal_mask if is_causal else None,
batch_size=batch_size,
query_len=query_len,
key_len=key_len,
window_size=window_size,
min_dtype=torch.finfo(query_states.dtype).min,
type="topk"
)

query_states_leaf = query_states
key_states_leaf = key_states
value_states_leaf = value_states

attn_bias, attn_mask = prepare_mask(
query_states,
attn_bias,
causal_mask if is_causal else None,
window_size,
)
attn_bias_leaf = attn_bias
attn_bias_leaf.retain_grad()

# Repeat KV for multi-head attention (GQA support)
key_states = repeat_kv(key_states, num_queries_per_kv)
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)

# Ensure correct data types and memory layout for Triton function
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
Expand Down Expand Up @@ -333,30 +313,38 @@ def dynamic_mask_attention_flex(
if flex_sparse_attn_func is None:
raise RuntimeError("Flex Attention implementation not available")

_, num_heads, _, _ = query_states.shape
_, num_kv_heads, _, _ = key_states.shape
batch_size, num_heads, query_len, _ = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape

num_queries_per_kv = num_heads // num_kv_heads

attn_bias, attn_mask = prepare_mask(
query_states,
attn_bias,
causal_mask if is_causal else None,
window_size,
attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=causal_mask if is_causal else None,
batch_size=batch_size,
query_len=query_len,
key_len=key_len,
window_size=window_size,
min_dtype=torch.finfo(query_states.dtype).min,
type="topk"
)
attn_bias.retain_grad()

query_states_leaf = query_states
key_states_leaf = key_states
value_states_leaf = value_states
attn_bias_leaf = attn_bias
attn_bias_leaf.retain_grad()

# Repeat KV for multi-head attention (GQA support)
key_states = repeat_kv(key_states, num_queries_per_kv)
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None
attn_bias = repeat_kv(attn_bias, num_queries_per_kv)

# Ensure correct data types and memory layout for Flex function
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]

# Call the Flex Attention implementation
attn_outputs = flex_sparse_attn_func(
Expand All @@ -372,7 +360,7 @@ def dynamic_mask_attention_flex(
# Backward pass
attn_outputs.sum().backward()

return attn_outputs, query_states.grad, key_states.grad, value_states.grad, attn_bias.grad
return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad


def analyze_differences(original_result, cuda_result, accuracy_threshold=0.95):
Expand Down Expand Up @@ -609,7 +597,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
device=device, dtype=dtype, requires_grad=True
)
attn_bias = torch.randn(
batch_size, num_kv_heads, query_len, key_len,
batch_size, num_kv_heads, 1, key_len,
device=device, dtype=torch.bfloat16
)
cache_position = torch.arange(key_len - query_len, key_len, device=device)
Expand Down Expand Up @@ -843,7 +831,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95):
device=device, dtype=dtype, requires_grad=True
)
attn_bias = torch.randn(
batch_size, num_kv_heads, query_len, key_len,
batch_size, num_kv_heads, 1, key_len,
device=device, dtype=torch.bfloat16
)
cache_position = torch.arange(key_len - query_len, key_len, device=device)
Expand Down
Loading
Loading