diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 7a0b4e7..93cde0e 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -311,7 +311,8 @@ std::tuple set_params_splitkv( ) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64); + const int block_n = 64; + // const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. diff --git a/csrc/src/flash_bwd_launch_template.h b/csrc/src/flash_bwd_launch_template.h index b27a78a..1bf12db 100644 --- a/csrc/src/flash_bwd_launch_template.h +++ b/csrc/src/flash_bwd_launch_template.h @@ -138,12 +138,11 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 104 * 1024) { // H100 and A100 - // 104KB + // 104KB, 1 CTAs in A100, 2 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 96KB - // We need to adjust no_double_buffer to save some smem, because is_v_in_regs=true will still allocate smem that may overflow - run_flash_bwd, Is_causal>(params, stream); + // 96KB, 2 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -158,17 +157,17 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - // Changing AtomLayoutMdQ from 2 to 4 takes the same time - // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { // H100 and A100 - // 144KB + // In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close. + // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100. + // run_flash_bwd, Is_causal>(params, stream); + // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + // run_flash_bwd, Is_causal>(params, stream); + // 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); - // This has a lot of register spilling - // run_flash_bwd>(params, stream); } else { // sm86 and sm89 - // 88KB - run_flash_bwd, Is_causal>(params, stream); + // 72KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times } @@ -186,11 +185,11 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { } // printf("max_smem_per_block = %d\n", max_smem_per_block); if (max_smem_per_block >= 116 * 1024) { // H100 and A100 - // 116KB + // 116KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 80KB - run_flash_bwd, Is_causal>(params, stream); + // 92KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -205,20 +204,12 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - // run_flash_bwd>(params, stream); - // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). - // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. - // run_flash_bwd>(params, stream); - if (max_smem_per_block >= 224 * 1024) { // H100 - // 224KB - run_flash_bwd, Is_causal>(params, stream); - } else if (max_smem_per_block >= 144 * 1024) { // A100 - // 144KB + if (max_smem_per_block >= 144 * 1024) { // H100 and A100 + // 144KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 88KB - run_flash_bwd, Is_causal>(params, stream); + // 88KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -233,15 +224,12 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 208 * 1024) { // H100 - // 208KB - run_flash_bwd, Is_causal>(params, stream); - } else if (max_smem_per_block >= 152 * 1024) { // A100 - // 152KB + if (max_smem_per_block >= 136 * 1024) { // H100 and A100 + // 136KB, 1 CTAs in A100, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 88KB - run_flash_bwd, Is_causal>(params, stream); + // 96KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } @@ -256,15 +244,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 200 * 1024) { // H100 - // 200KB + if (max_smem_per_block >= 176 * 1024) { // H100 + // 176KB, 1 CTAs in H100. run_flash_bwd, Is_causal>(params, stream); - } else if (max_smem_per_block >= 132 * 1024) { // A100 - // 132KB - run_flash_bwd, Is_causal>(params, stream); + } else if (max_smem_per_block >= 144 * 1024) { // A100 + // 144KB, 1 CTAs in A100. + run_flash_bwd, Is_causal>(params, stream); } else { // sm86 and sm89 - // 82KB - run_flash_bwd, Is_causal>(params, stream); + // 96KB, 1 CTAs in sm86 and sm 89. + run_flash_bwd, Is_causal>(params, stream); } } diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 76953d9..2a7dd4a 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -155,7 +155,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64); + constexpr static int kBlockN = 64; // Fixed for all head dimensions + // constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd, Is_causal>(params, stream); } @@ -171,11 +172,18 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 176 * 1024) { - run_flash_fwd, Is_causal>(params, stream); + if (max_smem_per_block >= 164 * 1024) { + // 28KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 8 CTAs in H100. + run_flash_fwd, Is_causal>(params, stream); + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + // 24KB, 4 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal>(params, stream); } + } template @@ -190,11 +198,18 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 224 * 1024) { - run_flash_fwd, Is_causal>(params, stream); - } else { + if (max_smem_per_block >= 164 * 1024) { // H100 and A100 + // 40KB, 2 CTAs in sm86 and sm 89, 4 CTAs in A100, 5 CTAs in H100. run_flash_fwd, Is_causal>(params, stream); + // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 112KB, N/A in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 + // 32KB, 3 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal>(params, stream); } + } template @@ -209,9 +224,15 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 160 * 1024) { - run_flash_fwd, Is_causal>(params, stream); - } else { + if (max_smem_per_block >= 164 * 1024) { // H100 and A100 + // 52KB, 1 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal>(params, stream); + // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 136KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 + // 40KB, 2 CTAs in sm86 and sm 89. run_flash_fwd, Is_causal>(params, stream); } } @@ -228,11 +249,15 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 192 * 1024) { - run_flash_fwd, Is_causal>(params, stream); - } else { - // For sm86 or sm89, 64 x 64 (48 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. - // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment + if (max_smem_per_block >= 164 * 1024) { // H100 and A100 + // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal>(params, stream); + // 96KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 160KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 + // 48KB, 2 CTAs in sm86 and sm 89. run_flash_fwd, Is_causal>(params, stream); } } @@ -240,7 +265,12 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; + // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. run_flash_fwd, Is_causal>(params, stream); + // 128KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 208KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); } template @@ -255,9 +285,15 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 224 * 1024) { + if (max_smem_per_block >= 112 * 1024) { // H100 and A100 + // 112KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. run_flash_fwd, Is_causal>(params, stream); - } else { + // 192KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + // 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100. + // run_flash_fwd, Is_causal>(params, stream); + } else { // sm86 and sm89 + // 80KB, 1 CTAs in sm86 and sm 89. run_flash_fwd, Is_causal>(params, stream); } } diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 804e145..bbbab3f 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -44,13 +44,9 @@ from .configuration_doge import DogeConfig try: - from flash_dmattn import flash_dmattn_func_auto + from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward except ImportError: - def flash_dmattn_func_auto(*args, **kwargs): - raise ImportError( - "flash_dmattn is not installed. Please install it to use flash_dmattn_func_auto. " - "You can install it with `pip install flash-dmattn` or consult the documentation." - ) + flash_dynamic_mask_attention_forward = None if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask @@ -163,6 +159,7 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + attention_bias: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], @@ -171,10 +168,11 @@ def eager_attention_forward( value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_bias is not None: + attn_weights = attn_weights + attention_bias if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) @@ -183,59 +181,6 @@ def eager_attention_forward( return attn_output, attn_weights -def flex_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Union[torch.Tensor, "BlockMask"], - scaling: Optional[float] = None, - softcap: Optional[float] = None, - head_mask: Optional[torch.Tensor] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - block_mask = None - causal_mask = None - if isinstance(attention_mask, BlockMask): - block_mask = attention_mask - else: - causal_mask = attention_mask - - if causal_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): - if softcap is not None: - score = softcap * torch.tanh(score / softcap) - if causal_mask is not None: - score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx] - if head_mask is not None: - score = score + head_mask[batch_idx][head_idx][0][0] - return score - - attn_output, attention_weights = compile_friendly_flex_attention( - query, - key, - value, - score_mod=score_mod, - block_mask=block_mask, - enable_gqa=True, - scale=scaling, - # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. - # For simplification, we thus always return it as no additional computations are introduced. - return_lse=True, - ) - # lse is returned in float32 - attention_weights = attention_weights.to(value.dtype) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attention_weights - - -ALL_ATTENTION_FUNCTIONS = AttentionInterface() -ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward - - class DogeAttention(nn.Module): def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): super().__init__() @@ -292,77 +237,33 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # calculate dynamic mask from value_states + # sampling dt_states from value_states to generate attention bias dt_states = self.dt_proj( value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) ) dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attn_bias, attn_mask = self.prepare_dynamic_mask( - hidden_states=hidden_states, - dt_states=dt_states, - keep_window_size=self.keep_window_size, - attention_mask=attention_mask, - ) + attn_bias = dt_states[:, :, None, :].expand( + -1, -1, hidden_states.shape[1], -1 + ).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len] - attention_interface: Callable = flash_dmattn_func_auto(backend="cuda") - query_states = query_states.transpose(1, 2).contiguous() # [B, H, Q_LEN, D] - key_states = key_states.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] - value_states = value_states.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] + attention_interface: Callable = eager_attention_forward + if flash_dynamic_mask_attention_forward is not None: + attention_interface = flash_dynamic_mask_attention_forward - attn_output = attention_interface( + attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) if attention_mask is not None else None, # attention_mask: batch, num_kv_heads, query_len, key_len + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=attn_mask, - attn_bias=attn_bias, - is_causal=self.is_causal, + attention_mask=attention_mask, + attention_bias=attn_bias, scale=self.scaling, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, None - - def prepare_dynamic_mask( - self, - hidden_states: torch.Tensor, - dt_states: torch.Tensor, - keep_window_size: int = 2048, - attention_mask: Optional[torch.Tensor] = None, - ): - """ - The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention. - - Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. - - Args: - hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision. - dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`. - keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. - attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`. - """ - min_dtype = torch.finfo(hidden_states.dtype).min - dtype = hidden_states.dtype - attn_bias = dt_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[1], -1 - ) # [batch_size, num_heads, query_len, key_len] - if attention_mask is not None and not isinstance(attention_mask, BlockMask): - if attention_mask.dtype == torch.bool: - attention_mask = torch.where( - attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype - ) - attn_bias = attn_bias.masked_fill(attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype) - if attn_bias.shape[-1] > keep_window_size: - topk_values, topk_indices = torch.topk( - attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - valid_topk = topk_values != min_dtype - attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) - attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) - else: - attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) - return attn_bias, attn_mask + return attn_output, attn_weights class DogeMLP(nn.Module):