-
Notifications
You must be signed in to change notification settings - Fork 39
Remove unused parameters and simplify mask logic #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -588,7 +588,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): | |||||||||
| A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) | ||||||||||
|
|
||||||||||
| # Create custom causal mask with cache position | ||||||||||
|
||||||||||
| # Create custom causal mask with cache position | |
| # Create custom causal mask with cache position | |
| if key_len < query_len: | |
| raise ValueError(f"Invalid configuration: key_len ({key_len}) must be >= query_len ({query_len}).") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -502,7 +502,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ | |
| A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) | ||
|
|
||
| # Create custom causal mask with cache position | ||
| cache_position = torch.arange(0, query_len + 0, device=device) | ||
| cache_position = torch.arange(key_len - query_len, key_len, device=device) | ||
|
||
| min_type = torch.finfo(value_states.dtype).min | ||
| causal_mask = torch.full( | ||
| (query_len, key_len), fill_value=min_type, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,16 +60,13 @@ __forceinline__ __device__ void apply_mask( | |
| template <bool Is_causal> | ||
| struct Mask { | ||
| const int max_seqlen_k, max_seqlen_q; | ||
| const int keep_window_size; | ||
|
|
||
| __forceinline__ __device__ Mask( | ||
| const int max_seqlen_k, | ||
| const int max_seqlen_q, | ||
| const int keep_window_size | ||
| const int max_seqlen_q | ||
| ) // Constructor | ||
| : max_seqlen_k(max_seqlen_k) | ||
| , max_seqlen_q(max_seqlen_q) | ||
| , keep_window_size(keep_window_size) { | ||
| , max_seqlen_q(max_seqlen_q) { | ||
| }; | ||
|
|
||
| template <bool Causal_mask=false, bool Is_even_MN=true, typename TensorType, typename MaskType, typename BiasType> | ||
|
|
@@ -86,61 +83,87 @@ struct Mask { | |
| static_assert(MaskType::rank == 3, "Mask must be 3D Tensor"); | ||
| static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); | ||
| static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); | ||
| const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); | ||
|
|
||
| // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) | ||
| Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); | ||
| Tensor mask = make_tensor(Mask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Mask.layout())); | ||
| Tensor bias = make_tensor(Bias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Bias.layout())); | ||
|
|
||
| const int lane_id = threadIdx.x % 32; | ||
| const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; | ||
| if (Need_masking) { | ||
| #pragma unroll | ||
| for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { | ||
| const int row_idx_base = row_idx_offset + mi * warp_row_stride; | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0, 0>(tensor); ++i) { | ||
| const int row_idx = row_idx_base + i * 8; | ||
| const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; | ||
| #pragma unroll | ||
| for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { | ||
| const int col_idx_base = col_idx_offset + nj * 8; | ||
| #pragma unroll | ||
| for (int j = 0; j < size<1, 0>(tensor); ++j) { | ||
| const int col_idx = col_idx_base + j; | ||
| auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); | ||
| bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); | ||
| if (inactive) { | ||
| tensor(coord) = -INFINITY; | ||
| } else { | ||
| // Apply scaling and bias | ||
| tensor(coord) = tensor(coord) * scale_softmax + bias(coord); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| // If no masking is needed, just scale the tensor and add bias | ||
| #pragma unroll | ||
| for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { | ||
| const int row_idx_base = row_idx_offset + mi * warp_row_stride; | ||
| #pragma unroll | ||
| for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { | ||
| // const int row_idx_base = row_idx_offset + mi * warp_row_stride; | ||
| for (int i = 0; i < size<0, 0>(tensor); ++i) { | ||
| const int row_idx = row_idx_base + i * 8; | ||
| const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0, 0>(tensor); ++i) { | ||
| // const int row_idx = row_idx_base + i * 8; | ||
| for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { | ||
| const int col_idx_base = col_idx_offset + nj * 8; | ||
| #pragma unroll | ||
| for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { | ||
| // const int col_idx_base = col_idx_offset + nj * 8; | ||
| #pragma unroll | ||
| for (int j = 0; j < size<1, 0>(tensor); ++j) { | ||
| // const int col_idx = col_idx_base + j; | ||
| auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); | ||
| for (int j = 0; j < size<1, 0>(tensor); ++j) { | ||
| const int col_idx = col_idx_base + j; | ||
| auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); | ||
| bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); | ||
| if (inactive) { | ||
| tensor(coord) = -INFINITY; | ||
| } else { | ||
| // Apply scaling and bias | ||
| tensor(coord) = tensor(coord) * scale_softmax + bias(coord); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| // const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); | ||
| // if (Need_masking) { | ||
| // #pragma unroll | ||
| // for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { | ||
| // const int row_idx_base = row_idx_offset + mi * warp_row_stride; | ||
| // #pragma unroll | ||
| // for (int i = 0; i < size<0, 0>(tensor); ++i) { | ||
| // const int row_idx = row_idx_base + i * 8; | ||
| // const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; | ||
| // #pragma unroll | ||
| // for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { | ||
| // const int col_idx_base = col_idx_offset + nj * 8; | ||
| // #pragma unroll | ||
| // for (int j = 0; j < size<1, 0>(tensor); ++j) { | ||
| // const int col_idx = col_idx_base + j; | ||
|
Comment on lines
+119
to
+133
|
||
| // auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); | ||
| // bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); | ||
| // if (inactive) { | ||
| // tensor(coord) = -INFINITY; | ||
| // } else { | ||
| // // Apply scaling and bias | ||
| // tensor(coord) = tensor(coord) * scale_softmax + bias(coord); | ||
| // } | ||
| // } | ||
| // } | ||
| // } | ||
| // } | ||
| // } else { | ||
| // // If no masking is needed, just scale the tensor and add bias | ||
| // #pragma unroll | ||
| // for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { | ||
| // // const int row_idx_base = row_idx_offset + mi * warp_row_stride; | ||
| // #pragma unroll | ||
| // for (int i = 0; i < size<0, 0>(tensor); ++i) { | ||
| // // const int row_idx = row_idx_base + i * 8; | ||
| // #pragma unroll | ||
| // for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { | ||
| // // const int col_idx_base = col_idx_offset + nj * 8; | ||
| // #pragma unroll | ||
| // for (int j = 0; j < size<1, 0>(tensor); ++j) { | ||
| // // const int col_idx = col_idx_base + j; | ||
| // auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); | ||
| // tensor(coord) = tensor(coord) * scale_softmax + bias(coord); | ||
| // } | ||
| // } | ||
| // } | ||
| // } | ||
| // } | ||
| } | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cache position calculation assumes
key_len >= query_len, but ifquery_len > key_len, this will create a range with negative start value, which may not be the intended behavior for cache positioning.