diff --git a/.gitmodules b/.gitmodules index 48e3812..8d501cb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,3 @@ [submodule "csrc/cutlass"] path = csrc/cutlass url = https://github.com/NVIDIA/cutlass.git - -[submodule "csrc/cub"] - path = csrc/cub - url = https://github.com/NVIDIA/cub.git \ No newline at end of file diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index d520c1a..4fb8cc3 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -14,10 +14,17 @@ import torch import torch.nn.functional as F -import numpy as np import argparse import time -from flash_dma_cpp import apply_dynamic_mask_attention # type: ignore + +# Import the compiled CUDA extension +try: + import flash_dma_cuda + print("✅ Successfully imported flash_dma_cuda") +except ImportError as e: + print(f"❌ Failed to import flash_dma_cuda: {e}") + print("Please make sure the package is properly installed with: pip install .") + exit(1) def prepare_dynamic_mask( @@ -45,7 +52,6 @@ def prepare_dynamic_mask( attn_mask = dt_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 ) # [batch_size, num_kv_heads, query_len, key_len] - active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) if attention_mask is not None: if attention_mask.dtype == torch.bool: @@ -65,7 +71,8 @@ def prepare_dynamic_mask( active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) active_mask = active_mask.scatter(-1, topk_indices, 1.0) attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) - + else: + active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) return attn_mask, active_mask @@ -140,10 +147,8 @@ def dynamic_mask_attention_python( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - batch_size, num_heads, query_len, head_dim = query_states.shape - _, num_kv_heads, key_len, _ = key_states.shape - device = query_states.device - dtype = query_states.dtype + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads @@ -201,17 +206,20 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ + # Calculate zero_hold_states zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + # Use prepare_dynamic_mask to get the processed attention mask _, active_mask = prepare_dynamic_mask( query_states, - zero_hold_states, + zero_hold_states, keep_window_size, causal_mask if is_causal else None ) # [batch_size, num_kv_heads, query_len, key_len] - # Ensure correct data types and memory layout + # Ensure correct data types and memory layout for CUDA function + # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] @@ -220,20 +228,25 @@ def dynamic_mask_attention_cuda( ).contiguous() # [batch, num_kv_heads, query_len, key_len] active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] - result = apply_dynamic_mask_attention( - query_states=query_states, - key_states=key_states, - value_states=value_states, - zoh_states=zero_hold_states, - active_mask=active_mask, - scale=scaling, - keep_window_size=keep_window_size, - is_causal=is_causal, - return_softmax=return_softmax + # Call the CUDA implementation using the mha_fwd function signature + out_tensor = None # Let the function allocate the output tensor + result = flash_dma_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) - # Convert result back to original data type - attn_outputs = result[0] + attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] return attn_outputs @@ -343,9 +356,12 @@ def test_forward_equivalence(accuracy_threshold=0.95): # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) (1, 1, 1, 64, 64, 32, True), # Small scale test, causal mask (1, 1, 1, 64, 64, 32, False), # Small scale test, non-causal mask - (1, 2, 1, 128, 128, 32, True), # Medium scale test, GQA mode (1, 1, 1, 128, 128, 32, True), # Medium scale test, causal mask (1, 1, 1, 128, 128, 32, False), # Medium scale test, non-causal mask + (1, 1, 1, 256, 256, 32, True), # Large scale test, causal mask + (1, 2, 1, 64, 64, 32, True), # Medium scale test, GQA mode + (2, 1, 1, 128, 128, 32, True), # Medium scale test, Multi batch + (2, 2, 1, 128, 128, 32, True), # Medium scale test, Multi batch GQA mode ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -456,6 +472,7 @@ def main(): of dynamic mask attention. This script validates numerical consistency including: + - Standard forward pass (fwd) - Different batch sizes, head counts, sequence lengths and dimensions - Causal and non-causal mask options - Numerical equivalence analysis @@ -470,6 +487,9 @@ def main(): parser.add_argument('--verbose', action='store_true', help='Verbose output') parser.add_argument('--accuracy-threshold', type=float, default=0.95, help='Minimum accuracy ratio to pass test (default: 0.95)') + parser.add_argument('--test-type', type=str, default='all', + choices=['all', 'fwd'], + help='Type of test to run (default: all)') args = parser.parse_args() @@ -477,6 +497,9 @@ def main(): torch.manual_seed(args.seed) # Print test environment information + print("🧬" + "=" * 78 + "🧬") + print("🔬 Dynamic Mask Attention Forward Pass Equivalence Test Suite 🔬") + print("🧬" + "=" * 78 + "🧬") print(f"🐍 PyTorch version: {torch.__version__}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_icon = "🔥" if device.type == "cuda" else "💻" @@ -484,10 +507,40 @@ def main(): if torch.cuda.is_available(): print(f"🎮 CUDA device: {torch.cuda.get_device_name()}") - print(f"🎲 Random seed: {args.seed}") + print(f"🎲 Random seed: {args.seed}") + print(f"📊 Test type: {args.test_type}") + print(f"🎯 Accuracy threshold: {args.accuracy_threshold*100:.1f}%") - # Run equivalence test - test_forward_equivalence(args.accuracy_threshold) + # Track overall test results + test_results = {} + + # Run tests based on user selection + if args.test_type in ['all', 'fwd']: + print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍") + test_results['fwd'] = test_forward_equivalence(args.accuracy_threshold) + + + # Print overall summary + print("\n" + "🏆" + "=" * 78 + "🏆") + print("🔬 FINAL TEST SUMMARY 🔬") + print("🏆" + "=" * 78 + "🏆") + + all_passed = True + for test_name, result in test_results.items(): + status_icon = "✅" if result else "❌" + status_text = "PASSED" if result else "FAILED" + print(f" {status_icon} {test_name.upper():12} : {status_text}") + all_passed = all_passed and result + + # Overall result + overall_icon = "🎉" if all_passed else "😞" + overall_text = "ALL TESTS PASSED" if all_passed else "SOME TESTS FAILED" + print(f"\n{overall_icon} OVERALL RESULT: {overall_text}") + print("🏆" + "=" * 78 + "🏆") + + # Exit with appropriate code + import sys + sys.exit(0 if all_passed else 1) if __name__ == "__main__": diff --git a/csrc/apply_dynamic_mask_api.cpp b/csrc/apply_dynamic_mask_api.cpp deleted file mode 100644 index 0ddfcb6..0000000 --- a/csrc/apply_dynamic_mask_api.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include - -// 声明CUDA函数 -torch::Tensor apply_dynamic_mask_cuda( - const torch::Tensor& zero_hold_states, - const int keep_window_size, - const bool is_causal); - -// 从Python调用的主API函数 -torch::Tensor apply_dynamic_mask( - const torch::Tensor& zero_hold_states, - const torch::Tensor& causal_mask, // 保留此参数以兼容Python接口,但不会使用 - const int keep_window_size = 2048, - const bool is_causal = true) { - - // 忽略causal_mask参数,只转发其他参数到CUDA实现 - return apply_dynamic_mask_cuda( - zero_hold_states, - keep_window_size, - is_causal - ); -} - -// 定义Python模块及其函数 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_dynamic_mask", &apply_dynamic_mask, - "Apply dynamic mask to attention mechanism", - py::arg("zero_hold_states"), - py::arg("causal_mask"), - py::arg("keep_window_size") = 2048, - py::arg("is_causal") = true); -} \ No newline at end of file diff --git a/csrc/apply_dynamic_mask_attention_api.cpp b/csrc/apply_dynamic_mask_attention_api.cpp deleted file mode 100644 index a25a8a3..0000000 --- a/csrc/apply_dynamic_mask_attention_api.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include - -// 声明CUDA函数 -std::vector apply_dynamic_mask_attention_cuda( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - float scale, - int keep_window_size, - bool is_causal, - bool return_softmax); - -// 主API函数,从Python调用 - 移除了冗余的causal_mask参数 -std::vector apply_dynamic_mask_attention( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - float scale = 1.0f, - int keep_window_size = 2048, - bool is_causal = true, - bool return_softmax = false) { - - // 验证所有张量都在CUDA上 - TORCH_CHECK(query_states.is_cuda(), "query_states必须是CUDA张量"); - TORCH_CHECK(key_states.is_cuda(), "key_states必须是CUDA张量"); - TORCH_CHECK(value_states.is_cuda(), "value_states必须是CUDA张量"); - TORCH_CHECK(zoh_states.is_cuda(), "zoh_states必须是CUDA张量"); - TORCH_CHECK(active_mask.is_cuda(), "active_mask必须是CUDA张量"); - - // 所有张量必须在同一设备上 - TORCH_CHECK(query_states.device() == key_states.device(), "所有张量必须在同一设备上"); - TORCH_CHECK(query_states.device() == value_states.device(), "所有张量必须在同一设备上"); - TORCH_CHECK(query_states.device() == zoh_states.device(), "所有张量必须在同一设备上"); - TORCH_CHECK(query_states.device() == active_mask.device(), "所有张量必须在同一设备上"); - - // 转发到CUDA实现 - return apply_dynamic_mask_attention_cuda( - query_states, - key_states, - value_states, - zoh_states, - active_mask, - scale, - keep_window_size, - is_causal, - return_softmax - ); -} - -// 定义Python模块和函数 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_dynamic_mask_attention", &apply_dynamic_mask_attention, - py::arg("query_states"), - py::arg("key_states"), - py::arg("value_states"), - py::arg("zoh_states"), - py::arg("active_mask"), - py::arg("scale") = 1.0f, - py::arg("keep_window_size") = 2048, - py::arg("is_causal") = true, - py::arg("return_softmax") = false, - "使用动态掩码计算注意力"); -} \ No newline at end of file diff --git a/csrc/apply_dynamic_mask_attention_kernel.cu b/csrc/apply_dynamic_mask_attention_kernel.cu deleted file mode 100644 index afdbafe..0000000 --- a/csrc/apply_dynamic_mask_attention_kernel.cu +++ /dev/null @@ -1,282 +0,0 @@ -#include -#include -#include - -// 包含CUTE库相关头文件 -#include -#include -#include - -// 包含CUTLASS库相关头文件 -#include -#include -#include - -// 项目相关头文件 -#include "src/flash.h" // flash.h 包含了 namespace_config.h -#include "src/kernel_traits.h" -#include "src/flash_fwd_kernel.h" -#include "src/utils.h" - -// 确保使用正确的命名空间 -using namespace cute; - -namespace FLASH_NAMESPACE { - -// 为每种情况定义专用内核 -template -__global__ void run_attention_fwd_kernel_template(Flash_fwd_params params) { - constexpr int kBlockM = 64; - constexpr int kBlockN = 64; - constexpr int kNWarps = 4; - - using Kernel_traits = Flash_fwd_kernel_traits; - constexpr bool kReturnSoftmax = false; - - compute_attn(params); -} - -// 修改host-side启动函数 -template -void launch_attention_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int kBlockM = 64; - constexpr int kBlockN = 64; - constexpr int kNWarps = 4; - using Kernel_traits = Flash_fwd_kernel_traits; - - dim3 grid_dim( - cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM), - params.b, - params.h - ); - dim3 block_dim(Kernel_traits::kNThreads); - size_t smem_size = Kernel_traits::kSmemSize; - - // 检查共享内存限制 - int device; - cudaGetDevice(&device); - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, device); - - if (smem_size > prop.sharedMemPerBlock) { - printf("Warning: Shared memory size (%zu) exceeds device limit (%zu)\n", - smem_size, prop.sharedMemPerBlock); - return; - } - - // 修正: 检查序列长度是否能被块大小整除 - bool isEvenMN = false; - - // 检查头部维度是否能被 MMA tile 大小整除 - bool isEvenK = true; - - // 如果需要,设置动态共享内存 - if (smem_size > 48 * 1024) { - cudaFuncSetAttribute( - run_attention_fwd_kernel_template, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size - ); - } - - // 根据实际维度分派不同的内核版本 - run_attention_fwd_kernel_template<<>>(params); - AT_CUDA_CHECK(cudaGetLastError()); -} - -// 动态掩码注意力调度函数 -template -std::vector dynamic_mask_attention_dispatch( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - torch::Tensor& output, - torch::Tensor& softmax_lse, - float scale, - int keep_window_size, - bool is_causal, - bool return_softmax -) { - const int batch_size = query_states.size(0); - const int seq_len_q = query_states.size(1); - const int num_heads = query_states.size(2); - const int head_dim = query_states.size(3); - const int seq_len_k = key_states.size(1); - const int num_kv_heads = key_states.size(2); - - // 确保对齐 - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_dim_rounded = round_multiple(head_dim, 32); - const int seq_len_q_rounded = round_multiple(seq_len_q, 128); - const int seq_len_k_rounded = round_multiple(seq_len_k, 128); - - Flash_fwd_params params; - memset(¶ms, 0, sizeof(params)); - - // 设置参数 - params.is_bf16 = query_states.scalar_type() == torch::kBFloat16; - params.q_ptr = query_states.data_ptr(); - params.k_ptr = key_states.data_ptr(); - params.v_ptr = value_states.data_ptr(); - params.o_ptr = output.data_ptr(); - params.zoh_ptr = zoh_states.data_ptr(); - params.active_mask_ptr = active_mask.data_ptr(); - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - // 基本维度参数 - params.b = batch_size; - params.h = num_heads; - params.h_k = num_kv_heads; - params.h_h_k_ratio = num_heads / num_kv_heads; - params.seqlen_q = seq_len_q; - params.seqlen_k = seq_len_k; - params.seqlen_q_rounded = seq_len_q_rounded; - params.seqlen_k_rounded = seq_len_k_rounded; - params.d = head_dim; - params.d_rounded = head_dim_rounded; - params.total_q = seq_len_q * batch_size; - - // 步长参数 - 确保与PyTorch tensor的内存布局匹配 - params.q_batch_stride = query_states.stride(0); - params.k_batch_stride = key_states.stride(0); - params.v_batch_stride = value_states.stride(0); - params.o_batch_stride = output.stride(0); - params.zoh_batch_stride = zoh_states.stride(0); - params.active_mask_batch_stride = active_mask.stride(0); - - params.q_row_stride = query_states.stride(1); - params.k_row_stride = key_states.stride(1); - params.v_row_stride = value_states.stride(1); - params.o_row_stride = output.stride(1); - - params.q_head_stride = query_states.stride(2); - params.k_head_stride = key_states.stride(2); - params.v_head_stride = value_states.stride(2); - params.o_head_stride = output.stride(2); - params.zoh_head_stride = zoh_states.stride(1); - params.active_mask_head_stride = active_mask.stride(1); - - /// 缩放和掩码参数 - params.scale_softmax = scale; - params.scale_softmax_log2 = scale * M_LOG2E; - params.softcap = 0.0f; - params.keep_window_size = keep_window_size; - - // Dropout参数(禁用) - params.p_dropout = 1.0f; - params.p_dropout_in_uint8_t = 255; - params.rp_dropout = 1.0f; - params.scale_softmax_rp_dropout = params.scale_softmax; - - // 因果掩码参数 - params.is_causal = is_causal; - - // 添加这些重要的参数设置 - params.unpadded_lse = false; - params.seqlenq_ngroups_swapped = false; - - // 确保LSE指针设置正确 - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (is_causal) { - if (head_dim == 32) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 64) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 128) { launch_attention_fwd_(params, stream); } - else { TORCH_CHECK(false, "Unsupported head_dim for causal attention: ", head_dim); } - } else { - if (head_dim == 32) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 64) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 128) { launch_attention_fwd_(params, stream); } - else { TORCH_CHECK(false, "Unsupported head_dim for non-causal attention: ", head_dim); } - } - - AT_CUDA_CHECK(cudaDeviceSynchronize()); - return {output, softmax_lse}; -} - -} // namespace FLASH_NAMESPACE - -// CUDA入口点函数,从C++ API文件调用。它保持在全局命名空间中。 -std::vector apply_dynamic_mask_attention_cuda( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - float scale, - int keep_window_size, - bool is_causal, - bool return_softmax -) { - - // 验证输入 - TORCH_CHECK(query_states.dim() == 4, "query_states must be a 4D tensor"); - TORCH_CHECK(key_states.dim() == 4, "key_states must be a 4D tensor"); - TORCH_CHECK(value_states.dim() == 4, "value_states must be a 4D tensor"); - TORCH_CHECK(zoh_states.dim() == 3, "zoh_states must be a 3D tensor"); - TORCH_CHECK(active_mask.dim() == 3, "active_mask must be a 3D tensor"); - - const int batch_size = query_states.size(0); - const int seq_len_q = query_states.size(1); - const int num_heads = query_states.size(2); - const int head_dim = query_states.size(3); - const int seq_len_k = key_states.size(1); - const int num_kv_heads = key_states.size(2); - - TORCH_CHECK(key_states.size(0) == batch_size, "Q/K batch mismatch"); - TORCH_CHECK(value_states.size(1) == seq_len_k, "K/V seq mismatch"); - TORCH_CHECK(key_states.size(3) == head_dim, "Q/K head_dim mismatch"); - TORCH_CHECK(value_states.size(0) == batch_size, "Q/V batch mismatch"); - TORCH_CHECK(value_states.size(2) == num_kv_heads, "K/V kv_heads mismatch"); - TORCH_CHECK(value_states.size(3) == head_dim, "Q/V head_dim mismatch"); - - TORCH_CHECK(query_states.scalar_type() == at::kHalf || query_states.scalar_type() == at::kBFloat16, - "Only half/bfloat16 supported"); - TORCH_CHECK(key_states.scalar_type() == query_states.scalar_type(), "All inputs must have same dtype"); - TORCH_CHECK(value_states.scalar_type() == query_states.scalar_type(), "All inputs must have same dtype"); - - TORCH_CHECK(head_dim == 32 || head_dim == 64 || head_dim == 128, "head_dim must be 32, 64, or 128"); - - TORCH_CHECK(query_states.is_contiguous(), "query_states must be contiguous"); - TORCH_CHECK(key_states.is_contiguous(), "key_states must be contiguous"); - TORCH_CHECK(value_states.is_contiguous(), "value_states must be contiguous"); - TORCH_CHECK(zoh_states.is_contiguous(), "zoh_states must be contiguous"); - TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous"); - - auto output_options = torch::TensorOptions() - .dtype(query_states.dtype()) - .device(query_states.device()); - auto output = torch::zeros({batch_size, seq_len_q, num_heads, head_dim}, output_options); - - auto softmax_lse_options = torch::TensorOptions() - .dtype(torch::kFloat32) - .device(query_states.device()); - auto softmax_lse = torch::zeros({batch_size, num_heads, seq_len_q}, softmax_lse_options); - - c10::cuda::CUDAGuard device_guard(query_states.device()); - - std::vector result_tensors; - if (query_states.scalar_type() == at::kHalf) { - result_tensors = FLASH_NAMESPACE::dynamic_mask_attention_dispatch( - query_states, key_states, value_states, zoh_states, active_mask, - output, softmax_lse, scale, keep_window_size, is_causal, return_softmax - ); - } else if (query_states.scalar_type() == at::kBFloat16) { - result_tensors = FLASH_NAMESPACE::dynamic_mask_attention_dispatch( - query_states, key_states, value_states, zoh_states, active_mask, - output, softmax_lse, scale, keep_window_size, is_causal, return_softmax - ); - } else { - TORCH_CHECK(false, "apply_attention only supports half and bfloat16"); - } - - if (return_softmax) { - return {result_tensors[0], result_tensors[1]}; - } else { - return {result_tensors[0]}; - } -} \ No newline at end of file diff --git a/csrc/apply_dynamic_mask_kernel.cu b/csrc/apply_dynamic_mask_kernel.cu deleted file mode 100644 index b8c5364..0000000 --- a/csrc/apply_dynamic_mask_kernel.cu +++ /dev/null @@ -1,214 +0,0 @@ -#include -#include -#include - -#include "src/namespace_config.h" -#include "src/mask.h" -#include "src/utils.h" -#include "src/hardware_info.h" -#include "src/static_switch.h" - - -using namespace FLASH_NAMESPACE; -using namespace cute; - -// 重新设计的动态掩码CUDA内核,使用DynamicMask结构体 -template -__global__ void apply_dynamic_mask_kernel( - scalar_t* output_ptr, - const scalar_t* zero_hold_states_ptr, - const int batch_size, - const int num_kv_heads, - const int query_len, - const int key_len, - const int keep_window_size -) { - // 使用mask.h中的DynamicMask结构体 - DynamicMask dynamic_mask(key_len, query_len, keep_window_size); - - // 动态分配共享内存 - extern __shared__ char smem[]; - scalar_t* smem_zero_hold = reinterpret_cast(smem); - bool* smem_active_indices = reinterpret_cast(smem_zero_hold + kBlockM * kBlockN); - - // 计算当前线程块处理的批次和头部索引 - const int batch_head_idx = blockIdx.y * gridDim.z + blockIdx.z; - const int b_idx = batch_head_idx / num_kv_heads; - const int kv_idx = batch_head_idx % num_kv_heads; - - if (b_idx >= batch_size) return; - - // 计算当前线程块处理的行和列索引 - const int row_idx_offset = blockIdx.x * kBlockM; - const int col_idx_offset = 0; // 处理整行 - - // 计算全局内存偏移 - const int batch_head_offset = (b_idx * num_kv_heads + kv_idx) * query_len * key_len; - - // 创建共享内存张量 - 使用3D布局以匹配DynamicMask的期望 - // 布局: (MMA=4, MMA_M, MMA_N) - constexpr int MMA = 4; - constexpr int MMA_M = kBlockM / (2 * 8); // 2个外部行,每个8行 - constexpr int MMA_N = kBlockN / (2 * 1); // 2列 - - auto smem_zero_hold_tensor = make_tensor( - make_smem_ptr(smem_zero_hold), - make_shape(Int{}, Int{}, Int{}), - make_stride(Int{}, Int{}, Int<1>{}) - ); - - auto smem_active_indices_tensor = make_tensor( - make_smem_ptr(smem_active_indices), - make_shape(Int{}, Int{}, Int{}), - make_stride(Int{}, Int{}, Int<1>{}) - ); - - // 协作加载数据到共享内存 - const int tid = threadIdx.x; - const int elements_per_thread = (kBlockM * kBlockN + blockDim.x - 1) / blockDim.x; - - #pragma unroll - for (int i = 0; i < elements_per_thread; ++i) { - int elem_idx = tid * elements_per_thread + i; - if (elem_idx < kBlockM * kBlockN) { - int local_row = elem_idx / kBlockN; - int local_col = elem_idx % kBlockN; - int global_row = row_idx_offset + local_row; - int global_col = col_idx_offset + local_col; - - if (global_row < query_len && global_col < key_len) { - smem_zero_hold[elem_idx] = zero_hold_states_ptr[ - batch_head_offset + global_row * key_len + global_col - ]; - } else { - smem_zero_hold[elem_idx] = scalar_t(0.0f); - } - smem_active_indices[elem_idx] = true; - } - } - __syncthreads(); - - // 使用DynamicMask处理 - dynamic_mask.get_active_zerohold( - smem_zero_hold_tensor, - smem_active_indices_tensor, - col_idx_offset, - row_idx_offset, - 1 // warp_row_stride - ); - - // 将结果写回全局内存 - #pragma unroll - for (int i = 0; i < elements_per_thread; ++i) { - int elem_idx = tid * elements_per_thread + i; - if (elem_idx < kBlockM * kBlockN) { - int local_row = elem_idx / kBlockN; - int local_col = elem_idx % kBlockN; - int global_row = row_idx_offset + local_row; - int global_col = col_idx_offset + local_col; - - if (global_row < query_len && global_col < key_len) { - output_ptr[batch_head_offset + global_row * key_len + global_col] = - smem_zero_hold[elem_idx]; - } - } - } -} - -template -void apply_dynamic_mask_cuda_impl( - torch::Tensor& output, - const torch::Tensor& zero_hold_states, - const int keep_window_size -) { - // 获取维度 - const int batch_size = zero_hold_states.size(0); - const int num_kv_heads = zero_hold_states.size(1); - const int query_len = zero_hold_states.size(2); - const int key_len = zero_hold_states.size(3); - - // 使用较小的块尺寸以适应共享内存 - constexpr int kBlockM = 16; // 处理16行 - constexpr int kBlockN = 128; // 直接使用 128,不用 min 函数 - - // 计算共享内存大小 - const int smem_size = kBlockM * kBlockN * sizeof(scalar_t) + - kBlockM * kBlockN * sizeof(bool); - - // 检查共享内存大小是否超过限制 - cudaDeviceProp props; - cudaGetDeviceProperties(&props, zero_hold_states.device().index()); - TORCH_CHECK(smem_size <= props.sharedMemPerBlock, - "共享内存需求(", smem_size, "字节)超过设备限制(", - props.sharedMemPerBlock, "字节)"); - - // 配置线程块和网格 - constexpr int threads_per_block = 256; - dim3 block(threads_per_block); - - // 计算需要的块数 - const int grid_m = (query_len + kBlockM - 1) / kBlockM; - const int batch_head_count = batch_size * num_kv_heads; - - // 使用y和z维度来处理批次和头部 - dim3 grid( - grid_m, - min(batch_head_count, 65535), - (batch_head_count + 65534) / 65535 - ); - - // 启动CUDA内核 - apply_dynamic_mask_kernel - <<>>( - output.data_ptr(), - zero_hold_states.data_ptr(), - batch_size, - num_kv_heads, - query_len, - key_len, - keep_window_size - ); - - // 检查CUDA错误 - cudaError_t err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed: ", cudaGetErrorString(err)); -} - -// 主接口函数 -torch::Tensor apply_dynamic_mask_cuda( - const torch::Tensor& zero_hold_states, - const int keep_window_size, - const bool is_causal -) { - - // 验证输入 - TORCH_CHECK(zero_hold_states.dim() == 4, "zero_hold_states必须是4D张量 [batch_size, num_kv_heads, query_len, key_len]"); - - // 所有张量必须是CUDA张量 - TORCH_CHECK(zero_hold_states.is_cuda(), "zero_hold_states必须是CUDA张量"); - - // 获取维度 - const int batch_size = zero_hold_states.size(0); - const int num_kv_heads = zero_hold_states.size(1); - const int query_len = zero_hold_states.size(2); - const int key_len = zero_hold_states.size(3); - - // 创建输出张量并复制输入(因为需要原地修改) - auto output = zero_hold_states.clone(); - - // 设置当前设备 - c10::cuda::CUDAGuard device_guard(zero_hold_states.device()); - - // 根据数据类型和因果掩码标志分发实现 - AT_DISPATCH_FLOATING_TYPES_AND_HALF(zero_hold_states.scalar_type(), "apply_dynamic_mask", ([&] { - if (is_causal) { - apply_dynamic_mask_cuda_impl( - output, zero_hold_states, keep_window_size); - } else { - apply_dynamic_mask_cuda_impl( - output, zero_hold_states, keep_window_size); - } - })); - - return output; -} \ No newline at end of file diff --git a/csrc/cutlass b/csrc/cutlass new file mode 160000 index 0000000..889ff20 --- /dev/null +++ b/csrc/cutlass @@ -0,0 +1 @@ +Subproject commit 889ff20648b06085f450e6c5d5bd22fe001ae95d diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp new file mode 100644 index 0000000..bfeb420 --- /dev/null +++ b/csrc/flash_api.cpp @@ -0,0 +1,429 @@ +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include +#include +#include +#include // For at::Generator and at::PhiloxCudaState +#include // For at::cuda::philox::unpack + +#include + +#include "namespace_config.h" +#include "hardware_info.h" +#include "flash.h" +#include "static_switch.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +namespace FLASH_NAMESPACE { + +void set_params_fprop( + Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + const size_t keep_window_size, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor zoh, + const at::Tensor active_mask, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + bool is_causal, + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false +) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.zoh_ptr = zoh.data_ptr(); + params.active_mask_ptr = active_mask.data_ptr(); + params.o_ptr = out.data_ptr(); + + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.zoh_row_stride = zoh.stride(-2); + params.active_mask_row_stride = active_mask.stride(-2); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.zoh_head_stride = zoh.stride(-3); + params.active_mask_head_stride = active_mask.stride(-3); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.zoh_batch_stride = zoh.stride(0); + params.active_mask_batch_stride = active_mask.stride(0); + params.o_batch_stride = out.stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + params.keep_window_size = keep_window_size; + + // Set the different scale values. + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); + #endif + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + params.is_causal = is_causal; + params.is_seqlens_k_cumulative = true; + + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif + + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +std::tuple set_params_splitkv( + Flash_fwd_params ¶ms, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_k, + const int max_seqlen_q, + const int head_size_rounded, + const float p_dropout, + const int num_splits, + const int num_sm, + struct c10::TensorOptions opts +) { + + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits; + at::Tensor softmax_lse_accum; + at::Tensor out_accum; + + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); + } + if (params.num_splits > 1) { + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + } + + return std::make_tuple(softmax_lse_accum, out_accum); +} + +std::vector +mha_fwd( + at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &zoh, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &active_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const float p_dropout, + const float softmax_scale, + bool is_causal, + const int keep_window_size, + const float softcap, + const bool return_softmax, + std::optional gen_ +) { + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(zoh.dtype() == q_dtype, "zoh must have the same dtype as inputs"); + TORCH_CHECK(active_mask.dtype() == q_dtype, "active_mask must have the same dtype as inputs"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(zoh); CHECK_DEVICE(active_mask); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1) { is_causal = false; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0; + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(zoh, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(active_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + } + } else { + out = torch::empty_like(q); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + else { + p = torch::empty({ 0 }, opts); + } + + Flash_fwd_params params; + set_params_fprop( + params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + keep_window_size, + q, k, v, zoh, active_mask, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal, + softcap + ); + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts + ); + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + return {out, softmax_lse, p, rng_state}; +} +} // namespace FLASH_NAMESPACE + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashDynamicMaskAttention"; + m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); +} diff --git a/csrc/src/mask.h b/csrc/src/mask.h index d2394cf..849d6e1 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -87,16 +87,16 @@ struct DynamicMask { // If no masking is needed, just scale the tensor and add zoh #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; + // const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; + // const int row_idx = row_idx_base + i * 8; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + // const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; + // const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); tensor(coord) = tensor(coord) * scale_softmax + zoh(coord); }