Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions benchmarks/benchmark_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,51 @@

def prepare_dynamic_mask(
hidden_states: torch.Tensor,
dt_states: torch.Tensor,
zoh_states: torch.Tensor,
keep_window_size: int = 2048,
attention_mask: torch.Tensor | None = None,
):
"""
The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention.
Calculate dynamic attention mask to mask tokens for sparse attention.

Combine `dt_states` with `attention_mask` to generate the final `attn_mask`.
Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`.

Args:
hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`.
keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
hidden_states: Input hidden states to determine dtype minimum value
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter description lacks explanation of what 'zoh_states' represents functionally. Consider adding a brief description of its purpose in the dynamic masking process.

Suggested change
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
zoh_states: Binary tensor indicating zones of high attention, of shape
(batch_size, num_kv_heads, key_sequence_length). Used to guide the
dynamic masking process by identifying key positions to retain.

Copilot uses AI. Check for mistakes.
keep_window_size: Window size of tokens not dynamically masked
attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len)

Returns:
tuple: (attn_bias, attn_mask)
"""
min_dtype = torch.finfo(hidden_states.dtype).min
dtype = hidden_states.dtype
attn_mask = dt_states[:, :, None, :].expand(
attn_bias = zoh_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[2], -1
) # [batch_size, num_heads, query_len, key_len]
active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device)
) # [batch_size, num_kv_heads, query_len, key_len]

if attention_mask is not None:
if attention_mask.dtype == torch.bool:
attention_mask = torch.where(
attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype
attention_mask,
torch.tensor(0.0, device=attention_mask.device, dtype=dtype),
min_dtype
)
attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype)
if attn_mask.shape[-1] > keep_window_size:
attn_bias = attn_bias.masked_fill(
attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype
)

if attn_bias.shape[-1] > keep_window_size:
topk_indices = torch.topk(
attn_mask, keep_window_size, dim=-1, largest=True, sorted=False
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
).indices
active_mask = active_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype)
return attn_mask, active_mask
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, 1.0)
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
return attn_bias, attn_mask


def dynamic_mask_attention_cuda(
Expand All @@ -60,7 +71,7 @@ def dynamic_mask_attention_cuda(

dt_states = torch.matmul(value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), dt_proj.T)
dt_states = torch.exp(A * F.softplus(dt_states)).transpose(-1, -2)
attn_mask, _ = prepare_dynamic_mask(
attn_bias, attn_mask = prepare_dynamic_mask(
query_states, dt_states, keep_window_size=keep_window_size, attention_mask=causal_mask
) # [batch_size, num_kv_heads, query_len, key_len]

Expand All @@ -76,16 +87,15 @@ def dynamic_mask_attention_cuda(
if len(non_mask_indices) == 0:
continue

k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]

q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]
q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]
k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
Comment on lines +90 to +92
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Moving the q_vec assignment before k_vecs and v_vecs creates an inconsistent ordering. The original placement after k_vecs and v_vecs was more logical as it groups the vector extractions together before their usage.

Suggested change
q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]
k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]

Copilot uses AI. Check for mistakes.

# QK dot product
attn_weight = torch.sum(q_vec.unsqueeze(0) * k_vecs, dim=-1)

# Apply scaling and dynamic_mask
attn_weight = attn_weight * scaling + attn_mask[b_idx, h_idx, q_idx, non_mask_indices]
attn_weight = attn_weight * scaling + attn_bias[b_idx, h_idx, q_idx, non_mask_indices]

# Softmax
attn_weight = F.softmax(attn_weight, dim=-1)
Expand Down Expand Up @@ -118,12 +128,11 @@ def dynamic_mask_attention_python(

dt_states = torch.matmul(value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), dt_proj.T)
dt_states = torch.exp(A * F.softplus(dt_states)).transpose(-1, -2)
attn_mask, _ = prepare_dynamic_mask(
attn_bias, _ = prepare_dynamic_mask(
query_states, dt_states, keep_window_size=keep_window_size, attention_mask=causal_mask
) # [batch_size, num_kv_heads, query_len, key_len]

attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # [batch_size, num_heads, query_len, key_len]
attn_weights = attn_weights * scaling + attn_mask # Apply scaling and dynamic_mask
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and dynamic_mask

attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization
attn_outputs = torch.matmul(attn_weights, value_states) # [batch_size, num_heads, query_len, head_dim]
Expand Down