From bf92d857cd56da523541211346377dfdce8bdafc Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 12:45:53 +0800 Subject: [PATCH 1/3] Removes keep_window_size parameter from Mask constructor Simplifies the mask initialization by removing the unused keep_window_size parameter from both attention computation functions. This streamlines the interface and reduces unnecessary parameter passing without affecting functionality. --- csrc/src/flash_fwd_kernel.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 33b911f..8699ceb 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -385,8 +385,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Init dynamic mask processor FLASH_NAMESPACE::Mask mask( - binfo.actual_seqlen_k, binfo.actual_seqlen_q, - params.keep_window_size + binfo.actual_seqlen_k, binfo.actual_seqlen_q ); // For performance reason, we separate out two kinds of iterations: @@ -961,8 +960,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Init dynamic mask processor FLASH_NAMESPACE::Mask mask( - binfo.actual_seqlen_k, binfo.actual_seqlen_q, - params.keep_window_size + binfo.actual_seqlen_k, binfo.actual_seqlen_q ); // For performance reason, we separate out two kinds of iterations: From bc0dd2711b4e185321c17842f698da431d6c631d Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 12:46:15 +0800 Subject: [PATCH 2/3] Simplifies mask logic by removing window size parameter Removes the keep_window_size parameter from the Mask struct and eliminates the conditional branching logic that determined whether masking was needed. Consolidates the masking logic into a single code path that always applies the mask check, reducing code complexity and potential branching overhead. The previous optimization that skipped masking when no window size limit was needed has been removed in favor of a more straightforward approach. --- csrc/src/mask.h | 111 +++++++++++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 44 deletions(-) diff --git a/csrc/src/mask.h b/csrc/src/mask.h index 166e765..8b2ebba 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -60,16 +60,13 @@ __forceinline__ __device__ void apply_mask( template struct Mask { const int max_seqlen_k, max_seqlen_q; - const int keep_window_size; __forceinline__ __device__ Mask( const int max_seqlen_k, - const int max_seqlen_q, - const int keep_window_size + const int max_seqlen_q ) // Constructor : max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) - , keep_window_size(keep_window_size) { + , max_seqlen_q(max_seqlen_q) { }; template @@ -86,7 +83,7 @@ struct Mask { static_assert(MaskType::rank == 3, "Mask must be 3D Tensor"); static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); - const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); + // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); Tensor mask = make_tensor(Mask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Mask.layout())); @@ -94,53 +91,79 @@ struct Mask { const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if (Need_masking) { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); - if (inactive) { - tensor(coord) = -INFINITY; - } else { - // Apply scaling and bias - tensor(coord) = tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } - } else { - // If no masking is needed, just scale the tensor and add bias + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - // const int row_idx_base = row_idx_offset + mi * warp_row_stride; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - // const int row_idx = row_idx_base + i * 8; + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - // const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - // const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); + if (inactive) { + tensor(coord) = -INFINITY; + } else { + // Apply scaling and bias tensor(coord) = tensor(coord) * scale_softmax + bias(coord); } } } } } + // const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); + // if (Need_masking) { + // #pragma unroll + // for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + // const int row_idx_base = row_idx_offset + mi * warp_row_stride; + // #pragma unroll + // for (int i = 0; i < size<0, 0>(tensor); ++i) { + // const int row_idx = row_idx_base + i * 8; + // const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + // #pragma unroll + // for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + // const int col_idx_base = col_idx_offset + nj * 8; + // #pragma unroll + // for (int j = 0; j < size<1, 0>(tensor); ++j) { + // const int col_idx = col_idx_base + j; + // auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); + // if (inactive) { + // tensor(coord) = -INFINITY; + // } else { + // // Apply scaling and bias + // tensor(coord) = tensor(coord) * scale_softmax + bias(coord); + // } + // } + // } + // } + // } + // } else { + // // If no masking is needed, just scale the tensor and add bias + // #pragma unroll + // for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + // // const int row_idx_base = row_idx_offset + mi * warp_row_stride; + // #pragma unroll + // for (int i = 0; i < size<0, 0>(tensor); ++i) { + // // const int row_idx = row_idx_base + i * 8; + // #pragma unroll + // for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + // // const int col_idx_base = col_idx_offset + nj * 8; + // #pragma unroll + // for (int j = 0; j < size<1, 0>(tensor); ++j) { + // // const int col_idx = col_idx_base + j; + // auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // tensor(coord) = tensor(coord) * scale_softmax + bias(coord); + // } + // } + // } + // } + // } } }; From 832edd1649039fe2e124d4d8116298031f9faac0 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 28 Jul 2025 14:37:21 +0800 Subject: [PATCH 3/3] Fixes cache position calculation for attention mechanism Corrects the cache position tensor generation to properly handle cases where key length differs from query length by calculating the starting position as key_len - query_len instead of starting from zero. This ensures proper causal masking behavior when dealing with cached key-value pairs in attention computations. --- benchmarks/benchmark_forward_equivalence.py | 2 +- benchmarks/benchmark_forward_performance.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index e724b12..60a868a 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -588,7 +588,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) # Create custom causal mask with cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) min_type = torch.finfo(value_states.dtype).min causal_mask = torch.full( (query_len, key_len), fill_value=min_type, diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 1d4e0bc..d64ae9a 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -502,7 +502,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) # Create custom causal mask with cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) min_type = torch.finfo(value_states.dtype).min causal_mask = torch.full( (query_len, key_len), fill_value=min_type,