diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index e724b12..60a868a 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -588,7 +588,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) # Create custom causal mask with cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) min_type = torch.finfo(value_states.dtype).min causal_mask = torch.full( (query_len, key_len), fill_value=min_type, diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 1d4e0bc..d64ae9a 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -502,7 +502,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) # Create custom causal mask with cache position - cache_position = torch.arange(0, query_len + 0, device=device) + cache_position = torch.arange(key_len - query_len, key_len, device=device) min_type = torch.finfo(value_states.dtype).min causal_mask = torch.full( (query_len, key_len), fill_value=min_type, diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 33b911f..8699ceb 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -385,8 +385,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Init dynamic mask processor FLASH_NAMESPACE::Mask mask( - binfo.actual_seqlen_k, binfo.actual_seqlen_q, - params.keep_window_size + binfo.actual_seqlen_k, binfo.actual_seqlen_q ); // For performance reason, we separate out two kinds of iterations: @@ -961,8 +960,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Init dynamic mask processor FLASH_NAMESPACE::Mask mask( - binfo.actual_seqlen_k, binfo.actual_seqlen_q, - params.keep_window_size + binfo.actual_seqlen_k, binfo.actual_seqlen_q ); // For performance reason, we separate out two kinds of iterations: diff --git a/csrc/src/mask.h b/csrc/src/mask.h index 166e765..8b2ebba 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -60,16 +60,13 @@ __forceinline__ __device__ void apply_mask( template struct Mask { const int max_seqlen_k, max_seqlen_q; - const int keep_window_size; __forceinline__ __device__ Mask( const int max_seqlen_k, - const int max_seqlen_q, - const int keep_window_size + const int max_seqlen_q ) // Constructor : max_seqlen_k(max_seqlen_k) - , max_seqlen_q(max_seqlen_q) - , keep_window_size(keep_window_size) { + , max_seqlen_q(max_seqlen_q) { }; template @@ -86,7 +83,7 @@ struct Mask { static_assert(MaskType::rank == 3, "Mask must be 3D Tensor"); static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); - const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); + // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); Tensor mask = make_tensor(Mask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Mask.layout())); @@ -94,53 +91,79 @@ struct Mask { const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if (Need_masking) { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); - if (inactive) { - tensor(coord) = -INFINITY; - } else { - // Apply scaling and bias - tensor(coord) = tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } - } else { - // If no masking is needed, just scale the tensor and add bias + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - // const int row_idx_base = row_idx_offset + mi * warp_row_stride; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - // const int row_idx = row_idx_base + i * 8; + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - // const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - // const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); + if (inactive) { + tensor(coord) = -INFINITY; + } else { + // Apply scaling and bias tensor(coord) = tensor(coord) * scale_softmax + bias(coord); } } } } } + // const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); + // if (Need_masking) { + // #pragma unroll + // for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + // const int row_idx_base = row_idx_offset + mi * warp_row_stride; + // #pragma unroll + // for (int i = 0; i < size<0, 0>(tensor); ++i) { + // const int row_idx = row_idx_base + i * 8; + // const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + // #pragma unroll + // for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + // const int col_idx_base = col_idx_offset + nj * 8; + // #pragma unroll + // for (int j = 0; j < size<1, 0>(tensor); ++j) { + // const int col_idx = col_idx_base + j; + // auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); + // if (inactive) { + // tensor(coord) = -INFINITY; + // } else { + // // Apply scaling and bias + // tensor(coord) = tensor(coord) * scale_softmax + bias(coord); + // } + // } + // } + // } + // } + // } else { + // // If no masking is needed, just scale the tensor and add bias + // #pragma unroll + // for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + // // const int row_idx_base = row_idx_offset + mi * warp_row_stride; + // #pragma unroll + // for (int i = 0; i < size<0, 0>(tensor); ++i) { + // // const int row_idx = row_idx_base + i * 8; + // #pragma unroll + // for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + // // const int col_idx_base = col_idx_offset + nj * 8; + // #pragma unroll + // for (int j = 0; j < size<1, 0>(tensor); ++j) { + // // const int col_idx = col_idx_base + j; + // auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // tensor(coord) = tensor(coord) * scale_softmax + bias(coord); + // } + // } + // } + // } + // } } };