diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 3b330ed..b5241bb 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -112,10 +112,15 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; + // printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -155,7 +160,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. - constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + constexpr static int kBlockN = Headdim <= 64 ? 128 : (Headdim <= 128 ? 64 : 32); run_flash_splitkv_fwd, Is_causal>(params, stream); } @@ -163,7 +168,7 @@ template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } @@ -173,13 +178,15 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 256) is 27% slower for seqlen=2k - // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using block size (64 x 128) is 27% slower for seqlen=2k + // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -196,7 +203,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } @@ -218,16 +225,16 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. + // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment if (is_sm8x) { if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -238,7 +245,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { // 1st ones are good for H100, A100 // 2nd one is good for A6000 bc we get slightly better occupancy } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -251,9 +258,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -278,17 +285,13 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For A100, we want to run with 128 x 64 (128KB smem). - // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + // For A100, we want to run with 64 x 64 (112KB smem). + // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // 64 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 96 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); }