Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,258 @@ mha_fwd(
}
return {out, softmax_lse, p, rng_state};
}

std::vector<at::Tensor>
mha_varlen_fwd(
at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &attn_mask, // total_q x num_heads_k x max_seqlen_k
const at::Tensor &attn_bias, // total_q x num_heads_k x max_seqlen_k
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment incorrectly states 'total_k := \sum_{i=0}^{b} s_i' for the out_ parameter, but it should be 'total_q := \sum_{i=0}^{b} s_i' since out_ has the same shape as the query tensor q.

Suggested change
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i

Copilot uses AI. Check for mistakes.
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
const int keep_window_size,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_
) {
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x_min = cc_major >= 8;
TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs");
TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);

at::Tensor block_table;
// const bool paged_KV = block_table_.has_value();
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out code should be removed since the functionality is explicitly disabled on the next line with a TODO comment.

Suggested change
// const bool paged_KV = block_table_.has_value();

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out line should be removed as it's replaced by the hard-coded false value below. Keeping dead code reduces readability.

Suggested change
// const bool paged_KV = block_table_.has_value();

Copilot uses AI. Check for mistakes.
const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed.
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coding paged_KV to false makes the subsequent paged_KV conditional logic unreachable. Consider removing the paged_KV related code blocks or adding a feature flag instead of hard-coding false.

Suggested change
const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed.
const bool paged_KV = ENABLE_PAGED_KV; // Use feature flag to control Paged KV functionality.

Copilot uses AI. Check for mistakes.
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
Comment on lines +480 to +484
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The paged_KV feature is hardcoded to false but the function still accepts and validates the block_table parameter. Consider either removing the block_table parameter entirely or adding an early return/error when block_table is provided while paged_KV is disabled to avoid confusion and unnecessary validation overhead.

Suggested change
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
if (block_table_.has_value()) {
TORCH_CHECK(false, "block_table is not supported because paged_KV is currently disabled.");

Copilot uses AI. Check for mistakes.
Comment on lines +478 to +484
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coding paged_KV = false ignores the block_table_ parameter and makes the related validation code unreachable. Consider removing the paged KV logic entirely or adding a runtime check to reject paged KV requests with a clear error message.

Suggested change
// const bool paged_KV = block_table_.has_value();
const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed.
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
if (block_table_.has_value()) {
TORCH_CHECK(false, "Paged KV is currently not supported. Please disable the block_table_ parameter.");

Copilot uses AI. Check for mistakes.
}

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(attn_mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(attn_bias.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

const auto sizes = q.sizes();

const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);

if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");

if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case

void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0;
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
cu_seqlens_q_d = nullptr;
}

const int total_q = q.sizes()[0];

TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

CHECK_SHAPE(q, total_q, num_heads, head_size);
if (!paged_KV) {
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
CHECK_SHAPE(attn_mask, total_q, num_heads_k, max_seqlen_k);
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape check for attn_mask appears incorrect. The comment on line 439 indicates attn_mask should be total_mask x num_heads_k where total_mask := total_q x total_k, but this check expects total_q x num_heads_k x max_seqlen_k. This mismatch could cause runtime errors when the attention mask has the correct shape according to the documentation.

Suggested change
CHECK_SHAPE(attn_mask, total_q, num_heads_k, max_seqlen_k);
CHECK_SHAPE(attn_mask, total_q * total_k, num_heads_k);

Copilot uses AI. Check for mistakes.
CHECK_SHAPE(attn_bias, total_q, num_heads_k, max_seqlen_k);
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape check for attn_bias appears incorrect. The comment on line 440 indicates attn_bias should be total_bias x num_heads_k where total_bias := total_q x total_k, but this check expects total_q x num_heads_k x max_seqlen_k. This mismatch could cause runtime errors when the attention bias has the correct shape according to the documentation.

Suggested change
CHECK_SHAPE(attn_bias, total_q, num_heads_k, max_seqlen_k);
CHECK_SHAPE(attn_bias, total_q * total_k, num_heads_k);

Copilot uses AI. Check for mistakes.
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}

CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
CHECK_SHAPE(seqused_k_, batch_size);
}

at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
}
} else {
out = torch::empty_like(q);
}

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);

auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}
else {
p = torch::empty({ 0 }, opts);
}

if (zero_tensors) {
out.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_softmax) { p.zero_(); }
}

Flash_fwd_params params;
set_params_fprop(
params,
batch_size,
max_seqlen_q, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
keep_window_size,
q, k, v, attn_mask, attn_bias, out,
cu_seqlens_q_d,
cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal,
softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true
);
params.total_q = total_q;

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
}
params.page_block_size = page_block_size;
// Keep references to these tensors to extend their lifetime
at::Tensor softmax_lse_accum, out_accum;
if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding
std::tie(softmax_lse_accum, out_accum) =
set_params_splitkv(
params, batch_size, num_heads, head_size,
max_seqlen_k, max_seqlen_q, head_size_rounded,
p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts
);
}

if (leftpad_k_.has_value()) {
auto leftpad_k = leftpad_k_.value();
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
CHECK_DEVICE(leftpad_k);
CHECK_CONTIGUOUS(leftpad_k);
CHECK_SHAPE(leftpad_k, batch_size);
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
}

// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}

if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream, paged_KV);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}

if (seqlenq_ngroups_swapped) {
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
Comment on lines +671 to +674
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Magic array initialization with raw arrays makes the code less maintainable. Consider using std::array or initializer lists with clear variable names to improve readability.

Suggested change
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size};
std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size};
out = out.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data());
q = q.reshape(size_before.data()).transpose(1, 2).reshape(size_after.data());

Copilot uses AI. Check for mistakes.
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
}

return {out, softmax_lse, p, rng_state};
}

} // namespace FLASH_NAMESPACE

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashDynamicMaskAttention";
m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
}