From 443255b054c9d790e8a463e8dce5986172996be6 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 17:17:54 +0800 Subject: [PATCH 1/7] Simplifies attention implementation by removing flex attention support Removes flex attention forward function and its integration with BlockMask to streamline the attention mechanism. Updates flash attention import to use the more specific flash_dynamic_mask_attention_forward function instead of the generic auto function. Eliminates the complex prepare_dynamic_mask method that handled topk selection and masking logic, replacing it with a simpler direct bias expansion approach. This reduces code complexity while maintaining the core dynamic mask attention functionality. Changes the attention interface selection to prefer eager attention as fallback when flash attention is unavailable, improving compatibility across different environments. --- examples/modeling/modeling_doge.py | 131 ++++------------------------- 1 file changed, 15 insertions(+), 116 deletions(-) diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 804e145..6708b02 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -44,13 +44,9 @@ from .configuration_doge import DogeConfig try: - from flash_dmattn import flash_dmattn_func_auto + from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward except ImportError: - def flash_dmattn_func_auto(*args, **kwargs): - raise ImportError( - "flash_dmattn is not installed. Please install it to use flash_dmattn_func_auto. " - "You can install it with `pip install flash-dmattn` or consult the documentation." - ) + flash_dynamic_mask_attention_forward = None if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask @@ -183,59 +179,6 @@ def eager_attention_forward( return attn_output, attn_weights -def flex_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Union[torch.Tensor, "BlockMask"], - scaling: Optional[float] = None, - softcap: Optional[float] = None, - head_mask: Optional[torch.Tensor] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - block_mask = None - causal_mask = None - if isinstance(attention_mask, BlockMask): - block_mask = attention_mask - else: - causal_mask = attention_mask - - if causal_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): - if softcap is not None: - score = softcap * torch.tanh(score / softcap) - if causal_mask is not None: - score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx] - if head_mask is not None: - score = score + head_mask[batch_idx][head_idx][0][0] - return score - - attn_output, attention_weights = compile_friendly_flex_attention( - query, - key, - value, - score_mod=score_mod, - block_mask=block_mask, - enable_gqa=True, - scale=scaling, - # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. - # For simplification, we thus always return it as no additional computations are introduced. - return_lse=True, - ) - # lse is returned in float32 - attention_weights = attention_weights.to(value.dtype) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attention_weights - - -ALL_ATTENTION_FUNCTIONS = AttentionInterface() -ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward - - class DogeAttention(nn.Module): def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): super().__init__() @@ -292,77 +235,33 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # calculate dynamic mask from value_states + # sampling dt_states from value_states to generate attention bias dt_states = self.dt_proj( value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) ) dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attn_bias, attn_mask = self.prepare_dynamic_mask( - hidden_states=hidden_states, - dt_states=dt_states, - keep_window_size=self.keep_window_size, - attention_mask=attention_mask, - ) + attn_bias = dt_states[:, :, None, :].expand( + -1, -1, hidden_states.shape[1], -1 + ).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len] - attention_interface: Callable = flash_dmattn_func_auto(backend="cuda") - query_states = query_states.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] - key_states = key_states.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] - value_states = value_states.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] + attention_interface: Callable = eager_attention_forward + if flash_dynamic_mask_attention_forward is not None: + attention_interface = flash_dynamic_mask_attention_forward - attn_output = attention_interface( + attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) if attention_mask is not None else None, # attention_mask: batch, num_kv_heads, query_len, key_len + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=attn_mask, - attn_bias=attn_bias, - is_causal=self.is_causal, + attention_mask=attention_mask, + attention_bias=attn_bias, scale=self.scaling, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, None - - def prepare_dynamic_mask( - self, - hidden_states: torch.Tensor, - dt_states: torch.Tensor, - keep_window_size: int = 2048, - attention_mask: Optional[torch.Tensor] = None, - ): - """ - The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention. - - Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. - - Args: - hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision. - dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`. - keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. - attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`. - """ - min_dtype = torch.finfo(hidden_states.dtype).min - dtype = hidden_states.dtype - attn_bias = dt_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[1], -1 - ) # [batch_size, num_heads, query_len, key_len] - if attention_mask is not None and not isinstance(attention_mask, BlockMask): - if attention_mask.dtype == torch.bool: - attention_mask = torch.where( - attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype - ) - attn_bias = attn_bias.masked_fill(attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) - else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) - return attn_bias, attn_mask + return attn_output, attn_weights class DogeMLP(nn.Module): From 2bb5346cbd0950e7640919c8bd15d9e9318c5a12 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 17:18:17 +0800 Subject: [PATCH 2/7] Fixes indentation in attention interface assignment Corrects improper indentation that was causing the flash dynamic mask attention interface assignment to be misaligned with the surrounding code block structure. --- examples/modeling/modeling_doge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 6708b02..d11b143 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -246,7 +246,7 @@ def forward( attention_interface: Callable = eager_attention_forward if flash_dynamic_mask_attention_forward is not None: - attention_interface = flash_dynamic_mask_attention_forward + attention_interface = flash_dynamic_mask_attention_forward attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) if attention_mask is not None else None, # attention_mask: batch, num_kv_heads, query_len, key_len attn_output, attn_weights = attention_interface( From 2b5f44e621dc0447e08a4631ceac05e6721b8df1 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 17:25:55 +0800 Subject: [PATCH 3/7] Adds attention bias support to eager attention forward Enables the attention mechanism to accept and apply an optional attention bias tensor, allowing for more flexible attention patterns and improved model capabilities. The bias is added to attention weights before applying the attention mask, following standard transformer architecture practices. --- examples/modeling/modeling_doge.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index d11b143..bbbab3f 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -159,6 +159,7 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + attention_bias: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], @@ -167,10 +168,11 @@ def eager_attention_forward( value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_bias is not None: + attn_weights = attn_weights + attention_bias if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) From 4111ac07f1ead51e71c86b6e3fa52f1fc6ff84b6 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 22:31:12 +0800 Subject: [PATCH 4/7] Simplifies block size calculation for splitkv kernels Replaces conditional block size logic with fixed value of 64 to streamline the splitkv kernel configuration and eliminate branching based on head size. The previous conditional logic is preserved as a comment for reference. --- csrc/flash_api.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 7a0b4e7..93cde0e 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -311,7 +311,8 @@ std::tuple set_params_splitkv( ) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64); + const int block_n = 64; + // const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. From 9dd59258bae8ed3eabc3ad53abb37baace6824c7 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 22:33:12 +0800 Subject: [PATCH 5/7] Optimizes Flash Attention kernel configurations Standardizes block dimensions to 64x64 across all head dimensions and updates shared memory thresholds for better GPU utilization. Changes kernel selection logic to use consistent 164KB threshold and provides detailed CTA count documentation for different GPU architectures (sm86/89, A100, H100). Improves memory efficiency by using smaller block sizes with better occupancy characteristics and enables compact memory layout flags for older architectures. --- csrc/src/flash_fwd_launch_template.h | 71 +++++++++++++++++++++------- 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 76953d9..9d659ca 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -155,7 +155,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64); + constexpr static int kBlockN = 64; // Fixed for all head dimensions + // constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd, Is_causal>(params, stream); } @@ -171,11 +172,18 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 176 * 1024) { - run_flash_fwd, Is_causal>(params, stream); + if (max_smem_per_block >= 164 * 1024) { + // 28KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 8 CTAs in H100. + run_flash_fwd, Is_causal>(params, stream); + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + // 24KB, 4 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal>(params, stream); } + } template @@ -190,11 +198,18 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 224 * 1024) { - run_flash_fwd, Is_causal>(params, stream); - } else { + if (max_smem_per_block >= 164 * 1024) { // H100 and A100 + // 40KB, 2 CTAs in sm86 and sm 89, 4 CTAs in A100, 5 CTAs in H100. run_flash_fwd, Is_causal>(params, stream); + // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 112KB, N/A in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 + // 32KB, 3 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal>(params, stream); } + } template @@ -209,9 +224,15 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 160 * 1024) { - run_flash_fwd, Is_causal>(params, stream); - } else { + if (max_smem_per_block >= 164 * 1024) { // H100 and A100 + // 52KB, 1 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal>(params, stream); + // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 136KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 + // 40KB, 2 CTAs in sm86 and sm 89. run_flash_fwd, Is_causal>(params, stream); } } @@ -228,19 +249,28 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 192 * 1024) { - run_flash_fwd, Is_causal>(params, stream); - } else { - // For sm86 or sm89, 64 x 64 (48 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. - // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment - run_flash_fwd, Is_causal>(params, stream); + if (max_smem_per_block >= 164 * 1024) { // H100 and A100 + // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal>(params, stream); + // 96KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 160KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 + // 48KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal>(params, stream); } } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; + // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. run_flash_fwd, Is_causal>(params, stream); + // 128KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 208KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); } template @@ -255,9 +285,14 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 224 * 1024) { + if (max_smem_per_block >= 112 * 1024) { // H100 and A100 + // 112KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. run_flash_fwd, Is_causal>(params, stream); - } else { + // 192KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 run_flash_fwd, Is_causal>(params, stream); } } From 1082e7223e597679c141b9bc6f7922b59fddbd64 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 22:37:09 +0800 Subject: [PATCH 6/7] Enables optimizations for flash attention kernels Activates kernel optimizations by setting optimization flags to true for both 128 and 256 head dimension configurations on sm86 and sm89 architectures. Adds memory usage comment for 256 head dimension case to document resource requirements. --- csrc/src/flash_fwd_launch_template.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 9d659ca..2a7dd4a 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -258,7 +258,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { // run_flash_fwd, Is_causal>(params, stream); } else { // sm86 and sm89 // 48KB, 2 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } } @@ -293,6 +293,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { // 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100. // run_flash_fwd, Is_causal>(params, stream); } else { // sm86 and sm89 + // 80KB, 1 CTAs in sm86 and sm 89. run_flash_fwd, Is_causal>(params, stream); } } From 2767510a7b5022ba882261ad2369458268a48388 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 3 Sep 2025 22:37:25 +0800 Subject: [PATCH 7/7] Optimizes flash attention backward kernel configurations Adjusts kernel parameters across different head dimensions to improve memory usage and performance on various GPU architectures. Updates shared memory requirements and CTA counts for better utilization on sm86, sm89, A100, and H100 GPUs. Enables double buffering and adjusts block sizes to reduce memory footprint while maintaining or improving performance across different hardware configurations. --- csrc/src/flash_bwd_launch_template.h | 70 ++++++++++++---------------- 1 file changed, 29 insertions(+), 41 deletions(-) diff --git a/csrc/src/flash_bwd_launch_template.h b/csrc/src/flash_bwd_launch_template.h index b27a78a..1bf12db 100644 --- a/csrc/src/flash_bwd_launch_template.h +++ b/csrc/src/flash_bwd_launch_template.h @@ -138,12 +138,11 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 104 * 1024) { // H100 and A100 - // 104KB + // 104KB, 1 CTAs in A100, 2 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 96KB - // We need to adjust no_double_buffer to save some smem, because is_v_in_regs=true will still allocate smem that may overflow - run_flash_bwd, Is_causal>(params, stream); + // 96KB, 2 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -158,17 +157,17 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - // Changing AtomLayoutMdQ from 2 to 4 takes the same time - // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { // H100 and A100 - // 144KB + // In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close. + // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100. + // run_flash_bwd, Is_causal>(params, stream); + // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + // run_flash_bwd, Is_causal>(params, stream); + // 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); - // This has a lot of register spilling - // run_flash_bwd>(params, stream); } else { // sm86 and sm89 - // 88KB - run_flash_bwd, Is_causal>(params, stream); + // 72KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times } @@ -186,11 +185,11 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { } // printf("max_smem_per_block = %d\n", max_smem_per_block); if (max_smem_per_block >= 116 * 1024) { // H100 and A100 - // 116KB + // 116KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 80KB - run_flash_bwd, Is_causal>(params, stream); + // 92KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -205,20 +204,12 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - // run_flash_bwd>(params, stream); - // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). - // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. - // run_flash_bwd>(params, stream); - if (max_smem_per_block >= 224 * 1024) { // H100 - // 224KB - run_flash_bwd, Is_causal>(params, stream); - } else if (max_smem_per_block >= 144 * 1024) { // A100 - // 144KB + if (max_smem_per_block >= 144 * 1024) { // H100 and A100 + // 144KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 88KB - run_flash_bwd, Is_causal>(params, stream); + // 88KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -233,15 +224,12 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 208 * 1024) { // H100 - // 208KB - run_flash_bwd, Is_causal>(params, stream); - } else if (max_smem_per_block >= 152 * 1024) { // A100 - // 152KB + if (max_smem_per_block >= 136 * 1024) { // H100 and A100 + // 136KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 88KB - run_flash_bwd, Is_causal>(params, stream); + // 96KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -256,15 +244,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 200 * 1024) { // H100 - // 200KB + if (max_smem_per_block >= 176 * 1024) { // H100 + // 176KB, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); - } else if (max_smem_per_block >= 132 * 1024) { // A100 - // 132KB - run_flash_bwd, Is_causal>(params, stream); + } else if (max_smem_per_block >= 144 * 1024) { // A100 + // 144KB, 1 CTAs in A100. + run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 82KB - run_flash_bwd, Is_causal>(params, stream); + // 96KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } }