diff --git a/README.md b/README.md index 89c6082..c163452 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A ## Key Features ### 🎯 Core Kernel Advantages -- **Mask & Bias Support**: Native support for `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` shaped attention mask and attention bias tensors +- **Mask & Bias Support**: Native support for `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped attention mask and attention bias tensors - **Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks - **Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training @@ -236,9 +236,9 @@ Flash-DMA integrates the efficient memory access patterns of Flash Attention wit ### Core Technology Integration -- **🎯 Native Mask & Bias Support**: Kernels directly process `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` shaped tensors +- **🎯 Native Mask & Bias Support**: Kernels directly process `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped tensors - **⚡ Block-level Intelligent Skipping**: Unified OR-reduction skipping logic based on masks, completely avoiding computation and memory access for zero blocks -- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation (dbias) supporting end-to-end differentiable training +- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation supporting end-to-end differentiable training ### Key Optimization Strategies diff --git a/README_zh.md b/README_zh.md index 8e16c41..2550652 100644 --- a/README_zh.md +++ b/README_zh.md @@ -18,7 +18,7 @@ Flash-DMA 是一个高性能的注意力实现,将 Flash Attention 的内存 ## 主要特性 ### 🎯 核心内核优势 -- **Mask & Bias 支持**: 原生支持 `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` 形状的 attention_mask 和 attention_bias 张量 +- **Mask & Bias 支持**: 原生支持 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的 attention_mask 和 attention_bias 张量 - **智能计算跳过**: 基于 attention_mask 的 block-level 自动跳过机制,完全跳过全零 mask 区块的计算和内存访问 - **完整梯度支持**: 内置 attention_bias 的完整梯度计算路径,支持端到端训练 @@ -236,7 +236,7 @@ Flash-DMA 通过将 Flash Attention 的高效内存访问模式与动态掩码 ### 核心技术融合 -- **🎯 Mask & Bias 原生支持**: 内核直接处理 `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` 形状的张量 +- **🎯 Mask & Bias 原生支持**: 内核直接处理 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的张量 - **⚡ Block-level 智能跳过**: 基于 mask 的统一 OR-reduction 跳过逻辑,完全避免全零区块的计算和内存访问 - **🔄 完整梯度链路**: 内置 attention bias 梯度计算,支持端到端可微分训练 diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index fd8ebef..2149a8e 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -50,104 +50,66 @@ flex_dmattn_func = None -def prepare_dynamic_mask( - hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, -): +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ - Calculate dynamic attention mask to mask tokens for sparse attention. + Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Transform from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. +def prepare_mask( + hidden_states: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, +): + """ Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - Transform from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Python reference implementation of dynamic mask attention backward pass. @@ -156,11 +118,10 @@ def dynamic_mask_attention_python( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -174,29 +135,27 @@ def dynamic_mask_attention_python( key_states_leaf = key_states value_states_leaf = value_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask function to process dynamic mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None + attn_bias, + causal_mask if is_causal else None, + window_size, ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() - - # Sparse attention weight calculation + key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) - attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh - attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization - attn_outputs = torch.matmul(attn_weights, value_states) - attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] + # Sparse attention weight calculation + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights + attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization + attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] # Backward pass attn_outputs.sum().backward() @@ -208,12 +167,11 @@ def dynamic_mask_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ CUDA implementation of dynamic mask attention backward pass. @@ -222,11 +180,10 @@ def dynamic_mask_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -239,36 +196,31 @@ def dynamic_mask_attention_cuda( key_states_leaf = key_states value_states_leaf = value_states - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] # Call the flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query=query_states, # q: [batch, query_len, num_heads, head_dim] - key=key_states, # k: [batch, key_len, num_kv_heads, head_dim] - value=value_states, # v: [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len] - is_causal=is_causal, # causal masking - softmax_scale=scaling, # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, softcap=0.0, - deterministic=False, + deterministic=True, return_attn_probs=False ) @@ -282,12 +234,11 @@ def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Triton implementation of dynamic mask attention backward pass. @@ -296,11 +247,10 @@ def dynamic_mask_attention_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -317,16 +267,12 @@ def dynamic_mask_attention_triton( key_states_leaf = key_states value_states_leaf = value_states - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() @@ -336,7 +282,7 @@ def dynamic_mask_attention_triton( attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format + # Ensure correct data types and memory layout for Triton function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] @@ -345,13 +291,13 @@ def dynamic_mask_attention_triton( # Call the Triton implementation attn_outputs = triton_dmattn_func( - query=query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key=key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value=value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) # Backward pass @@ -364,12 +310,11 @@ def dynamic_mask_attention_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Flex Attention implementation of dynamic mask attention backward pass. @@ -378,11 +323,10 @@ def dynamic_mask_attention_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -395,16 +339,12 @@ def dynamic_mask_attention_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) attn_bias.retain_grad() # Repeat KV for multi-head attention (GQA support) @@ -413,18 +353,22 @@ def dynamic_mask_attention_flex( attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) # Backward pass @@ -599,35 +543,33 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 128, False), (1, 2, 1, 4096, 4096, 128, True), - # # Head dim 192 - # Not enough shared memory for head_dim=192 in bwd yet - # (1, 2, 1, 128, 128, 192, False), - # (1, 2, 1, 128, 128, 192, True), - # (1, 2, 1, 256, 256, 192, False), - # (1, 2, 1, 256, 256, 192, True), - # (1, 2, 1, 512, 512, 192, False), - # (1, 2, 1, 512, 512, 192, True), - # (1, 2, 1, 1024, 1024, 192, False), - # (1, 2, 1, 1024, 1024, 192, True), - # (1, 2, 1, 2048, 2048, 192, False), - # (1, 2, 1, 2048, 2048, 192, True), - # (1, 2, 1, 4096, 4096, 192, False), - # (1, 2, 1, 4096, 4096, 192, True), + # Head dim 192 + (1, 2, 1, 128, 128, 192, False), + (1, 2, 1, 128, 128, 192, True), + (1, 2, 1, 256, 256, 192, False), + (1, 2, 1, 256, 256, 192, True), + (1, 2, 1, 512, 512, 192, False), + (1, 2, 1, 512, 512, 192, True), + (1, 2, 1, 1024, 1024, 192, False), + (1, 2, 1, 1024, 1024, 192, True), + (1, 2, 1, 2048, 2048, 192, False), + (1, 2, 1, 2048, 2048, 192, True), + (1, 2, 1, 4096, 4096, 192, False), + (1, 2, 1, 4096, 4096, 192, True), # Head dim 256 - # Not enough shared memory for head_dim=256 in bwd yet - # (1, 2, 1, 128, 128, 256, False), - # (1, 2, 1, 128, 128, 256, True), - # (1, 2, 1, 256, 256, 256, False), - # (1, 2, 1, 256, 256, 256, True), - # (1, 2, 1, 512, 512, 256, False), - # (1, 2, 1, 512, 512, 256, True), - # (1, 2, 1, 1024, 1024, 256, False), - # (1, 2, 1, 1024, 1024, 256, True), - # (1, 2, 1, 2048, 2048, 256, False), - # (1, 2, 1, 2048, 2048, 256, True), - # (1, 2, 1, 4096, 4096, 256, False), - # (1, 2, 1, 4096, 4096, 256, True), + (1, 2, 1, 128, 128, 256, False), + (1, 2, 1, 128, 128, 256, True), + (1, 2, 1, 256, 256, 256, False), + (1, 2, 1, 256, 256, 256, True), + (1, 2, 1, 512, 512, 256, False), + (1, 2, 1, 512, 512, 256, True), + (1, 2, 1, 1024, 1024, 256, False), + (1, 2, 1, 1024, 1024, 256, True), + (1, 2, 1, 2048, 2048, 256, False), + (1, 2, 1, 2048, 2048, 256, True), + (1, 2, 1, 4096, 4096, 256, False), + (1, 2, 1, 4096, 4096, 256, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -668,48 +610,48 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=dtype, requires_grad=True ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, - device=device, dtype=dtype, requires_grad=True + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, + device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=dtype, requires_grad=True) - - # Create cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 1024 + window_size = 1024 # Clone inputs for Python implementation query_python = query_states.clone().detach().requires_grad_(True) key_python = key_states.clone().detach().requires_grad_(True) value_python = value_states.clone().detach().requires_grad_(True) - dt_proj_python = dt_proj.clone().detach().requires_grad_(True) - A_python = A.clone().detach().requires_grad_(True) + attn_bias_python = attn_bias.clone().detach().requires_grad_(True) + causal_mask_python = causal_mask.clone().detach() # Run Python implementation start_time = time.time() attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python( - query_python, key_python, value_python, dt_proj_python, A_python, - scaling, cache_position, keep_window_size, is_causal + query_python, key_python, value_python, + attn_bias_python, causal_mask_python, + scaling, window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time - - + # Clone inputs for CUDA implementation query_cuda = query_states.clone().detach().requires_grad_(True) key_cuda = key_states.clone().detach().requires_grad_(True) value_cuda = value_states.clone().detach().requires_grad_(True) - dt_proj_cuda = dt_proj.clone().detach().requires_grad_(True) - A_cuda = A.clone().detach().requires_grad_(True) - + attn_bias_cuda = attn_bias.clone().detach().requires_grad_(True) + causal_mask_cuda = causal_mask.clone().detach() + # Run CUDA implementation start_time = time.time() attn_outputs_cuda, dq_cuda, dk_cuda, dv_cuda, dbias_cuda = dynamic_mask_attention_cuda( - query_cuda, key_cuda, value_cuda, dt_proj_cuda, A_cuda, - scaling, cache_position, keep_window_size, is_causal + query_cuda, key_cuda, value_cuda, + attn_bias_cuda, causal_mask_cuda, + scaling, window_size, is_causal ) torch.cuda.synchronize() cuda_time = time.time() - start_time @@ -774,7 +716,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): if not is_close and max_dbias_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - del query_states, key_states, value_states, dt_proj, A, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda + del query_states, key_states, value_states, attn_bias, causal_mask, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda torch.cuda.empty_cache() gc.collect() torch.cuda.synchronize() @@ -872,7 +814,3 @@ def main(): if __name__ == "__main__": main() - - - - diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index 03d5018..82deb8c 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -72,124 +72,100 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_dynamic_mask( +def prepare_mask( hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, ): """ - Calculate dynamic attention mask to mask tokens for sparse attention. - - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. - Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - -def scaled_dot_product_attention_backward( +def scaled_dot_product_attention_backward_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - scaling: float, + attn_bias: torch.Tensor, causal_mask: torch.Tensor, - is_causal=True, + scaling: float, + window_size: int, + is_causal: bool, ): """ - SDPA baseline backward pass implementation. + CUDA implementation of SDPA baseline. Args: query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - causal_mask: Causal attention mask is_causal: Whether to apply causal masking Returns: - tuple: (output_tensor, timing_ms) or ("OOM", 0) if out of memory + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, _, query_len, _ = query_states.shape - _, _, key_len, _ = key_states.shape - if query_len > 32768 and key_len > 32768: - return "OOM", 0 + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) + + # Repeat KV for multi-head attention (GQA support) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + attn_bias = attn_bias.masked_fill(~attn_mask, torch.finfo(query_states.dtype).min).contiguous() try: - # Forward pass - SDPA expects q, k, v in [batch, num_heads, seq_len, head_dim] format attn_outputs = F.scaled_dot_product_attention( - query_states, # [batch, num_heads, query_len, head_dim] - key_states, # [batch, num_kv_heads, key_len, head_dim] - value_states, # [batch, num_kv_heads, key_len, head_dim] - attn_mask=causal_mask, - softmax_scale=scaling, - # is_causal=is_causal if query_len == key_len else False, + query_states, + key_states, + value_states, + attn_mask=attn_bias, + scale=scaling, + # is_causal=is_causal, enable_gqa=True ) - # Transpose to match expected output format + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - + torch.cuda.synchronize() start_time = time.time() @@ -209,12 +185,11 @@ def dynamic_mask_attention_backward_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ CUDA implementation of dynamic mask attention backward pass. @@ -223,11 +198,10 @@ def dynamic_mask_attention_backward_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -236,33 +210,27 @@ def dynamic_mask_attention_backward_cuda( if flash_dmattn_func is None: return "Not Available", 0 - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] try: - # Call the flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query=query_states, # q: [batch, query_len, num_heads, head_dim] - key=key_states, # k: [batch, key_len, num_kv_heads, head_dim] - value=value_states, # v: [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # bias: [batch, num_kv_heads, query_len, key_len] - is_causal=is_causal, # causal masking - softmax_scale=scaling, # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, softcap=0.0, deterministic=False, return_attn_probs=False @@ -287,12 +255,11 @@ def dynamic_mask_attention_backward_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Triton implementation of dynamic mask attention backward pass. @@ -301,11 +268,10 @@ def dynamic_mask_attention_backward_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -318,40 +284,35 @@ def dynamic_mask_attention_backward_triton( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - - # Call the Triton implementation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Triton function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + try: attn_outputs = triton_dmattn_func( - query=query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key=key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value=value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query=query_states, + key=key_states, + value=value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -373,12 +334,11 @@ def dynamic_mask_attention_backward_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Flex Attention implementation of dynamic mask attention backward pass. @@ -387,11 +347,10 @@ def dynamic_mask_attention_backward_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_length, key_length] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -404,36 +363,35 @@ def dynamic_mask_attention_backward_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format - - # Call the Flex Attention implementation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + try: attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -470,7 +428,7 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 Benchmark backward attention performance for a given configuration. Args: - config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) test_type: Type of test to run ('all', 'sdpa', 'cuda', 'triton', 'flex', etc.) num_runs: Number of benchmark runs warmup_runs: Number of warmup runs @@ -478,7 +436,7 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 Returns: dict: Performance metrics """ - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create random input data (requires_grad=True for backward pass) @@ -494,21 +452,12 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16, requires_grad=True ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, - device=device, dtype=torch.bfloat16, requires_grad=True + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, + device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16, requires_grad=True) - - # Create custom causal mask with cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) - min_type = torch.finfo(value_states.dtype).min - causal_mask = torch.full( - (query_len, key_len), fill_value=min_type, - device=device, dtype=value_states.dtype - ) - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor from config @@ -543,13 +492,16 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - - result = scaled_dot_product_attention_backward( - q_clone, k_clone, v_clone, scaling, causal_mask, is_causal + query_sdpa = query_states.clone().detach().requires_grad_(True) + key_sdpa = key_states.clone().detach().requires_grad_(True) + value_sdpa = value_states.clone().detach().requires_grad_(True) + attn_bias_sdpa = attn_bias.clone().detach().requires_grad_(True) + causal_mask_sdpa = causal_mask.clone().detach() + + result = scaled_dot_product_attention_backward_cuda( + query_sdpa, key_sdpa, value_sdpa, + attn_bias_sdpa, causal_mask_sdpa, + scaling, window_size, is_causal ) if result[0] == "OOM": results['sdpa_backward_status'] = 'OOM' @@ -562,13 +514,16 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - - result = scaled_dot_product_attention_backward( - q_clone, k_clone, v_clone, scaling, causal_mask, is_causal + query_sdpa = query_states.clone().detach().requires_grad_(True) + key_sdpa = key_states.clone().detach().requires_grad_(True) + value_sdpa = value_states.clone().detach().requires_grad_(True) + attn_bias_sdpa = attn_bias.clone().detach().requires_grad_(True) + causal_mask_sdpa = causal_mask.clone().detach() + + result = scaled_dot_product_attention_backward_cuda( + query_sdpa, key_sdpa, value_sdpa, + attn_bias_sdpa, causal_mask_sdpa, + scaling, window_size, is_causal ) if result[0] == "OOM": @@ -591,16 +546,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_cuda = query_states.clone().detach().requires_grad_(True) + key_cuda = key_states.clone().detach().requires_grad_(True) + value_cuda = value_states.clone().detach().requires_grad_(True) + attn_bias_cuda = attn_bias.clone().detach().requires_grad_(True) + causal_mask_cuda = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_cuda( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_cuda, key_cuda, value_cuda, attn_bias_cuda, causal_mask_cuda, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_cuda_backward_status'] = result[0] @@ -613,16 +567,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_cuda = query_states.clone().detach().requires_grad_(True) + key_cuda = key_states.clone().detach().requires_grad_(True) + value_cuda = value_states.clone().detach().requires_grad_(True) + attn_bias_cuda = attn_bias.clone().detach().requires_grad_(True) + causal_mask_cuda = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_cuda( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_cuda, key_cuda, value_cuda, attn_bias_cuda, causal_mask_cuda, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -645,16 +598,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_triton = query_states.clone().detach().requires_grad_(True) + key_triton = key_states.clone().detach().requires_grad_(True) + value_triton = value_states.clone().detach().requires_grad_(True) + attn_bias_triton = attn_bias.clone().detach().requires_grad_(True) + causal_mask_triton = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_triton( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_triton, key_triton, value_triton, attn_bias_triton, causal_mask_triton, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_triton_backward_status'] = result[0] @@ -667,16 +619,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_triton = query_states.clone().detach().requires_grad_(True) + key_triton = key_states.clone().detach().requires_grad_(True) + value_triton = value_states.clone().detach().requires_grad_(True) + attn_bias_triton = attn_bias.clone().detach().requires_grad_(True) + causal_mask_triton = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_triton( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_triton, key_triton, value_triton, attn_bias_triton, causal_mask_triton, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -699,16 +650,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Warmup runs for _ in range(warmup_runs): - # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_flex = query_states.clone().detach().requires_grad_(True) + key_flex = key_states.clone().detach().requires_grad_(True) + value_flex = value_states.clone().detach().requires_grad_(True) + attn_bias_flex = attn_bias.clone().detach().requires_grad_(True) + causal_mask_flex = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_flex( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_flex, key_flex, value_flex, attn_bias_flex, causal_mask_flex, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_flex_backward_status'] = result[0] @@ -722,15 +672,15 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 # Actual benchmark runs for _ in range(num_runs): # Clone inputs for each run - q_clone = query_states.clone().detach().requires_grad_(True) - k_clone = key_states.clone().detach().requires_grad_(True) - v_clone = value_states.clone().detach().requires_grad_(True) - dt_clone = dt_proj.clone().detach().requires_grad_(True) - a_clone = A.clone().detach().requires_grad_(True) - + query_flex = query_states.clone().detach().requires_grad_(True) + key_flex = key_states.clone().detach().requires_grad_(True) + value_flex = value_states.clone().detach().requires_grad_(True) + attn_bias_flex = attn_bias.clone().detach().requires_grad_(True) + causal_mask_flex = causal_mask.clone().detach() + result = dynamic_mask_attention_backward_flex( - q_clone, k_clone, v_clone, dt_clone, a_clone, - scaling, cache_position, keep_window_size, is_causal + query_flex, key_flex, value_flex, attn_bias_flex, causal_mask_flex, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -776,7 +726,7 @@ def run_backward_performance_benchmark(test_type='all', num_runs=3, warmup_runs= print(title) print("🏆" + "=" * 76 + "🏆") - # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) configs = [ # Vary sequence length (1, 2, 1, 256, 256, 64, 1024, True), @@ -805,7 +755,7 @@ def run_backward_performance_benchmark(test_type='all', num_runs=3, warmup_runs= (1, 2, 1, 16384, 16384, 96, 1024, True), (1, 2, 1, 16384, 16384, 128, 1024, True), - # Vary keep_window_size + # Vary window_size (1, 2, 1, 16384, 16384, 64, 32, True), (1, 2, 1, 16384, 16384, 64, 64, True), (1, 2, 1, 16384, 16384, 64, 128, True), @@ -830,8 +780,8 @@ def run_backward_performance_benchmark(test_type='all', num_runs=3, warmup_runs= all_results.append(results) # Format configuration string - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config - config_str = f"B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{keep_window_size} {'C' if is_causal else 'N'}" + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config + config_str = f"B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{window_size} {'C' if is_causal else 'N'}" # Calculate averages and format results sdpa_avg = f"{sum(results['sdpa_backward_times'])/len(results['sdpa_backward_times']):.2f}ms" if results['sdpa_backward_times'] else results['sdpa_backward_status'] diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 1da6f2a..8baff70 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -50,104 +50,66 @@ flex_dmattn_func = None -def prepare_dynamic_mask( - hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, -): +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Transform from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) """ - Calculate dynamic attention mask to mask tokens for sparse attention. + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. +def prepare_mask( + hidden_states: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, +): + """ Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - Transform from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Python reference implementation of dynamic mask attention. @@ -156,11 +118,10 @@ def dynamic_mask_attention_python( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -171,26 +132,25 @@ def dynamic_mask_attention_python( num_queries_per_kv = num_heads // num_kv_heads - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask function to process dynamic mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None + attn_bias, + causal_mask if is_causal else None, + window_size, ) - - # Sparse attention weight calculation + key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) - attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh - attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization - attn_outputs = torch.matmul(attn_weights, value_states) - attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + + # Sparse attention weight calculation + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights + attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization + attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] return attn_outputs @@ -199,13 +159,11 @@ def dynamic_mask_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, - return_softmax=False + window_size: int, + is_causal: bool, ): """ CUDA implementation of dynamic mask attention. @@ -214,13 +172,11 @@ def dynamic_mask_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking - return_softmax: Whether to return softmax weights Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] @@ -228,35 +184,30 @@ def dynamic_mask_attention_cuda( if flash_dmattn_func is None: raise RuntimeError("flash_dmattn_func not available") - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] # Call the flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query_states, # [batch, query_len, num_heads, head_dim] - key_states, # [batch, key_len, num_kv_heads, head_dim] - value_states, # [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=is_causal, softmax_scale=scaling, softcap=0.0, deterministic=True, - return_attn_probs=return_softmax + return_attn_probs=False, ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -266,12 +217,11 @@ def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Triton implementation of dynamic mask attention. @@ -280,11 +230,10 @@ def dynamic_mask_attention_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -297,16 +246,12 @@ def dynamic_mask_attention_triton( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) @@ -323,13 +268,13 @@ def dynamic_mask_attention_triton( # Call the Triton implementation attn_outputs = triton_dmattn_func( - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -339,12 +284,11 @@ def dynamic_mask_attention_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, - is_causal=True, + window_size: int, + is_causal: bool, ): """ Flex Attention implementation of dynamic mask attention. @@ -353,11 +297,10 @@ def dynamic_mask_attention_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_kv_heads, query_len, key_len] + causal_mask: [batch_size, 1, query_len, key_len] or None scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window + window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: @@ -370,16 +313,12 @@ def dynamic_mask_attention_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) @@ -387,18 +326,22 @@ def dynamic_mask_attention_flex( attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format - + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + # Call the Flex Attention implementation attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -611,18 +554,18 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): torch.cuda.synchronize() batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config - + # Progress indicator progress_filled = "█" * (i + 1) progress_empty = "░" * (len(test_configs) - i - 1) progress_bar = f"[{progress_filled}{progress_empty}]" - + print(f"\n🧪 Test configuration {i+1}/{len(test_configs)} {progress_bar}") print(f" 📊 batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}") print(f" 📏 query_len={query_len}, key_len={key_len}, head_dim={head_dim}") print(f" 🔒 is_causal={is_causal}") print(f" 🎯 Accuracy threshold: {accuracy_threshold*100:.1f}%") - + # Create random input data query_states = torch.randn( batch_size, num_heads, query_len, head_dim, @@ -636,40 +579,39 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) - + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 1024 + window_size = 1024 # Run Python implementation start_time = time.time() py_output = dynamic_mask_attention_python( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time - + # Run CUDA implementation start_time = time.time() cuda_output = dynamic_mask_attention_cuda( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() cuda_time = time.time() - start_time - - + + # Analyze differences py_output_copy = py_output.clone() cuda_output_copy = cuda_output.clone() @@ -692,7 +634,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): if not is_close and max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - del query_states, key_states, value_states, dt_proj, A, cache_position, py_output, cuda_output, py_output_copy, cuda_output_copy + del query_states, key_states, value_states, attn_bias, causal_mask, py_output, cuda_output, py_output_copy, cuda_output_copy torch.cuda.empty_cache() gc.collect() torch.cuda.synchronize() @@ -774,8 +716,8 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 4096, 4096, 128, False), # Not support head_dim > 128 in triton yet - # (1, 2, 1, 128, 128, 128, True), - # (1, 2, 1, 128, 128, 128, False), + # (1, 2, 1, 128, 128, 256, True), + # (1, 2, 1, 128, 128, 256, False), # (1, 2, 1, 256, 256, 256, True), # (1, 2, 1, 256, 256, 256, False), # (1, 2, 1, 512, 512, 256, True), @@ -825,25 +767,24 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 64 + window_size = 1024 # Run Python implementation start_time = time.time() py_output = dynamic_mask_attention_python( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time @@ -853,8 +794,8 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): try: triton_output = dynamic_mask_attention_triton( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() triton_time = time.time() - start_time @@ -896,8 +837,8 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): if triton_max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - - del query_states, key_states, value_states, dt_proj, A, cache_position, py_output, py_output_copy + + del query_states, key_states, value_states, attn_bias, causal_mask, py_output, py_output_copy if triton_output is not None: del triton_output, triton_output_copy torch.cuda.empty_cache() @@ -1031,25 +972,24 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor and keep window size scaling = head_dim ** -0.5 - keep_window_size = 64 + window_size = 1024 # Run Python implementation start_time = time.time() py_output = dynamic_mask_attention_python( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + window_size, attn_bias, causal_mask, scaling, + is_causal ) torch.cuda.synchronize() py_time = time.time() - start_time @@ -1059,8 +999,8 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): try: flex_output = dynamic_mask_attention_flex( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, + window_size, is_causal ) torch.cuda.synchronize() flex_time = time.time() - start_time @@ -1102,8 +1042,8 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): if flex_max_diff > 1e-2: print(" ⚠️ Difference too large, stopping subsequent tests.") break - - del query_states, key_states, value_states, dt_proj, A, cache_position, py_output, py_output_copy + + del query_states, key_states, value_states, attn_bias, causal_mask, py_output, py_output_copy if flex_output is not None: del flex_output, flex_output_copy torch.cuda.empty_cache() @@ -1203,4 +1143,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 0d48c1a..5730e0e 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -72,86 +72,51 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_dynamic_mask( +def prepare_mask( hidden_states: torch.Tensor, - zoh_states: torch.Tensor, - keep_window_size: int = 2048, - cache_position: torch.Tensor = None, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor = None, + window_size: int = None, ): """ - Calculate dynamic attention mask to mask tokens for sparse attention. - - Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. - Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) - keep_window_size: Window size of tokens not dynamically masked - cache_position: Optional cache position for causal masking + attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) + causal_mask: Optional causal mask to apply + window_size: Window size of tokens not masked Returns: tuple: (attn_bias, attn_mask) """ dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[2], -1 - ).to(dtype) # [batch_size, num_kv_heads, query_len, key_len] - - if cache_position is not None: - attn_bias = attn_bias.masked_fill( - torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), - min_dtype - ) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk) - attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype) + if attn_bias.shape[-1] > window_size: + if causal_mask is not None: + topk_values, topk_indices = torch.topk( + attn_bias.masked_fill(~causal_mask, min_dtype).detach(), + window_size, dim=-1, largest=True, sorted=False + ) + else: + topk_values, topk_indices = torch.topk( + attn_bias, + window_size, dim=-1, largest=True, sorted=False + ) + attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) else: - attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) + attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) return attn_bias, attn_mask -def calculate_zoh_states(value_states, dt_proj, A): - """ - Calculate zoh states for dynamic mask attention. - - Args: - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask - - Returns: - zoh_states: [batch_size, num_kv_heads, key_len] - """ - batch_size, _, key_len, _ = value_states.shape - - # Transpose and reshape value_states, then matrix multiply with dt_proj.T - dt_result = torch.matmul( - value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), - dt_proj.T - ) - - # Apply softplus activation and coefficient A - dt_states = torch.exp(F.softplus(dt_result) * A) - zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zoh_states - - def scaled_dot_product_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - scaling: float, + attn_bias: torch.Tensor, causal_mask: torch.Tensor, - is_causal=True, + scaling: float, + window_size: int, + is_causal: bool, ): """ CUDA implementation of SDPA baseline. @@ -160,24 +125,36 @@ def scaled_dot_product_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - causal_mask: Causal attention mask is_causal: Whether to apply causal masking Returns: - attn_outputs or "OOM" if out of memory + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, _, query_len, _ = query_states.shape - _, _, key_len, _ = key_states.shape - if query_len > 32768 and key_len > 32768: - return "OOM" + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) + + # Repeat KV for multi-head attention (GQA support) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() + attn_bias = attn_bias.masked_fill(~attn_mask, torch.finfo(query_states.dtype).min).contiguous() try: - # Only measure the core attention computation torch.cuda.synchronize() start_time = time.time() @@ -185,17 +162,17 @@ def scaled_dot_product_attention_cuda( query_states, key_states, value_states, - attn_mask=causal_mask, - softmax_scale=scaling, - # is_causal=is_causal if query_len == key_len else False, - enable_gqa=True + attn_mask=attn_bias, + scale=scaling, + # is_causal=is_causal, + enable_gqa=True, ) torch.cuda.synchronize() end_time = time.time() - attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] - return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + attn_outputs = attn_outputs.transpose(1, 2).contiguous() + return attn_outputs, (end_time - start_time) * 1000 except torch.cuda.OutOfMemoryError: return "OOM", 0 @@ -204,13 +181,11 @@ def dynamic_mask_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, + window_size=2048, is_causal=True, - return_softmax=False ): """ CUDA implementation of dynamic mask attention. @@ -219,33 +194,26 @@ def dynamic_mask_attention_cuda( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking - return_softmax: Whether to return softmax weights Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ if flash_dmattn_func is None: return "Not Available", 0 - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_mask( query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] + attn_bias, + causal_mask if is_causal else None, + window_size, + ) # Ensure correct data types and memory layout for CUDA function - # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] @@ -254,24 +222,23 @@ def dynamic_mask_attention_cuda( torch.cuda.synchronize() start_time = time.time() - # Call the new flash_dmattn_func interface attn_outputs = flash_dmattn_func( - query_states, # [batch, query_len, num_heads, head_dim] - key_states, # [batch, key_len, num_kv_heads, head_dim] - value_states, # [batch, key_len, num_kv_heads, head_dim] - attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] - attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=is_causal, softmax_scale=scaling, softcap=0.0, deterministic=False, - return_attn_probs=return_softmax + return_attn_probs=False ) torch.cuda.synchronize() end_time = time.time() - return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + return attn_outputs, (end_time - start_time) * 1000 except torch.cuda.OutOfMemoryError: return "OOM", 0 @@ -280,11 +247,10 @@ def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, + window_size=2048, is_causal=True, ): """ @@ -294,15 +260,14 @@ def dynamic_mask_attention_triton( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ if triton_dmattn_func is None: return "Not Available", 0 @@ -311,44 +276,38 @@ def dynamic_mask_attention_triton( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - - # Only measure the core Triton kernel computation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Triton function + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + try: torch.cuda.synchronize() start_time = time.time() - # Call the Triton implementation attn_outputs = triton_dmattn_func( - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_heads, head_dim] - attn_mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - attn_bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - is_causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -363,11 +322,10 @@ def dynamic_mask_attention_flex( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, + attn_bias: torch.Tensor, + causal_mask: torch.Tensor, scaling: float, - cache_position: torch.Tensor, - keep_window_size=2048, + window_size=2048, is_causal=True, ): """ @@ -377,15 +335,14 @@ def dynamic_mask_attention_flex( query_states: [batch_size, num_heads, query_len, head_dim] key_states: [batch_size, num_kv_heads, key_len, head_dim] value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] + attn_bias: [batch_size, num_heads, query_length, key_length] + causal_mask: [batch_size, 1, query_length, key_length] or None + window_size: Number of tokens to keep in attention window scaling: Attention scaling factor - cache_position: Cache position for causal masking - keep_window_size: Number of tokens to keep in attention window is_causal: Whether to apply causal masking Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] + tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ if flex_dmattn_func is None: return "Not Available", 0 @@ -394,40 +351,39 @@ def dynamic_mask_attention_flex( _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - try: - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) + attn_bias, attn_mask = prepare_mask( + query_states, + attn_bias, + causal_mask if is_causal else None, + window_size, + ) - # Use prepare_dynamic_mask to get the processed attention mask - attn_bias, attn_mask = prepare_dynamic_mask( - query_states, - zoh_states, - keep_window_size, - cache_position if is_causal else None - ) # [batch_size, num_kv_heads, query_len, key_len] - - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - - # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format - # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format - - # Only measure the core Flex Attention computation + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Ensure correct data types and memory layout for Flex function + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + attn_mask = attn_mask.contiguous() + attn_bias = attn_bias.contiguous() + + try: torch.cuda.synchronize() start_time = time.time() # Call the Flex Attention implementation attn_outputs = flex_dmattn_func( - query_states.transpose(1, 2), # q: [batch, query_len, num_heads, head_dim] - key_states.transpose(1, 2), # k: [batch, key_len, num_heads, head_dim] - value_states.transpose(1, 2), # v: [batch, key_len, num_heads, head_dim] - attn_mask=attn_mask, # attn_mask: [batch, num_heads, query_len, key_len] - attn_bias=attn_bias, # attn_bias: [batch, num_heads, query_len, key_len] - is_causal=is_causal, # is_causal: whether to apply causal masking - softmax_scale=scaling # scaling factor + query_states, + key_states, + value_states, + attn_mask=attn_mask, + attn_bias=attn_bias, + is_causal=is_causal, + softmax_scale=scaling, ) torch.cuda.synchronize() @@ -457,14 +413,14 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ Benchmark attention performance for a given configuration. Args: - config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) num_runs: Number of benchmark runs warmup_runs: Number of warmup runs Returns: dict: Performance metrics """ - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create random input data @@ -480,21 +436,12 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ batch_size, num_kv_heads, key_len, head_dim, device=device, dtype=torch.bfloat16 ) - dt_proj = torch.randn( - num_kv_heads, num_kv_heads * head_dim, + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, device=device, dtype=torch.bfloat16 ) - A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) - - # Create custom causal mask with cache position cache_position = torch.arange(key_len - query_len, key_len, device=device) - min_type = torch.finfo(value_states.dtype).min - causal_mask = torch.full( - (query_len, key_len), fill_value=min_type, - device=device, dtype=value_states.dtype - ) - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) # Set scaling factor from config @@ -531,7 +478,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": results['sdpa_forward_status'] = 'OOM' @@ -546,7 +494,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = scaled_dot_product_attention_cuda( query_states, key_states, value_states, - scaling, causal_mask, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": @@ -571,8 +520,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": results['fdma_cuda_forward_status'] = 'OOM' @@ -587,8 +536,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] == "OOM": @@ -613,8 +562,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = dynamic_mask_attention_triton( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_triton_forward_status'] = result[0] @@ -629,8 +578,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = dynamic_mask_attention_triton( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -655,8 +604,8 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(warmup_runs): result = dynamic_mask_attention_flex( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, + scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: results['fdma_flex_forward_status'] = result[0] @@ -671,8 +620,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ for _ in range(num_runs): result = dynamic_mask_attention_flex( query_states, key_states, value_states, - dt_proj, A, scaling, cache_position, - keep_window_size, is_causal + attn_bias, causal_mask, scaling, window_size, is_causal ) if result[0] in ["OOM", "Not Available"]: @@ -718,43 +666,43 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): print(title) print("🏆" + "=" * 76 + "🏆") - # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) + # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal) configs = [ # Vary sequence length - (1, 2, 1, 256, 256, 128, 1024, True), - (1, 2, 1, 512, 512, 128, 1024, True), - (1, 2, 1, 1024, 1024, 128, 1024, True), - (1, 2, 1, 2048, 2048, 128, 1024, True), - (1, 2, 1, 4096, 4096, 128, 1024, True), - (1, 2, 1, 8192, 8192, 128, 1024, True), - (1, 2, 1, 16384, 16384, 128, 1024, True), - (1, 2, 1, 32768, 32768, 128, 1024, True), + (1, 2, 1, 256, 256, 64, 1024, True), + (1, 2, 1, 512, 512, 64, 1024, True), + (1, 2, 1, 1024, 1024, 64, 1024, True), + (1, 2, 1, 2048, 2048, 64, 1024, True), + (1, 2, 1, 4096, 4096, 64, 1024, True), + (1, 2, 1, 8192, 8192, 64, 1024, True), + (1, 2, 1, 16384, 16384, 64, 1024, True), + (1, 2, 1, 32768, 32768, 64, 1024, True), # Inference - (1, 2, 1, 1, 256, 128, 1024, True), - (1, 2, 1, 1, 512, 128, 1024, True), - (1, 2, 1, 1, 1024, 128, 1024, True), - (1, 2, 1, 1, 2048, 128, 1024, True), - (1, 2, 1, 1, 4096, 128, 1024, True), - (1, 2, 1, 1, 8192, 128, 1024, True), - (1, 2, 1, 1, 16384, 128, 1024, True), - (1, 2, 1, 1, 32768, 128, 1024, True), - (1, 2, 1, 1, 65536, 128, 1024, True), - (1, 2, 1, 1, 131072, 128, 1024, True), - (1, 2, 1, 1, 262144, 128, 1024, True), - (1, 2, 1, 1, 524288, 128, 1024, True), + (1, 2, 1, 1, 256, 64, 1024, True), + (1, 2, 1, 1, 512, 64, 1024, True), + (1, 2, 1, 1, 1024, 64, 1024, True), + (1, 2, 1, 1, 2048, 64, 1024, True), + (1, 2, 1, 1, 4096, 64, 1024, True), + (1, 2, 1, 1, 8192, 64, 1024, True), + (1, 2, 1, 1, 16384, 64, 1024, True), + (1, 2, 1, 1, 32768, 64, 1024, True), + (1, 2, 1, 1, 65536, 64, 1024, True), + (1, 2, 1, 1, 131072, 64, 1024, True), + (1, 2, 1, 1, 262144, 64, 1024, True), + (1, 2, 1, 1, 524288, 64, 1024, True), # Vary batch size - (1, 2, 1, 4096, 4096, 32, 1024, True), - (2, 2, 1, 4096, 4096, 32, 1024, True), - (4, 2, 1, 4096, 4096, 32, 1024, True), - (8, 2, 1, 4096, 4096, 32, 1024, True), + (1, 2, 1, 4096, 4096, 64, 1024, True), + (2, 2, 1, 4096, 4096, 64, 1024, True), + (4, 2, 1, 4096, 4096, 64, 1024, True), + (8, 2, 1, 4096, 4096, 64, 1024, True), # Vary head count - (1, 1, 1, 4096, 4096, 32, 1024, True), - (1, 2, 1, 4096, 4096, 32, 1024, True), - (1, 4, 1, 4096, 4096, 32, 1024, True), - (1, 8, 2, 4096, 4096, 32, 1024, True), + (1, 1, 1, 4096, 4096, 64, 1024, True), + (1, 2, 1, 4096, 4096, 64, 1024, True), + (1, 4, 1, 4096, 4096, 64, 1024, True), + (1, 8, 2, 4096, 4096, 64, 1024, True), # Vary head dimension (1, 2, 1, 4096, 4096, 32, 1024, True), @@ -764,18 +712,18 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): (1, 2, 1, 4096, 4096, 192, 1024, True), (1, 2, 1, 4096, 4096, 256, 1024, True), - # Vary keep_window_size - (1, 2, 1, 32768, 32768, 128, 32, True), - (1, 2, 1, 32768, 32768, 128, 64, True), - (1, 2, 1, 32768, 32768, 128, 128, True), - (1, 2, 1, 32768, 32768, 128, 256, True), - (1, 2, 1, 32768, 32768, 128, 512, True), - (1, 2, 1, 32768, 32768, 128, 1024, True), - (1, 2, 1, 32768, 32768, 128, 2048, True), - (1, 2, 1, 32768, 32768, 128, 4096, True), - (1, 2, 1, 32768, 32768, 128, 8192, True), - (1, 2, 1, 32768, 32768, 128, 16384, True), - (1, 2, 1, 32768, 32768, 128, 32768, True), + # Vary window_size + (1, 2, 1, 32768, 32768, 64, 32, True), + (1, 2, 1, 32768, 32768, 64, 64, True), + (1, 2, 1, 32768, 32768, 64, 128, True), + (1, 2, 1, 32768, 32768, 64, 256, True), + (1, 2, 1, 32768, 32768, 64, 512, True), + (1, 2, 1, 32768, 32768, 64, 1024, True), + (1, 2, 1, 32768, 32768, 64, 2048, True), + (1, 2, 1, 32768, 32768, 64, 4096, True), + (1, 2, 1, 32768, 32768, 64, 8192, True), + (1, 2, 1, 32768, 32768, 64, 16384, True), + (1, 2, 1, 32768, 32768, 64, 32768, True), ] print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):") @@ -785,7 +733,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): all_results = [] for config in configs: - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, window_size, is_causal = config results = benchmark_attention_performance(config, test_type, num_runs, warmup_runs) all_results.append(results) @@ -824,7 +772,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_strs[impl_key] = "N/A" # Format output with shorter config string - config_short = f" B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{keep_window_size} " + config_short = f" B{batch_size} Hq{num_heads} Hkv{num_kv_heads} Q{query_len} K{key_len} D{head_dim} W{window_size} " if not is_causal: config_short += "N" else: diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 0cc7cdf..e6a8b7a 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -76,10 +76,10 @@ void set_params_fprop( params.k_head_stride = k.stride(-2); params.v_row_stride = v.stride(-3); params.v_head_stride = v.stride(-2); - params.mask_head_stride = has_mask ? mask.stride(-3) : 0; - params.mask_row_stride = has_mask ? mask.stride(-2) : 0; - params.bias_head_stride = has_bias ? bias.stride(-3) : 0; - params.bias_row_stride = has_bias ? bias.stride(-2) : 0; + params.mask_head_stride = has_mask ? (mask.size(-3) == 1 ? 0 : mask.stride(-3)) : 0; + params.mask_row_stride = has_mask ? (mask.size(-2) == 1 ? 0 : mask.stride(-2)) : 0; + params.bias_head_stride = has_bias ? (bias.size(-3) == 1 ? 0 : bias.stride(-3)) : 0; + params.bias_row_stride = has_bias ? (bias.size(-2) == 1 ? 0 : bias.stride(-2)) : 0; params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); @@ -87,8 +87,8 @@ void set_params_fprop( params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - params.mask_batch_stride = has_mask ? mask.stride(0) : 0; - params.bias_batch_stride = has_bias ? bias.stride(0) : 0; + params.mask_batch_stride = has_mask ? (mask.size(0) == 1 ? 0 : mask.stride(0)) : 0; + params.bias_batch_stride = has_bias ? (bias.size(0) == 1 ? 0 : bias.stride(0)) : 0; params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; @@ -227,15 +227,15 @@ void set_params_dgrad( params.dk_head_stride = dk.stride(-2); params.dv_row_stride = dv.stride(-3); params.dv_head_stride = dv.stride(-2); - params.dbias_head_stride = has_bias ? dbias.stride(-3) : 0; - params.dbias_row_stride = has_bias ? dbias.stride(-2) : 0; + params.dbias_head_stride = has_bias ? (dbias.size(-3) == 1 ? 0 : dbias.stride(-3)) : 0; + params.dbias_row_stride = has_bias ? (dbias.size(-2) == 1 ? 0 : dbias.stride(-2)) : 0; if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); params.dq_batch_stride = dq.stride(0); params.dk_batch_stride = dk.stride(0); params.dv_batch_stride = dv.stride(0); - params.dbias_batch_stride = has_bias ? dbias.stride(0) : 0; + params.dbias_batch_stride = has_bias ? (dbias.size(0) == 1 ? 0 : dbias.stride(0)) : 0; } params.dq_accum_ptr = dq_accum_d; @@ -353,8 +353,8 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &mask_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x round_multiple(seqlen_k, 128) + std::optional &bias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x round_multiple(seqlen_k, 128) std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, @@ -387,11 +387,8 @@ mha_fwd( mask = mask_.value(); TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); CHECK_DEVICE(mask); + TORCH_CHECK(mask.dim() == 4, "mask must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k_rounded)"); TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (mask.dim() == 3) { - // Add a dummy dimension for seqlen_q - mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { mask = torch::empty({0}, opts); } @@ -401,11 +398,8 @@ mha_fwd( bias = bias_.value(); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(bias); + TORCH_CHECK(bias.dim() == 4, "bias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k_rounded)"); TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (bias.dim() == 3) { - // Add a dummy dimension for seqlen_q - bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { bias = torch::empty({0}, opts); } @@ -420,16 +414,27 @@ mha_fwd( const int num_heads_k = k.size(2); int num_heads_mask = has_mask ? mask.size(1) : 1; int num_heads_bias = has_bias ? bias.size(1) : 1; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (has_mask) { + TORCH_CHECK(mask.size(0) == batch_size || mask.size(0) == 1, "Batch dimension in mask must be 1 or equal to batch size"); TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(mask.size(2) == 1 || mask.size(2) == seqlen_q, "Query length dimension in mask must be 1 or equal to seqlen_q"); + TORCH_CHECK(mask.size(3) == seqlen_k_rounded, "Key length dimension in mask must be seqlen_k_rounded"); } if (has_bias) { + TORCH_CHECK(bias.size(0) == batch_size || bias.size(0) == 1, "Batch dimension in bias must be 1 or equal to batch size"); TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + TORCH_CHECK(bias.size(2) == 1 || bias.size(2) == seqlen_q, "Query length dimension in bias must be 1 or equal to seqlen_q"); + TORCH_CHECK(bias.size(3) == seqlen_k_rounded, "Key length dimension in bias must be seqlen_k_rounded"); } // causal=true is the same as causal=false in this case @@ -439,27 +444,29 @@ mha_fwd( // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; - const int orig_num_heads_mask = num_heads_mask; - const int orig_num_heads_bias = num_heads_bias; + const int batch_size_mask_og = has_mask ? mask.size(0) : batch_size; + const int batch_size_bias_og = has_bias ? bias.size(0) : batch_size; + const int num_heads_mask_og = num_heads_mask; + const int num_heads_bias_og = num_heads_bias; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); if (has_mask) { - mask = num_heads_mask == 1 - ? mask.expand({batch_size, 1, ngroups, seqlen_k}) - : ( - num_heads_mask == num_heads_k - ? mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}) - : mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) - ); + if (num_heads_mask == 1) { + mask = mask.reshape({batch_size_mask_og, 1, 1, seqlen_k_rounded}); + } else if (num_heads_mask == num_heads_k) { + mask = mask.reshape({batch_size_mask_og, num_heads_k, 1, seqlen_k_rounded}); + } else if (num_heads_mask == num_heads) { + mask = mask.reshape({batch_size_mask_og, num_heads_k, ngroups, seqlen_k_rounded}); + } } if (has_bias) { - bias = num_heads_bias == 1 - ? bias.expand({batch_size, 1, ngroups, seqlen_k}) - : ( - num_heads_bias == num_heads_k - ? bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}) - : bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) - ); + if (num_heads_bias == 1) { + bias = bias.reshape({batch_size_bias_og, 1, 1, seqlen_k_rounded}); + } else if (num_heads_bias == num_heads_k) { + bias = bias.reshape({batch_size_bias_og, num_heads_k, 1, seqlen_k_rounded}); + } else if (num_heads_bias == num_heads) { + bias = bias.reshape({batch_size_bias_og, num_heads_k, ngroups, seqlen_k_rounded}); + } } num_heads_mask = has_mask ? ((num_heads_mask == num_heads) ? num_heads_k : num_heads_mask) : 1; num_heads_bias = has_bias ? ((num_heads_bias == num_heads) ? num_heads_k : num_heads_bias) : 1; @@ -485,11 +492,6 @@ mha_fwd( out = torch::empty_like(q); } - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor p; @@ -541,16 +543,13 @@ mha_fwd( q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); if (has_mask) { - mask = (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) - ? mask.narrow(2, 0, 1) - : mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + mask = mask.reshape({batch_size_mask_og, num_heads_mask_og, 1, seqlen_k_rounded}); } if (has_bias) { - bias = (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) - ? bias.narrow(2, 0, 1) - : bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + bias = bias.reshape({batch_size_bias_og, num_heads_bias_og, 1, seqlen_k_rounded}); } } + return {out, softmax_lse, p}; } @@ -796,14 +795,14 @@ mha_bwd( const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k - const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const std::optional &mask_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1} + const std::optional &bias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1} const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &dbias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1} const float softmax_scale, const bool is_causal, const float softcap, @@ -846,11 +845,8 @@ mha_bwd( mask = mask_.value(); TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); CHECK_DEVICE(mask); + TORCH_CHECK(mask.dim() == 4, "mask must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)"); TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (mask.dim() == 3) { - // Add a dummy dimension for seqlen_q - mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { mask = torch::empty({0}, opts); } @@ -860,11 +856,8 @@ mha_bwd( bias = bias_.value(); TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(bias); + TORCH_CHECK(bias.dim() == 4, "bias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)"); TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - if (bias.dim() == 3) { - // Add a dummy dimension for seqlen_q - bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); - } } else { bias = torch::empty({0}, opts); } @@ -879,29 +872,39 @@ mha_bwd( const int num_heads_k = k.size(2); int num_heads_mask = has_mask ? mask.size(1) : 1; int num_heads_bias = has_bias ? bias.size(1) : 1; + int batch_size_mask = has_mask ? mask.size(0) : batch_size; + int batch_size_bias = has_bias ? bias.size(0) : batch_size; + int seqlen_q_mask = has_mask ? mask.size(2) : seqlen_q; + int seqlen_q_bias = has_bias ? bias.size(2) : seqlen_q; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (has_mask) { + TORCH_CHECK(mask.size(0) == batch_size || mask.size(0) == 1, "Batch dimension in mask must be 1 or equal to batch size"); TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(mask.size(2) == 1 || mask.size(2) == seqlen_q, "Query length dimension in mask must be 1 or equal to seqlen_q"); + TORCH_CHECK(mask.size(3) == seqlen_k_rounded, "Key length dimension in mask must be seqlen_k_rounded"); } if (has_bias) { + TORCH_CHECK(bias.size(0) == batch_size || bias.size(0) == 1, "Batch dimension in bias must be 1 or equal to batch size"); TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + TORCH_CHECK(bias.size(2) == 1 || bias.size(2) == seqlen_q, "Query length dimension in bias must be 1 or equal to seqlen_q"); + TORCH_CHECK(bias.size(3) == seqlen_k_rounded, "Key length dimension in bias must be seqlen_k_rounded"); } - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - + at::Tensor dq, dk, dv, dbias; if (dq_.has_value()) { dq = dq_.value(); @@ -935,30 +938,14 @@ mha_bwd( dbias = dbias_.value(); TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); CHECK_DEVICE(dbias); + TORCH_CHECK(dbias.dim() == 4, "dbias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)"); TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - if (dbias.dim() == 4) { - CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_k); - } + TORCH_CHECK(dbias.size(0) == batch_size || dbias.size(0) == 1, "Batch dimension in dbias must be 1 or equal to batch size"); + TORCH_CHECK(dbias.size(1) == num_heads || dbias.size(1) == num_heads_k || dbias.size(1) == 1, "Number of heads in dbias must be 1, h_k or h"); + TORCH_CHECK(dbias.size(2) == seqlen_q || dbias.size(2) == 1, "Query length dimension in dbias must be 1 or equal to seqlen_q"); + TORCH_CHECK(dbias.size(3) == seqlen_k_rounded, "Key length dimension in dbias must be seqlen_k_rounded"); } else { - if (bias.dim() == 4) { - if (num_heads_bias == 1) { - dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); - } else if (num_heads_bias == num_heads_k) { - dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); - } else { - dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); - } - } else { - if (num_heads_bias == 1) { - dbias = torch::empty({batch_size, 1, seqlen_k}, opts); - } else if (num_heads_bias == num_heads_k) { - dbias = torch::empty({batch_size, num_heads_k, seqlen_k}, opts); - } else { - dbias = torch::empty({batch_size, num_heads, seqlen_k}, opts); - } - } + dbias = torch::empty({batch_size_bias, num_heads_bias, seqlen_q_bias, seqlen_k_rounded}, opts); } } else { dbias = torch::empty({0}, opts); @@ -991,8 +978,8 @@ mha_bwd( : dv; dbias_expanded = has_bias ? ( - (num_heads_bias != num_heads) || (bias_.has_value() && bias_.value().dim() == 3) // MQA / GQA or bias has no seqlen_q dimension - ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts) + (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q + ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) : dbias ) : torch::empty({0}, opts); @@ -1047,24 +1034,19 @@ mha_bwd( at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } - // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads + // For MQA/GQA or dbias has different batch size or seqlen_q, we need to sum dbias across the groups, batch and seqlen_q if (has_bias) { - bool sum_seqlen_q = bias_.has_value() && bias_.value().dim() == 3; - if (num_heads_bias != num_heads) { - if (sum_seqlen_q) { - dbias_expanded = at::sum( - at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} - ); - } else { - at::sum_out( - dbias, - at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} - ); + if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) { + at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); + } else { + dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); + if (seqlen_q_bias == 1) { + dbias_expanded = at::sum(dbias_expanded, {2}, true); } - } - if (sum_seqlen_q) { - // We need to sum across the seqlen_q dimension - at::sum_out(dbias, dbias_expanded, {2}); + if (batch_size_bias == 1) { + dbias_expanded = at::sum(dbias_expanded, {0}, true); + } + dbias.copy_(dbias_expanded); } } diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 893c638..663dc2e 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -416,14 +416,17 @@ def forward( q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - seqlen_k_og = k.shape[1] - if seqlen_k_og % 8 != 0: - k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) - if mask is not None: - mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) - if bias is not None: - bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) + seqlen_k_rounded = round_multiple(k.shape[1], 128) + if mask is not None and mask.shape[-1] != seqlen_k_rounded: + if mask.shape[-1] == 1: + mask = mask.expand(*mask.shape[:-1], seqlen_k_rounded) + else: + mask = torch.nn.functional.pad(mask, [0, seqlen_k_rounded - mask.shape[-1]]) + if bias is not None and bias.shape[-1] != seqlen_k_rounded: + if bias.shape[-1] == 1: + bias = bias.expand(*bias.shape[:-1], seqlen_k_rounded) + else: + bias = torch.nn.functional.pad(bias, [0, seqlen_k_rounded - bias.shape[-1]]) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -443,7 +446,6 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic - ctx.seqlen_k_og = seqlen_k_og out = out_padded[..., :head_size_og] @@ -488,11 +490,8 @@ def backward( dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - if ctx.seqlen_k_og % 8 != 0: - dk = dk[:, : ctx.seqlen_k_og, :, :] - dv = dv[:, : ctx.seqlen_k_og, :, :] - if dbias is not None: - dbias = dbias[..., : ctx.seqlen_k_og] + if dbias is not None: + dbias = dbias[..., : k.shape[1]] return dq, dk, dv, None, dbias, None, None, None, None, None, None @@ -646,10 +645,10 @@ def flash_dmattn_func( key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) attn_mask: torch.Tensor, optional. The attention mask boolean tensor of - shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to apply to the attention scores. + shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to apply to the attention scores. If None, no mask is applied. attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to add to the attention scores. + shape ({batch_size|1}, {nheads|nheads_k|1}, {seqlen_q|1}, {seqlen_k|1}) to add to the attention scores. If None, no bias is applied. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim).