diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 1c3bc0a..539e30d 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -1,6 +1,11 @@ -#include "src/prefill.hpp" #include +#include "src/prefill.hpp" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + namespace FLASH_NAMESPACE { inline int round_multiple(int x, int m) { @@ -22,8 +27,24 @@ mha_fwd( const float softcap, const bool return_softmax, std::optional gen_) { + + auto device_idx = q.device().index(); + COMPAT::select_device(device_idx); + + // check inputs + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + // XPU requires head_size to be a multiple of 32 - const int head_size_og = q.size(3); const int head_size_padded = round_multiple(head_size_og, 32); at::Tensor q_padded = q; @@ -59,6 +80,7 @@ mha_fwd( out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}); } + out = out.contiguous(); // TODO: current do not support store softmax_lse out // hard code to return empty tensor for softmax_lse, S_dmask, rng_state @@ -91,8 +113,36 @@ mha_varlen_fwd( const float softcap, const bool return_softmax, std::optional gen_) { + + auto device_idx = q.device().index(); + COMPAT::select_device(device_idx); + + // check inputs + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_q.dim() == 1, "cu_seqlens_q must be 1-dimensional, but got ", cu_seqlens_q.dim(), " dimensions"); + TORCH_CHECK(cu_seqlens_k.dim() == 1, "cu_seqlens_k must be 1-dimensional, but got ", cu_seqlens_k.dim(), " dimensions"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + // Extract dimensions + const auto sizes = q.sizes(); + const int total_q = sizes[0]; + const int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + const int batch_size = cu_seqlens_q.numel() - 1; + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + // XPU requires head_size to be a multiple of 32 - const int head_size_og = q.size(2); const int head_size_padded = round_multiple(head_size_og, 32); at::Tensor q_padded = q; @@ -131,6 +181,7 @@ mha_varlen_fwd( out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}); } + out = out.contiguous(); // TODO: current do not support store softmax_lse out // hard code to return empty tensor for softmax_lse, S_dmask, rng_state diff --git a/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp b/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp index f3614dd..e4de1d9 100644 --- a/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp +++ b/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp @@ -3,6 +3,8 @@ // Define namespace based on CUTLASS_SYCL_REVISION #if defined(OLD_API) #define COMPAT syclcompat + #include #else #define COMPAT compat + #include #endif