From 4961dd97ab1b5b9d6bc5513b89ca96fb724b33ee Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 26 Jun 2025 19:09:12 +0800 Subject: [PATCH 1/5] Removes cub submodule dependency Eliminates the NVIDIA cub submodule from the project configuration, likely due to deprecation or integration into CUDA toolkit, reducing external dependencies and simplifying the build process. --- .gitmodules | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.gitmodules b/.gitmodules index 48e3812..8d501cb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,3 @@ [submodule "csrc/cutlass"] path = csrc/cutlass url = https://github.com/NVIDIA/cutlass.git - -[submodule "csrc/cub"] - path = csrc/cub - url = https://github.com/NVIDIA/cub.git \ No newline at end of file From 598f1dee4a1b7624c4aeefc537d74f312ec2f6b3 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 26 Jun 2025 19:46:54 +0800 Subject: [PATCH 2/5] Adds cutlass submodule and removes unused variables Includes cutlass library as a git submodule for enhanced matrix operations support. Comments out unused index calculation variables in mask optimization path to eliminate compiler warnings while preserving the logic structure for potential future use. --- csrc/cutlass | 1 + csrc/src/mask.h | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 160000 csrc/cutlass diff --git a/csrc/cutlass b/csrc/cutlass new file mode 160000 index 0000000..889ff20 --- /dev/null +++ b/csrc/cutlass @@ -0,0 +1 @@ +Subproject commit 889ff20648b06085f450e6c5d5bd22fe001ae95d diff --git a/csrc/src/mask.h b/csrc/src/mask.h index d2394cf..849d6e1 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -87,16 +87,16 @@ struct DynamicMask { // If no masking is needed, just scale the tensor and add zoh #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; + // const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; + // const int row_idx = row_idx_base + i * 8; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + // const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; + // const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); tensor(coord) = tensor(coord) * scale_softmax + zoh(coord); } From bad45324aaefd59b59dba30985fe1ce7e76562e1 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 26 Jun 2025 19:47:57 +0800 Subject: [PATCH 3/5] Adds FlashDynamicMaskAttention C++ API implementation Implements complete forward pass functionality for flash attention with dynamic masking support, including parameter setup, kernel dispatch logic, and memory optimization through split-kv heuristics. Supports key features like dropout, softcapping, causal masking, and multi-head attention with grouped query attention optimization for single-token sequences. Provides Python bindings for integration with PyTorch tensors and CUDA operations. --- csrc/flash_api.cpp | 429 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100644 csrc/flash_api.cpp diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp new file mode 100644 index 0000000..bfeb420 --- /dev/null +++ b/csrc/flash_api.cpp @@ -0,0 +1,429 @@ +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include +#include +#include +#include // For at::Generator and at::PhiloxCudaState +#include // For at::cuda::philox::unpack + +#include + +#include "namespace_config.h" +#include "hardware_info.h" +#include "flash.h" +#include "static_switch.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +namespace FLASH_NAMESPACE { + +void set_params_fprop( + Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + const size_t keep_window_size, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor zoh, + const at::Tensor active_mask, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + bool is_causal, + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false +) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.zoh_ptr = zoh.data_ptr(); + params.active_mask_ptr = active_mask.data_ptr(); + params.o_ptr = out.data_ptr(); + + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.zoh_row_stride = zoh.stride(-2); + params.active_mask_row_stride = active_mask.stride(-2); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.zoh_head_stride = zoh.stride(-3); + params.active_mask_head_stride = active_mask.stride(-3); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.zoh_batch_stride = zoh.stride(0); + params.active_mask_batch_stride = active_mask.stride(0); + params.o_batch_stride = out.stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + params.keep_window_size = keep_window_size; + + // Set the different scale values. + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); + #endif + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + params.is_causal = is_causal; + params.is_seqlens_k_cumulative = true; + + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif + + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +std::tuple set_params_splitkv( + Flash_fwd_params ¶ms, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_k, + const int max_seqlen_q, + const int head_size_rounded, + const float p_dropout, + const int num_splits, + const int num_sm, + struct c10::TensorOptions opts +) { + + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits; + at::Tensor softmax_lse_accum; + at::Tensor out_accum; + + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); + } + if (params.num_splits > 1) { + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + } + + return std::make_tuple(softmax_lse_accum, out_accum); +} + +std::vector +mha_fwd( + at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &zoh, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &active_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const float p_dropout, + const float softmax_scale, + bool is_causal, + const int keep_window_size, + const float softcap, + const bool return_softmax, + std::optional gen_ +) { + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(zoh.dtype() == q_dtype, "zoh must have the same dtype as inputs"); + TORCH_CHECK(active_mask.dtype() == q_dtype, "active_mask must have the same dtype as inputs"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(zoh); CHECK_DEVICE(active_mask); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1) { is_causal = false; } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0; + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(zoh, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(active_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + } + } else { + out = torch::empty_like(q); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } + else { + p = torch::empty({ 0 }, opts); + } + + Flash_fwd_params params; + set_params_fprop( + params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + keep_window_size, + q, k, v, zoh, active_mask, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + p_dropout, + softmax_scale, + is_causal, + softcap + ); + + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( + params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts + ); + + // number of times random will be generated per thread, to offset philox counter in thc random + // state + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + params.rng_state = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + return {out, softmax_lse, p, rng_state}; +} +} // namespace FLASH_NAMESPACE + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashDynamicMaskAttention"; + m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); +} From c90c74a363404c56ff45eb866112c1d209f31b78 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 26 Jun 2025 19:48:13 +0800 Subject: [PATCH 4/5] Refactors benchmark to use new CUDA extension API Updates the benchmark to import and use the new flash_dma_cuda module instead of the previous flash_dma_cpp implementation. Adds proper error handling for the CUDA extension import with informative messages and graceful exit on failure. Refactors the CUDA attention function to use the new mha_fwd API signature with proper parameter mapping and tensor format requirements. Improves test configuration with additional test cases for multi-batch and GQA scenarios, and enhances the test runner with better reporting and exit codes. Fixes a bug in the dynamic mask preparation logic by moving active_mask initialization to the correct conditional branch. --- benchmarks/benchmark_forward_equivalence.py | 105 +++++++++++++++----- 1 file changed, 79 insertions(+), 26 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index d520c1a..4fb8cc3 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -14,10 +14,17 @@ import torch import torch.nn.functional as F -import numpy as np import argparse import time -from flash_dma_cpp import apply_dynamic_mask_attention # type: ignore + +# Import the compiled CUDA extension +try: + import flash_dma_cuda + print("✅ Successfully imported flash_dma_cuda") +except ImportError as e: + print(f"❌ Failed to import flash_dma_cuda: {e}") + print("Please make sure the package is properly installed with: pip install .") + exit(1) def prepare_dynamic_mask( @@ -45,7 +52,6 @@ def prepare_dynamic_mask( attn_mask = dt_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 ) # [batch_size, num_kv_heads, query_len, key_len] - active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) if attention_mask is not None: if attention_mask.dtype == torch.bool: @@ -65,7 +71,8 @@ def prepare_dynamic_mask( active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) active_mask = active_mask.scatter(-1, topk_indices, 1.0) attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) - + else: + active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) return attn_mask, active_mask @@ -140,10 +147,8 @@ def dynamic_mask_attention_python( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - batch_size, num_heads, query_len, head_dim = query_states.shape - _, num_kv_heads, key_len, _ = key_states.shape - device = query_states.device - dtype = query_states.dtype + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads @@ -201,17 +206,20 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ + # Calculate zero_hold_states zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + # Use prepare_dynamic_mask to get the processed attention mask _, active_mask = prepare_dynamic_mask( query_states, - zero_hold_states, + zero_hold_states, keep_window_size, causal_mask if is_causal else None ) # [batch_size, num_kv_heads, query_len, key_len] - # Ensure correct data types and memory layout + # Ensure correct data types and memory layout for CUDA function + # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] @@ -220,20 +228,25 @@ def dynamic_mask_attention_cuda( ).contiguous() # [batch, num_kv_heads, query_len, key_len] active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] - result = apply_dynamic_mask_attention( - query_states=query_states, - key_states=key_states, - value_states=value_states, - zoh_states=zero_hold_states, - active_mask=active_mask, - scale=scaling, - keep_window_size=keep_window_size, - is_causal=is_causal, - return_softmax=return_softmax + # Call the CUDA implementation using the mha_fwd function signature + out_tensor = None # Let the function allocate the output tensor + result = flash_dma_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) - # Convert result back to original data type - attn_outputs = result[0] + attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] return attn_outputs @@ -343,9 +356,12 @@ def test_forward_equivalence(accuracy_threshold=0.95): # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) (1, 1, 1, 64, 64, 32, True), # Small scale test, causal mask (1, 1, 1, 64, 64, 32, False), # Small scale test, non-causal mask - (1, 2, 1, 128, 128, 32, True), # Medium scale test, GQA mode (1, 1, 1, 128, 128, 32, True), # Medium scale test, causal mask (1, 1, 1, 128, 128, 32, False), # Medium scale test, non-causal mask + (1, 1, 1, 256, 256, 32, True), # Large scale test, causal mask + (1, 2, 1, 64, 64, 32, True), # Medium scale test, GQA mode + (2, 1, 1, 128, 128, 32, True), # Medium scale test, Multi batch + (2, 2, 1, 128, 128, 32, True), # Medium scale test, Multi batch GQA mode ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -456,6 +472,7 @@ def main(): of dynamic mask attention. This script validates numerical consistency including: + - Standard forward pass (fwd) - Different batch sizes, head counts, sequence lengths and dimensions - Causal and non-causal mask options - Numerical equivalence analysis @@ -470,6 +487,9 @@ def main(): parser.add_argument('--verbose', action='store_true', help='Verbose output') parser.add_argument('--accuracy-threshold', type=float, default=0.95, help='Minimum accuracy ratio to pass test (default: 0.95)') + parser.add_argument('--test-type', type=str, default='all', + choices=['all', 'fwd'], + help='Type of test to run (default: all)') args = parser.parse_args() @@ -477,6 +497,9 @@ def main(): torch.manual_seed(args.seed) # Print test environment information + print("🧬" + "=" * 78 + "🧬") + print("🔬 Dynamic Mask Attention Forward Pass Equivalence Test Suite 🔬") + print("🧬" + "=" * 78 + "🧬") print(f"🐍 PyTorch version: {torch.__version__}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_icon = "🔥" if device.type == "cuda" else "💻" @@ -484,10 +507,40 @@ def main(): if torch.cuda.is_available(): print(f"🎮 CUDA device: {torch.cuda.get_device_name()}") - print(f"🎲 Random seed: {args.seed}") + print(f"🎲 Random seed: {args.seed}") + print(f"📊 Test type: {args.test_type}") + print(f"🎯 Accuracy threshold: {args.accuracy_threshold*100:.1f}%") - # Run equivalence test - test_forward_equivalence(args.accuracy_threshold) + # Track overall test results + test_results = {} + + # Run tests based on user selection + if args.test_type in ['all', 'fwd']: + print("\n" + "📍" + " Starting Standard Forward Pass Tests " + "📍") + test_results['fwd'] = test_forward_equivalence(args.accuracy_threshold) + + + # Print overall summary + print("\n" + "🏆" + "=" * 78 + "🏆") + print("🔬 FINAL TEST SUMMARY 🔬") + print("🏆" + "=" * 78 + "🏆") + + all_passed = True + for test_name, result in test_results.items(): + status_icon = "✅" if result else "❌" + status_text = "PASSED" if result else "FAILED" + print(f" {status_icon} {test_name.upper():12} : {status_text}") + all_passed = all_passed and result + + # Overall result + overall_icon = "🎉" if all_passed else "😞" + overall_text = "ALL TESTS PASSED" if all_passed else "SOME TESTS FAILED" + print(f"\n{overall_icon} OVERALL RESULT: {overall_text}") + print("🏆" + "=" * 78 + "🏆") + + # Exit with appropriate code + import sys + sys.exit(0 if all_passed else 1) if __name__ == "__main__": From 0351b6ad5ad2bf9f9cace7497fdcd59fbb25b5bf Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Thu, 26 Jun 2025 19:48:29 +0800 Subject: [PATCH 5/5] Removes dynamic mask functionality from CUDA extension Cleans up codebase by removing unused dynamic mask attention implementation. Eliminates API wrappers, CUDA kernels, and associated infrastructure that was no longer needed, reducing maintenance overhead and code complexity. --- csrc/apply_dynamic_mask_api.cpp | 32 --- csrc/apply_dynamic_mask_attention_api.cpp | 67 ----- csrc/apply_dynamic_mask_attention_kernel.cu | 282 -------------------- csrc/apply_dynamic_mask_kernel.cu | 214 --------------- 4 files changed, 595 deletions(-) delete mode 100644 csrc/apply_dynamic_mask_api.cpp delete mode 100644 csrc/apply_dynamic_mask_attention_api.cpp delete mode 100644 csrc/apply_dynamic_mask_attention_kernel.cu delete mode 100644 csrc/apply_dynamic_mask_kernel.cu diff --git a/csrc/apply_dynamic_mask_api.cpp b/csrc/apply_dynamic_mask_api.cpp deleted file mode 100644 index 0ddfcb6..0000000 --- a/csrc/apply_dynamic_mask_api.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include - -// 声明CUDA函数 -torch::Tensor apply_dynamic_mask_cuda( - const torch::Tensor& zero_hold_states, - const int keep_window_size, - const bool is_causal); - -// 从Python调用的主API函数 -torch::Tensor apply_dynamic_mask( - const torch::Tensor& zero_hold_states, - const torch::Tensor& causal_mask, // 保留此参数以兼容Python接口,但不会使用 - const int keep_window_size = 2048, - const bool is_causal = true) { - - // 忽略causal_mask参数,只转发其他参数到CUDA实现 - return apply_dynamic_mask_cuda( - zero_hold_states, - keep_window_size, - is_causal - ); -} - -// 定义Python模块及其函数 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_dynamic_mask", &apply_dynamic_mask, - "Apply dynamic mask to attention mechanism", - py::arg("zero_hold_states"), - py::arg("causal_mask"), - py::arg("keep_window_size") = 2048, - py::arg("is_causal") = true); -} \ No newline at end of file diff --git a/csrc/apply_dynamic_mask_attention_api.cpp b/csrc/apply_dynamic_mask_attention_api.cpp deleted file mode 100644 index a25a8a3..0000000 --- a/csrc/apply_dynamic_mask_attention_api.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include - -// 声明CUDA函数 -std::vector apply_dynamic_mask_attention_cuda( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - float scale, - int keep_window_size, - bool is_causal, - bool return_softmax); - -// 主API函数,从Python调用 - 移除了冗余的causal_mask参数 -std::vector apply_dynamic_mask_attention( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - float scale = 1.0f, - int keep_window_size = 2048, - bool is_causal = true, - bool return_softmax = false) { - - // 验证所有张量都在CUDA上 - TORCH_CHECK(query_states.is_cuda(), "query_states必须是CUDA张量"); - TORCH_CHECK(key_states.is_cuda(), "key_states必须是CUDA张量"); - TORCH_CHECK(value_states.is_cuda(), "value_states必须是CUDA张量"); - TORCH_CHECK(zoh_states.is_cuda(), "zoh_states必须是CUDA张量"); - TORCH_CHECK(active_mask.is_cuda(), "active_mask必须是CUDA张量"); - - // 所有张量必须在同一设备上 - TORCH_CHECK(query_states.device() == key_states.device(), "所有张量必须在同一设备上"); - TORCH_CHECK(query_states.device() == value_states.device(), "所有张量必须在同一设备上"); - TORCH_CHECK(query_states.device() == zoh_states.device(), "所有张量必须在同一设备上"); - TORCH_CHECK(query_states.device() == active_mask.device(), "所有张量必须在同一设备上"); - - // 转发到CUDA实现 - return apply_dynamic_mask_attention_cuda( - query_states, - key_states, - value_states, - zoh_states, - active_mask, - scale, - keep_window_size, - is_causal, - return_softmax - ); -} - -// 定义Python模块和函数 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_dynamic_mask_attention", &apply_dynamic_mask_attention, - py::arg("query_states"), - py::arg("key_states"), - py::arg("value_states"), - py::arg("zoh_states"), - py::arg("active_mask"), - py::arg("scale") = 1.0f, - py::arg("keep_window_size") = 2048, - py::arg("is_causal") = true, - py::arg("return_softmax") = false, - "使用动态掩码计算注意力"); -} \ No newline at end of file diff --git a/csrc/apply_dynamic_mask_attention_kernel.cu b/csrc/apply_dynamic_mask_attention_kernel.cu deleted file mode 100644 index afdbafe..0000000 --- a/csrc/apply_dynamic_mask_attention_kernel.cu +++ /dev/null @@ -1,282 +0,0 @@ -#include -#include -#include - -// 包含CUTE库相关头文件 -#include -#include -#include - -// 包含CUTLASS库相关头文件 -#include -#include -#include - -// 项目相关头文件 -#include "src/flash.h" // flash.h 包含了 namespace_config.h -#include "src/kernel_traits.h" -#include "src/flash_fwd_kernel.h" -#include "src/utils.h" - -// 确保使用正确的命名空间 -using namespace cute; - -namespace FLASH_NAMESPACE { - -// 为每种情况定义专用内核 -template -__global__ void run_attention_fwd_kernel_template(Flash_fwd_params params) { - constexpr int kBlockM = 64; - constexpr int kBlockN = 64; - constexpr int kNWarps = 4; - - using Kernel_traits = Flash_fwd_kernel_traits; - constexpr bool kReturnSoftmax = false; - - compute_attn(params); -} - -// 修改host-side启动函数 -template -void launch_attention_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr int kBlockM = 64; - constexpr int kBlockN = 64; - constexpr int kNWarps = 4; - using Kernel_traits = Flash_fwd_kernel_traits; - - dim3 grid_dim( - cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM), - params.b, - params.h - ); - dim3 block_dim(Kernel_traits::kNThreads); - size_t smem_size = Kernel_traits::kSmemSize; - - // 检查共享内存限制 - int device; - cudaGetDevice(&device); - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, device); - - if (smem_size > prop.sharedMemPerBlock) { - printf("Warning: Shared memory size (%zu) exceeds device limit (%zu)\n", - smem_size, prop.sharedMemPerBlock); - return; - } - - // 修正: 检查序列长度是否能被块大小整除 - bool isEvenMN = false; - - // 检查头部维度是否能被 MMA tile 大小整除 - bool isEvenK = true; - - // 如果需要,设置动态共享内存 - if (smem_size > 48 * 1024) { - cudaFuncSetAttribute( - run_attention_fwd_kernel_template, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size - ); - } - - // 根据实际维度分派不同的内核版本 - run_attention_fwd_kernel_template<<>>(params); - AT_CUDA_CHECK(cudaGetLastError()); -} - -// 动态掩码注意力调度函数 -template -std::vector dynamic_mask_attention_dispatch( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - torch::Tensor& output, - torch::Tensor& softmax_lse, - float scale, - int keep_window_size, - bool is_causal, - bool return_softmax -) { - const int batch_size = query_states.size(0); - const int seq_len_q = query_states.size(1); - const int num_heads = query_states.size(2); - const int head_dim = query_states.size(3); - const int seq_len_k = key_states.size(1); - const int num_kv_heads = key_states.size(2); - - // 确保对齐 - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_dim_rounded = round_multiple(head_dim, 32); - const int seq_len_q_rounded = round_multiple(seq_len_q, 128); - const int seq_len_k_rounded = round_multiple(seq_len_k, 128); - - Flash_fwd_params params; - memset(¶ms, 0, sizeof(params)); - - // 设置参数 - params.is_bf16 = query_states.scalar_type() == torch::kBFloat16; - params.q_ptr = query_states.data_ptr(); - params.k_ptr = key_states.data_ptr(); - params.v_ptr = value_states.data_ptr(); - params.o_ptr = output.data_ptr(); - params.zoh_ptr = zoh_states.data_ptr(); - params.active_mask_ptr = active_mask.data_ptr(); - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - // 基本维度参数 - params.b = batch_size; - params.h = num_heads; - params.h_k = num_kv_heads; - params.h_h_k_ratio = num_heads / num_kv_heads; - params.seqlen_q = seq_len_q; - params.seqlen_k = seq_len_k; - params.seqlen_q_rounded = seq_len_q_rounded; - params.seqlen_k_rounded = seq_len_k_rounded; - params.d = head_dim; - params.d_rounded = head_dim_rounded; - params.total_q = seq_len_q * batch_size; - - // 步长参数 - 确保与PyTorch tensor的内存布局匹配 - params.q_batch_stride = query_states.stride(0); - params.k_batch_stride = key_states.stride(0); - params.v_batch_stride = value_states.stride(0); - params.o_batch_stride = output.stride(0); - params.zoh_batch_stride = zoh_states.stride(0); - params.active_mask_batch_stride = active_mask.stride(0); - - params.q_row_stride = query_states.stride(1); - params.k_row_stride = key_states.stride(1); - params.v_row_stride = value_states.stride(1); - params.o_row_stride = output.stride(1); - - params.q_head_stride = query_states.stride(2); - params.k_head_stride = key_states.stride(2); - params.v_head_stride = value_states.stride(2); - params.o_head_stride = output.stride(2); - params.zoh_head_stride = zoh_states.stride(1); - params.active_mask_head_stride = active_mask.stride(1); - - /// 缩放和掩码参数 - params.scale_softmax = scale; - params.scale_softmax_log2 = scale * M_LOG2E; - params.softcap = 0.0f; - params.keep_window_size = keep_window_size; - - // Dropout参数(禁用) - params.p_dropout = 1.0f; - params.p_dropout_in_uint8_t = 255; - params.rp_dropout = 1.0f; - params.scale_softmax_rp_dropout = params.scale_softmax; - - // 因果掩码参数 - params.is_causal = is_causal; - - // 添加这些重要的参数设置 - params.unpadded_lse = false; - params.seqlenq_ngroups_swapped = false; - - // 确保LSE指针设置正确 - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (is_causal) { - if (head_dim == 32) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 64) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 128) { launch_attention_fwd_(params, stream); } - else { TORCH_CHECK(false, "Unsupported head_dim for causal attention: ", head_dim); } - } else { - if (head_dim == 32) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 64) { launch_attention_fwd_(params, stream); } - // else if (head_dim == 128) { launch_attention_fwd_(params, stream); } - else { TORCH_CHECK(false, "Unsupported head_dim for non-causal attention: ", head_dim); } - } - - AT_CUDA_CHECK(cudaDeviceSynchronize()); - return {output, softmax_lse}; -} - -} // namespace FLASH_NAMESPACE - -// CUDA入口点函数,从C++ API文件调用。它保持在全局命名空间中。 -std::vector apply_dynamic_mask_attention_cuda( - const torch::Tensor& query_states, - const torch::Tensor& key_states, - const torch::Tensor& value_states, - const torch::Tensor& zoh_states, - const torch::Tensor& active_mask, - float scale, - int keep_window_size, - bool is_causal, - bool return_softmax -) { - - // 验证输入 - TORCH_CHECK(query_states.dim() == 4, "query_states must be a 4D tensor"); - TORCH_CHECK(key_states.dim() == 4, "key_states must be a 4D tensor"); - TORCH_CHECK(value_states.dim() == 4, "value_states must be a 4D tensor"); - TORCH_CHECK(zoh_states.dim() == 3, "zoh_states must be a 3D tensor"); - TORCH_CHECK(active_mask.dim() == 3, "active_mask must be a 3D tensor"); - - const int batch_size = query_states.size(0); - const int seq_len_q = query_states.size(1); - const int num_heads = query_states.size(2); - const int head_dim = query_states.size(3); - const int seq_len_k = key_states.size(1); - const int num_kv_heads = key_states.size(2); - - TORCH_CHECK(key_states.size(0) == batch_size, "Q/K batch mismatch"); - TORCH_CHECK(value_states.size(1) == seq_len_k, "K/V seq mismatch"); - TORCH_CHECK(key_states.size(3) == head_dim, "Q/K head_dim mismatch"); - TORCH_CHECK(value_states.size(0) == batch_size, "Q/V batch mismatch"); - TORCH_CHECK(value_states.size(2) == num_kv_heads, "K/V kv_heads mismatch"); - TORCH_CHECK(value_states.size(3) == head_dim, "Q/V head_dim mismatch"); - - TORCH_CHECK(query_states.scalar_type() == at::kHalf || query_states.scalar_type() == at::kBFloat16, - "Only half/bfloat16 supported"); - TORCH_CHECK(key_states.scalar_type() == query_states.scalar_type(), "All inputs must have same dtype"); - TORCH_CHECK(value_states.scalar_type() == query_states.scalar_type(), "All inputs must have same dtype"); - - TORCH_CHECK(head_dim == 32 || head_dim == 64 || head_dim == 128, "head_dim must be 32, 64, or 128"); - - TORCH_CHECK(query_states.is_contiguous(), "query_states must be contiguous"); - TORCH_CHECK(key_states.is_contiguous(), "key_states must be contiguous"); - TORCH_CHECK(value_states.is_contiguous(), "value_states must be contiguous"); - TORCH_CHECK(zoh_states.is_contiguous(), "zoh_states must be contiguous"); - TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous"); - - auto output_options = torch::TensorOptions() - .dtype(query_states.dtype()) - .device(query_states.device()); - auto output = torch::zeros({batch_size, seq_len_q, num_heads, head_dim}, output_options); - - auto softmax_lse_options = torch::TensorOptions() - .dtype(torch::kFloat32) - .device(query_states.device()); - auto softmax_lse = torch::zeros({batch_size, num_heads, seq_len_q}, softmax_lse_options); - - c10::cuda::CUDAGuard device_guard(query_states.device()); - - std::vector result_tensors; - if (query_states.scalar_type() == at::kHalf) { - result_tensors = FLASH_NAMESPACE::dynamic_mask_attention_dispatch( - query_states, key_states, value_states, zoh_states, active_mask, - output, softmax_lse, scale, keep_window_size, is_causal, return_softmax - ); - } else if (query_states.scalar_type() == at::kBFloat16) { - result_tensors = FLASH_NAMESPACE::dynamic_mask_attention_dispatch( - query_states, key_states, value_states, zoh_states, active_mask, - output, softmax_lse, scale, keep_window_size, is_causal, return_softmax - ); - } else { - TORCH_CHECK(false, "apply_attention only supports half and bfloat16"); - } - - if (return_softmax) { - return {result_tensors[0], result_tensors[1]}; - } else { - return {result_tensors[0]}; - } -} \ No newline at end of file diff --git a/csrc/apply_dynamic_mask_kernel.cu b/csrc/apply_dynamic_mask_kernel.cu deleted file mode 100644 index b8c5364..0000000 --- a/csrc/apply_dynamic_mask_kernel.cu +++ /dev/null @@ -1,214 +0,0 @@ -#include -#include -#include - -#include "src/namespace_config.h" -#include "src/mask.h" -#include "src/utils.h" -#include "src/hardware_info.h" -#include "src/static_switch.h" - - -using namespace FLASH_NAMESPACE; -using namespace cute; - -// 重新设计的动态掩码CUDA内核,使用DynamicMask结构体 -template -__global__ void apply_dynamic_mask_kernel( - scalar_t* output_ptr, - const scalar_t* zero_hold_states_ptr, - const int batch_size, - const int num_kv_heads, - const int query_len, - const int key_len, - const int keep_window_size -) { - // 使用mask.h中的DynamicMask结构体 - DynamicMask dynamic_mask(key_len, query_len, keep_window_size); - - // 动态分配共享内存 - extern __shared__ char smem[]; - scalar_t* smem_zero_hold = reinterpret_cast(smem); - bool* smem_active_indices = reinterpret_cast(smem_zero_hold + kBlockM * kBlockN); - - // 计算当前线程块处理的批次和头部索引 - const int batch_head_idx = blockIdx.y * gridDim.z + blockIdx.z; - const int b_idx = batch_head_idx / num_kv_heads; - const int kv_idx = batch_head_idx % num_kv_heads; - - if (b_idx >= batch_size) return; - - // 计算当前线程块处理的行和列索引 - const int row_idx_offset = blockIdx.x * kBlockM; - const int col_idx_offset = 0; // 处理整行 - - // 计算全局内存偏移 - const int batch_head_offset = (b_idx * num_kv_heads + kv_idx) * query_len * key_len; - - // 创建共享内存张量 - 使用3D布局以匹配DynamicMask的期望 - // 布局: (MMA=4, MMA_M, MMA_N) - constexpr int MMA = 4; - constexpr int MMA_M = kBlockM / (2 * 8); // 2个外部行,每个8行 - constexpr int MMA_N = kBlockN / (2 * 1); // 2列 - - auto smem_zero_hold_tensor = make_tensor( - make_smem_ptr(smem_zero_hold), - make_shape(Int{}, Int{}, Int{}), - make_stride(Int{}, Int{}, Int<1>{}) - ); - - auto smem_active_indices_tensor = make_tensor( - make_smem_ptr(smem_active_indices), - make_shape(Int{}, Int{}, Int{}), - make_stride(Int{}, Int{}, Int<1>{}) - ); - - // 协作加载数据到共享内存 - const int tid = threadIdx.x; - const int elements_per_thread = (kBlockM * kBlockN + blockDim.x - 1) / blockDim.x; - - #pragma unroll - for (int i = 0; i < elements_per_thread; ++i) { - int elem_idx = tid * elements_per_thread + i; - if (elem_idx < kBlockM * kBlockN) { - int local_row = elem_idx / kBlockN; - int local_col = elem_idx % kBlockN; - int global_row = row_idx_offset + local_row; - int global_col = col_idx_offset + local_col; - - if (global_row < query_len && global_col < key_len) { - smem_zero_hold[elem_idx] = zero_hold_states_ptr[ - batch_head_offset + global_row * key_len + global_col - ]; - } else { - smem_zero_hold[elem_idx] = scalar_t(0.0f); - } - smem_active_indices[elem_idx] = true; - } - } - __syncthreads(); - - // 使用DynamicMask处理 - dynamic_mask.get_active_zerohold( - smem_zero_hold_tensor, - smem_active_indices_tensor, - col_idx_offset, - row_idx_offset, - 1 // warp_row_stride - ); - - // 将结果写回全局内存 - #pragma unroll - for (int i = 0; i < elements_per_thread; ++i) { - int elem_idx = tid * elements_per_thread + i; - if (elem_idx < kBlockM * kBlockN) { - int local_row = elem_idx / kBlockN; - int local_col = elem_idx % kBlockN; - int global_row = row_idx_offset + local_row; - int global_col = col_idx_offset + local_col; - - if (global_row < query_len && global_col < key_len) { - output_ptr[batch_head_offset + global_row * key_len + global_col] = - smem_zero_hold[elem_idx]; - } - } - } -} - -template -void apply_dynamic_mask_cuda_impl( - torch::Tensor& output, - const torch::Tensor& zero_hold_states, - const int keep_window_size -) { - // 获取维度 - const int batch_size = zero_hold_states.size(0); - const int num_kv_heads = zero_hold_states.size(1); - const int query_len = zero_hold_states.size(2); - const int key_len = zero_hold_states.size(3); - - // 使用较小的块尺寸以适应共享内存 - constexpr int kBlockM = 16; // 处理16行 - constexpr int kBlockN = 128; // 直接使用 128,不用 min 函数 - - // 计算共享内存大小 - const int smem_size = kBlockM * kBlockN * sizeof(scalar_t) + - kBlockM * kBlockN * sizeof(bool); - - // 检查共享内存大小是否超过限制 - cudaDeviceProp props; - cudaGetDeviceProperties(&props, zero_hold_states.device().index()); - TORCH_CHECK(smem_size <= props.sharedMemPerBlock, - "共享内存需求(", smem_size, "字节)超过设备限制(", - props.sharedMemPerBlock, "字节)"); - - // 配置线程块和网格 - constexpr int threads_per_block = 256; - dim3 block(threads_per_block); - - // 计算需要的块数 - const int grid_m = (query_len + kBlockM - 1) / kBlockM; - const int batch_head_count = batch_size * num_kv_heads; - - // 使用y和z维度来处理批次和头部 - dim3 grid( - grid_m, - min(batch_head_count, 65535), - (batch_head_count + 65534) / 65535 - ); - - // 启动CUDA内核 - apply_dynamic_mask_kernel - <<>>( - output.data_ptr(), - zero_hold_states.data_ptr(), - batch_size, - num_kv_heads, - query_len, - key_len, - keep_window_size - ); - - // 检查CUDA错误 - cudaError_t err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed: ", cudaGetErrorString(err)); -} - -// 主接口函数 -torch::Tensor apply_dynamic_mask_cuda( - const torch::Tensor& zero_hold_states, - const int keep_window_size, - const bool is_causal -) { - - // 验证输入 - TORCH_CHECK(zero_hold_states.dim() == 4, "zero_hold_states必须是4D张量 [batch_size, num_kv_heads, query_len, key_len]"); - - // 所有张量必须是CUDA张量 - TORCH_CHECK(zero_hold_states.is_cuda(), "zero_hold_states必须是CUDA张量"); - - // 获取维度 - const int batch_size = zero_hold_states.size(0); - const int num_kv_heads = zero_hold_states.size(1); - const int query_len = zero_hold_states.size(2); - const int key_len = zero_hold_states.size(3); - - // 创建输出张量并复制输入(因为需要原地修改) - auto output = zero_hold_states.clone(); - - // 设置当前设备 - c10::cuda::CUDAGuard device_guard(zero_hold_states.device()); - - // 根据数据类型和因果掩码标志分发实现 - AT_DISPATCH_FLOATING_TYPES_AND_HALF(zero_hold_states.scalar_type(), "apply_dynamic_mask", ([&] { - if (is_causal) { - apply_dynamic_mask_cuda_impl( - output, zero_hold_states, keep_window_size); - } else { - apply_dynamic_mask_cuda_impl( - output, zero_hold_states, keep_window_size); - } - })); - - return output; -} \ No newline at end of file