Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions flash-attn2/flash_attn_xpu/flash_api.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#include "src/prefill.hpp"
#include <torch/all.h>

#include "src/prefill.hpp"

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

namespace FLASH_NAMESPACE {

inline int round_multiple(int x, int m) {
Expand All @@ -22,8 +27,24 @@ mha_fwd(
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {

auto device_idx = q.device().index();
COMPAT::select_device(device_idx);

// check inputs
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);

CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);

// XPU requires head_size to be a multiple of 32
const int head_size_og = q.size(3);
const int head_size_padded = round_multiple(head_size_og, 32);

at::Tensor q_padded = q;
Expand Down Expand Up @@ -59,6 +80,7 @@ mha_fwd(
out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)});
}
out = out.contiguous();

// TODO: current do not support store softmax_lse out
// hard code to return empty tensor for softmax_lse, S_dmask, rng_state
Expand Down Expand Up @@ -91,8 +113,36 @@ mha_varlen_fwd(
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {

auto device_idx = q.device().index();
COMPAT::select_device(device_idx);

// check inputs
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
TORCH_CHECK(cu_seqlens_q.dim() == 1, "cu_seqlens_q must be 1-dimensional, but got ", cu_seqlens_q.dim(), " dimensions");
TORCH_CHECK(cu_seqlens_k.dim() == 1, "cu_seqlens_k must be 1-dimensional, but got ", cu_seqlens_k.dim(), " dimensions");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

// Extract dimensions
const auto sizes = q.sizes();
const int total_q = sizes[0];
const int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int total_k = k.size(0);
const int num_heads_k = k.size(1);
const int batch_size = cu_seqlens_q.numel() - 1;

CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);

// XPU requires head_size to be a multiple of 32
const int head_size_og = q.size(2);
const int head_size_padded = round_multiple(head_size_og, 32);

at::Tensor q_padded = q;
Expand Down Expand Up @@ -131,6 +181,7 @@ mha_varlen_fwd(
out = out_padded.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(0, head_size_og)});
}
out = out.contiguous();

// TODO: current do not support store softmax_lse out
// hard code to return empty tensor for softmax_lse, S_dmask, rng_state
Expand Down
2 changes: 2 additions & 0 deletions flash-attn2/flash_attn_xpu/src/compat_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Define namespace based on CUTLASS_SYCL_REVISION
#if defined(OLD_API)
#define COMPAT syclcompat
#include <syclcompat.hpp>
#else
#define COMPAT compat
#include <cute/util/compat/device.hpp>
#endif
Loading