From 3b330488e7f4b2a90aaf7df7169072523b07fee6 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Tue, 28 Oct 2025 17:40:33 +0000 Subject: [PATCH 1/7] Support multiple XPU loading models --- flash-attn2/flash_attn_xpu/flash_api.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 1c3bc0a..2db4fcc 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -1,6 +1,8 @@ #include "src/prefill.hpp" #include +#include "cute/util/compat/device.hpp" + namespace FLASH_NAMESPACE { inline int round_multiple(int x, int m) { @@ -22,6 +24,10 @@ mha_fwd( const float softcap, const bool return_softmax, std::optional gen_) { + + auto device_idx = q.device().index(); + compat::select_device(device_idx); + // XPU requires head_size to be a multiple of 32 const int head_size_og = q.size(3); const int head_size_padded = round_multiple(head_size_og, 32); @@ -91,6 +97,10 @@ mha_varlen_fwd( const float softcap, const bool return_softmax, std::optional gen_) { + + auto device_idx = q.device().index(); + compat::select_device(device_idx); + // XPU requires head_size to be a multiple of 32 const int head_size_og = q.size(2); const int head_size_padded = round_multiple(head_size_og, 32); From 032fa30c395a74d9d4f1349d192c5be643faa46c Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Tue, 28 Oct 2025 17:44:22 +0000 Subject: [PATCH 2/7] Update flake.lock --- flash-attn2/flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash-attn2/flake.lock b/flash-attn2/flake.lock index 66cc1fc..e05ef3d 100644 --- a/flash-attn2/flake.lock +++ b/flash-attn2/flake.lock @@ -98,11 +98,11 @@ ] }, "locked": { - "lastModified": 1761222363, - "narHash": "sha256-jJqlTvy6T4nn1bEMLRbfBDeaRdTolm8WsNsRC6g4x0s=", + "lastModified": 1761645431, + "narHash": "sha256-Ns3m/L+FMAYnmKhwt4vlIf8lq6dOJWHAocFL23HasTM=", "owner": "huggingface", "repo": "kernel-builder", - "rev": "889428ef4ef25c2b9ef275dbd46ae272022d4fd5", + "rev": "289788986c318e6ccb92608f011c49d61b25b5b6", "type": "github" }, "original": { From 9fdbe7d490d811431d1c196c5881c3c2dfa49450 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Wed, 29 Oct 2025 01:57:23 +0000 Subject: [PATCH 3/7] Fix --- flash-attn2/flash_attn_xpu/flash_api.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 2db4fcc..b6326bf 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -1,7 +1,7 @@ -#include "src/prefill.hpp" #include +#include -#include "cute/util/compat/device.hpp" +#include "src/prefill.hpp" namespace FLASH_NAMESPACE { From c0cb2eaefc5c2333fc4d679b9248fffddb374f21 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Wed, 29 Oct 2025 01:59:51 +0000 Subject: [PATCH 4/7] Update flake.lock --- flash-attn2/flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash-attn2/flake.lock b/flash-attn2/flake.lock index e05ef3d..b2c6e3d 100644 --- a/flash-attn2/flake.lock +++ b/flash-attn2/flake.lock @@ -98,11 +98,11 @@ ] }, "locked": { - "lastModified": 1761645431, - "narHash": "sha256-Ns3m/L+FMAYnmKhwt4vlIf8lq6dOJWHAocFL23HasTM=", + "lastModified": 1761680882, + "narHash": "sha256-7C4kAI5gysfIpgrZPuKOY2WtNBQ8BW0kVIYW//R9iQc=", "owner": "huggingface", "repo": "kernel-builder", - "rev": "289788986c318e6ccb92608f011c49d61b25b5b6", + "rev": "d509f5f9f6dcb0b402212df8c5a7a15536cb891a", "type": "github" }, "original": { From ed83412d112dba984999d49c580b0b3da0d99479 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Thu, 30 Oct 2025 08:45:38 +0000 Subject: [PATCH 5/7] Add check func and fix multi-XPU BUG --- flash-attn2/flash_attn_xpu/flash_api.cpp | 46 ++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index b6326bf..2647993 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -3,6 +3,10 @@ #include "src/prefill.hpp" +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + namespace FLASH_NAMESPACE { inline int round_multiple(int x, int m) { @@ -28,8 +32,20 @@ mha_fwd( auto device_idx = q.device().index(); compat::select_device(device_idx); + // check inputs + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + // XPU requires head_size to be a multiple of 32 - const int head_size_og = q.size(3); const int head_size_padded = round_multiple(head_size_og, 32); at::Tensor q_padded = q; @@ -65,6 +81,7 @@ mha_fwd( out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}); } + out = out.contiguous(); // TODO: current do not support store softmax_lse out // hard code to return empty tensor for softmax_lse, S_dmask, rng_state @@ -101,8 +118,32 @@ mha_varlen_fwd( auto device_idx = q.device().index(); compat::select_device(device_idx); + // check inputs + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_q.dim() == 1, "cu_seqlens_q must be 1-dimensional, but got ", cu_seqlens_q.dim(), " dimensions"); + TORCH_CHECK(cu_seqlens_k.dim() == 1, "cu_seqlens_k must be 1-dimensional, but got ", cu_seqlens_k.dim(), " dimensions"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + // Extract dimensions + const auto sizes = q.sizes(); + const int total_q = sizes[0]; + const int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + const int batch_size = cu_seqlens_q.numel() - 1; + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + // XPU requires head_size to be a multiple of 32 - const int head_size_og = q.size(2); const int head_size_padded = round_multiple(head_size_og, 32); at::Tensor q_padded = q; @@ -141,6 +182,7 @@ mha_varlen_fwd( out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}); } + out = out.contiguous(); // TODO: current do not support store softmax_lse out // hard code to return empty tensor for softmax_lse, S_dmask, rng_state From 46b34026bb3e55ab8dac8afa439117596781adc6 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Thu, 30 Oct 2025 13:58:14 +0000 Subject: [PATCH 6/7] Add support for old api --- flash-attn2/flash_attn_xpu/flash_api.cpp | 5 ++--- flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 2647993..539e30d 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -1,5 +1,4 @@ #include -#include #include "src/prefill.hpp" @@ -30,7 +29,7 @@ mha_fwd( std::optional gen_) { auto device_idx = q.device().index(); - compat::select_device(device_idx); + COMPAT::select_device(device_idx); // check inputs const auto sizes = q.sizes(); @@ -116,7 +115,7 @@ mha_varlen_fwd( std::optional gen_) { auto device_idx = q.device().index(); - compat::select_device(device_idx); + COMPAT::select_device(device_idx); // check inputs TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); diff --git a/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp b/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp index f3614dd..31a3575 100644 --- a/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp +++ b/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp @@ -3,6 +3,8 @@ // Define namespace based on CUTLASS_SYCL_REVISION #if defined(OLD_API) #define COMPAT syclcompat + #include #else #define COMPAT compat + #include #endif From af667ca5f5f02ecec63fbcf4aafe59371de6640a Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Thu, 30 Oct 2025 15:56:12 +0000 Subject: [PATCH 7/7] Fix compile bug --- flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp b/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp index 31a3575..e4de1d9 100644 --- a/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp +++ b/flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp @@ -3,7 +3,7 @@ // Define namespace based on CUTLASS_SYCL_REVISION #if defined(OLD_API) #define COMPAT syclcompat - #include + #include #else #define COMPAT compat #include