Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 24 additions & 37 deletions benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def prepare_dynamic_mask(
hidden_states: torch.Tensor,
zoh_states: torch.Tensor,
keep_window_size: int = 2048,
attention_mask: torch.Tensor | None = None,
cache_position: torch.Tensor = None,
):
"""
Calculate dynamic attention mask to mask tokens for sparse attention.
Expand All @@ -65,28 +65,23 @@ def prepare_dynamic_mask(
hidden_states: Input hidden states to determine dtype minimum value
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
keep_window_size: Window size of tokens not dynamically masked
attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len)
cache_position: Optional cache position for causal masking

Returns:
tuple: (attn_bias, attn_mask)
"""
min_dtype = torch.finfo(hidden_states.dtype).min
dtype = hidden_states.dtype
min_dtype = torch.finfo(dtype).min
attn_bias = zoh_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[2], -1
) # [batch_size, num_kv_heads, query_len, key_len]
).to(dtype) # [batch_size, num_kv_heads, query_len, key_len]

if attention_mask is not None:
if attention_mask.dtype == torch.bool:
attention_mask = torch.where(
attention_mask,
torch.tensor(0.0, device=attention_mask.device, dtype=dtype),
min_dtype
)
if cache_position is not None:
attn_bias = attn_bias.masked_fill(
attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype
torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1),
min_dtype
)

if attn_bias.shape[-1] > keep_window_size:
topk_values, topk_indices = torch.topk(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
Expand Down Expand Up @@ -150,7 +145,7 @@ def dynamic_mask_attention_python(
dt_proj: torch.Tensor,
A: torch.Tensor,
scaling: float,
causal_mask: torch.Tensor,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
Expand All @@ -165,7 +160,7 @@ def dynamic_mask_attention_python(
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
scaling: Attention scaling factor
causal_mask: Causal attention mask
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking
Expand All @@ -188,7 +183,7 @@ def dynamic_mask_attention_python(
query_states,
zoh_states,
keep_window_size,
causal_mask if is_causal else None
cache_position if is_causal else None
)
attn_bias_leaf = attn_bias
attn_bias_leaf.retain_grad()
Expand Down Expand Up @@ -218,7 +213,7 @@ def dynamic_mask_attention_cuda(
dt_proj: torch.Tensor,
A: torch.Tensor,
scaling: float,
causal_mask: torch.Tensor,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
Expand All @@ -233,7 +228,7 @@ def dynamic_mask_attention_cuda(
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
scaling: Attention scaling factor
causal_mask: Causal attention mask
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking
Expand All @@ -256,7 +251,7 @@ def dynamic_mask_attention_cuda(
query_states,
zoh_states,
keep_window_size,
causal_mask if is_causal else None
cache_position if is_causal else None
) # [batch_size, num_kv_heads, query_len, key_len]
attn_bias_leaf = attn_bias
attn_bias_leaf.retain_grad()
Expand Down Expand Up @@ -294,7 +289,7 @@ def dynamic_mask_attention_triton(
dt_proj: torch.Tensor,
A: torch.Tensor,
scaling: float,
causal_mask: torch.Tensor,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
Expand All @@ -309,7 +304,7 @@ def dynamic_mask_attention_triton(
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
scaling: Attention scaling factor
causal_mask: Causal attention mask
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking
Expand All @@ -336,7 +331,7 @@ def dynamic_mask_attention_triton(
query_states,
zoh_states,
keep_window_size,
causal_mask if is_causal else None
cache_position if is_causal else None
) # [batch_size, num_kv_heads, query_len, key_len]
attn_bias_leaf = attn_bias
attn_bias_leaf.retain_grad()
Expand Down Expand Up @@ -378,7 +373,7 @@ def dynamic_mask_attention_flex(
dt_proj: torch.Tensor,
A: torch.Tensor,
scaling: float,
causal_mask: torch.Tensor,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
Expand All @@ -393,7 +388,7 @@ def dynamic_mask_attention_flex(
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
scaling: Attention scaling factor
causal_mask: Causal attention mask
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking
Expand All @@ -416,7 +411,7 @@ def dynamic_mask_attention_flex(
query_states,
zoh_states,
keep_window_size,
causal_mask if is_causal else None
cache_position if is_causal else None
) # [batch_size, num_kv_heads, query_len, key_len]
attn_bias.retain_grad()

Expand Down Expand Up @@ -673,16 +668,8 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
)
A = torch.randn(num_kv_heads, device=device, dtype=dtype, requires_grad=True)

# Create custom causal mask with cache position
# Create cache position
cache_position = torch.arange(key_len - query_len, key_len, device=device)
min_type = torch.finfo(value_states.dtype).min
causal_mask = torch.full(
(query_len, key_len), fill_value=min_type,
device=device, dtype=value_states.dtype
)
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

# Set scaling factor and keep window size
scaling = head_dim ** -0.5
Expand All @@ -705,7 +692,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
start_time = time.time()
attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python(
query_python, key_python, value_python, dt_proj_python, A_python,
scaling, causal_mask, dout.clone(), keep_window_size, is_causal
scaling, cache_position, dout.clone(), keep_window_size, is_causal
)
torch.cuda.synchronize()
py_time = time.time() - start_time
Expand All @@ -722,7 +709,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
start_time = time.time()
attn_outputs_cuda, dq_cuda, dk_cuda, dv_cuda, dbias_cuda = dynamic_mask_attention_cuda(
query_cuda, key_cuda, value_cuda, dt_proj_cuda, A_cuda,
scaling, causal_mask, dout.clone(), keep_window_size, is_causal
scaling, cache_position, dout.clone(), keep_window_size, is_causal
)
torch.cuda.synchronize()
cuda_time = time.time() - start_time
Expand Down Expand Up @@ -787,7 +774,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
if not is_close and max_dbias_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
del query_states, key_states, value_states, dt_proj, A, causal_mask, dout, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
del query_states, key_states, value_states, dt_proj, A, cache_position, dout, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()
Expand Down
Loading