Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16)

# Create custom causal mask with cache position
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache position calculation assumes key_len >= query_len, but if query_len > key_len, this will create a range with negative start value, which may not be the intended behavior for cache positioning.

Suggested change
# Create custom causal mask with cache position
# Create custom causal mask with cache position
if key_len < query_len:
raise ValueError(f"Invalid configuration: key_len ({key_len}) must be >= query_len ({query_len}).")

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache position calculation assumes key_len >= query_len, but this assumption may not always hold. Consider adding validation or handling the case where key_len < query_len to prevent potential negative start values in torch.arange.

Suggested change
# Create custom causal mask with cache position
# Create custom causal mask with cache position
if key_len < query_len:
raise ValueError(f"Invalid configuration: key_len ({key_len}) must be >= query_len ({query_len}).")

Copilot uses AI. Check for mistakes.
cache_position = torch.arange(0, query_len + 0, device=device)
cache_position = torch.arange(key_len - query_len, key_len, device=device)
min_type = torch.finfo(value_states.dtype).min
causal_mask = torch.full(
(query_len, key_len), fill_value=min_type,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_
A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16)

# Create custom causal mask with cache position
cache_position = torch.arange(0, query_len + 0, device=device)
cache_position = torch.arange(key_len - query_len, key_len, device=device)
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache position calculation assumes key_len >= query_len, but if query_len > key_len, this will create a range with negative start value, which may not be the intended behavior for cache positioning.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache position calculation assumes key_len >= query_len, but this assumption may not always hold. Consider adding validation or handling the case where key_len < query_len to prevent potential negative start values in torch.arange.

Copilot uses AI. Check for mistakes.
min_type = torch.finfo(value_states.dtype).min
causal_mask = torch.full(
(query_len, key_len), fill_value=min_type,
Expand Down
6 changes: 2 additions & 4 deletions csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// Init dynamic mask processor
FLASH_NAMESPACE::Mask<Is_causal> mask(
binfo.actual_seqlen_k, binfo.actual_seqlen_q,
params.keep_window_size
binfo.actual_seqlen_k, binfo.actual_seqlen_q
);

// For performance reason, we separate out two kinds of iterations:
Expand Down Expand Up @@ -961,8 +960,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons

// Init dynamic mask processor
FLASH_NAMESPACE::Mask<Is_causal> mask(
binfo.actual_seqlen_k, binfo.actual_seqlen_q,
params.keep_window_size
binfo.actual_seqlen_k, binfo.actual_seqlen_q
);

// For performance reason, we separate out two kinds of iterations:
Expand Down
111 changes: 67 additions & 44 deletions csrc/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,13 @@ __forceinline__ __device__ void apply_mask(
template <bool Is_causal>
struct Mask {
const int max_seqlen_k, max_seqlen_q;
const int keep_window_size;

__forceinline__ __device__ Mask(
const int max_seqlen_k,
const int max_seqlen_q,
const int keep_window_size
const int max_seqlen_q
) // Constructor
: max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q)
, keep_window_size(keep_window_size) {
, max_seqlen_q(max_seqlen_q) {
};

template <bool Causal_mask=false, bool Is_even_MN=true, typename TensorType, typename MaskType, typename BiasType>
Expand All @@ -86,61 +83,87 @@ struct Mask {
static_assert(MaskType::rank == 3, "Mask must be 3D Tensor");
static_assert(BiasType::rank == 3, "Bias must be 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k);

// Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout()));
Tensor mask = make_tensor(Mask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Mask.layout()));
Tensor bias = make_tensor(Bias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Bias.layout()));

const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if (Need_masking) {
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f);
if (inactive) {
tensor(coord) = -INFINITY;
} else {
// Apply scaling and bias
tensor(coord) = tensor(coord) * scale_softmax + bias(coord);
}
}
}
}
}
} else {
// If no masking is needed, just scale the tensor and add bias
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
// const int row_idx_base = row_idx_offset + mi * warp_row_stride;
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
// const int row_idx = row_idx_base + i * 8;
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
// const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
// const int col_idx = col_idx_base + j;
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f);
if (inactive) {
tensor(coord) = -INFINITY;
} else {
// Apply scaling and bias
tensor(coord) = tensor(coord) * scale_softmax + bias(coord);
}
}
}
}
}
// const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k);
// if (Need_masking) {
// #pragma unroll
// for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
// const int row_idx_base = row_idx_offset + mi * warp_row_stride;
// #pragma unroll
// for (int i = 0; i < size<0, 0>(tensor); ++i) {
// const int row_idx = row_idx_base + i * 8;
// const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
// #pragma unroll
// for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
// const int col_idx_base = col_idx_offset + nj * 8;
// #pragma unroll
// for (int j = 0; j < size<1, 0>(tensor); ++j) {
// const int col_idx = col_idx_base + j;
Comment on lines +119 to +133
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large blocks of commented-out code should be removed rather than left in the codebase. This commented code appears to be an alternative implementation that is no longer needed and clutters the file.

Copilot uses AI. Check for mistakes.
// auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
// bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f);
// if (inactive) {
// tensor(coord) = -INFINITY;
// } else {
// // Apply scaling and bias
// tensor(coord) = tensor(coord) * scale_softmax + bias(coord);
// }
// }
// }
// }
// }
// } else {
// // If no masking is needed, just scale the tensor and add bias
// #pragma unroll
// for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
// // const int row_idx_base = row_idx_offset + mi * warp_row_stride;
// #pragma unroll
// for (int i = 0; i < size<0, 0>(tensor); ++i) {
// // const int row_idx = row_idx_base + i * 8;
// #pragma unroll
// for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
// // const int col_idx_base = col_idx_offset + nj * 8;
// #pragma unroll
// for (int j = 0; j < size<1, 0>(tensor); ++j) {
// // const int col_idx = col_idx_base + j;
// auto coord = make_coord(make_coord(i, mi), make_coord(j, nj));
// tensor(coord) = tensor(coord) * scale_softmax + bias(coord);
// }
// }
// }
// }
// }
}
};

Expand Down