Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 12 additions & 6 deletions flash-attn2/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,22 @@ backend = "xpu"
src = [
"flash_attn_xpu/flash_api.cpp",

"flash_attn_xpu/src/prefill.hpp",
"flash_attn_xpu/src/fixed.hpp",
"flash_attn_xpu/src/varlen.hpp",
"flash_attn_xpu/src/fmha_utils.hpp",
"flash_attn_xpu/src/compat_wrapper.hpp",

"flash_attn_xpu/src/collective/fmha_fusion.hpp",
"flash_attn_xpu/src/collective/xe_flash_attn_prefill_epilogue.hpp",
"flash_attn_xpu/src/collective/xe_flash_attn_prefill_mma.hpp",
"flash_attn_xpu/src/collective/xe_flash_attn_prefill_softmax_epilogue.hpp",
"flash_attn_xpu/src/collective/fixed_epilogue.hpp",
"flash_attn_xpu/src/collective/fixed_mma.hpp",
"flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp",
"flash_attn_xpu/src/collective/varlen_epilogue.hpp",
"flash_attn_xpu/src/collective/varlen_mma.hpp",
"flash_attn_xpu/src/collective/varlen_softmax_epilogue.hpp",

"flash_attn_xpu/src/kernel/tile_scheduler.hpp",
"flash_attn_xpu/src/kernel/xe_flash_attn_prefill.hpp",
"flash_attn_xpu/src/kernel/fixed_scheduler.hpp",
"flash_attn_xpu/src/kernel/fixed_kernel.hpp",
"flash_attn_xpu/src/kernel/varlen_scheduler.hpp",
"flash_attn_xpu/src/kernel/varlen_kernel.hpp",
]
depends = ["torch", "cutlass_sycl"]
62 changes: 33 additions & 29 deletions flash-attn2/flash_attn_xpu/flash_api.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
#include <torch/all.h>

#include "src/prefill.hpp"

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#include "src/fixed.hpp"
#include "src/varlen.hpp"

namespace FLASH_NAMESPACE {

inline int round_multiple(int x, int m) {
return (x + m - 1) / m * m;
int pad_res = (x + m - 1) / m * m;
if (pad_res == 224)
pad_res = 256;
return pad_res;
}

inline at::Tensor ensure_contiguous(const at::Tensor& tensor) {
return tensor.is_contiguous() ? tensor : tensor.contiguous();
}

std::vector<at::Tensor>
Expand All @@ -32,6 +36,7 @@ mha_fwd(
COMPAT::select_device(device_idx);

// check inputs
q = ensure_contiguous(q);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
Expand All @@ -40,9 +45,11 @@ mha_fwd(
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);

CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");

// XPU requires head_size to be a multiple of 32
const int head_size_padded = round_multiple(head_size_og, 32);
Expand Down Expand Up @@ -72,15 +79,18 @@ mha_fwd(
out_padded = torch::zeros_like(q_padded);
}

cutlass_prefill_fixed_impl(q_padded, k_padded, v_padded, out_padded, softmax_scale, is_causal);
q_padded = ensure_contiguous(q_padded);
k_padded = ensure_contiguous(k_padded);
v_padded = ensure_contiguous(v_padded);
cutlass::flash_attention::fixed::cutlass_fixed_impl(q_padded, k_padded, v_padded, out_padded, softmax_scale, is_causal);

// Remove padding from output
at::Tensor out = out_padded;
if (head_size_og != head_size_padded) {
out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)});
}
out = out.contiguous();
out = ensure_contiguous(out);

// TODO: current do not support store softmax_lse out
// hard code to return empty tensor for softmax_lse, S_dmask, rng_state
Expand Down Expand Up @@ -118,16 +128,7 @@ mha_varlen_fwd(
COMPAT::select_device(device_idx);

// check inputs
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
TORCH_CHECK(cu_seqlens_q.dim() == 1, "cu_seqlens_q must be 1-dimensional, but got ", cu_seqlens_q.dim(), " dimensions");
TORCH_CHECK(cu_seqlens_k.dim() == 1, "cu_seqlens_k must be 1-dimensional, but got ", cu_seqlens_k.dim(), " dimensions");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

// Extract dimensions
q = ensure_contiguous(q);
const auto sizes = q.sizes();
const int total_q = sizes[0];
const int num_heads = sizes[1];
Expand All @@ -136,11 +137,11 @@ mha_varlen_fwd(
const int num_heads_k = k.size(1);
const int batch_size = cu_seqlens_q.numel() - 1;

CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");

// XPU requires head_size to be a multiple of 32
const int head_size_padded = round_multiple(head_size_og, 32);
Expand Down Expand Up @@ -169,8 +170,11 @@ mha_varlen_fwd(
} else {
out_padded = torch::zeros_like(q_padded);
}

cutlass_prefill_varlen_impl(q_padded, k_padded, v_padded, out_padded,

q_padded = ensure_contiguous(q_padded);
k_padded = ensure_contiguous(k_padded);
v_padded = ensure_contiguous(v_padded);
cutlass::flash_attention::varlen::cutlass_varlen_impl(q_padded, k_padded, v_padded, out_padded, block_table_,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
softmax_scale, is_causal);
Expand All @@ -181,7 +185,7 @@ mha_varlen_fwd(
out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(0, head_size_og)});
}
out = out.contiguous();
out = ensure_contiguous(out);

// TODO: current do not support store softmax_lse out
// hard code to return empty tensor for softmax_lse, S_dmask, rng_state
Expand Down
Loading
Loading