From fc0e0413bcfbe3676c9311cd1448b353f298df48 Mon Sep 17 00:00:00 2001 From: yehudit-dev Date: Mon, 3 Nov 2025 14:50:15 +0200 Subject: [PATCH 1/6] sycl: initialize flash-attention implementation Co-authored-by: safranowith Co-authored-by: ye-NX --- ggml/src/ggml-sycl/CMakeLists.txt | 9 ++ .../ggml-sycl/flash-attn/flash-attn-sycl.cpp | 98 ++++++++++++++++ .../ggml-sycl/flash-attn/flash-attn-sycl.h | 10 ++ .../flash-attn/kernels/flash-attn-kernel.h | 108 ++++++++++++++++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 5 + 5 files changed, 230 insertions(+) create mode 100644 ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp create mode 100644 ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h create mode 100644 ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912cc65..19f96607b99d9 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) +# Include flash-attn sources (SYCL optimized flash attention implementation) +file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp") +file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH}) + +# Also include kernel headers under flash-attn/kernels +file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS}) + if (WIN32) # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C")) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp new file mode 100644 index 0000000000000..1dbc4b952a555 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -0,0 +1,98 @@ +#include "flash-attn-sycl.h" + +#include "kernels/flash-attn-kernel.h" + +#include +#include +#include +#include + +#define FLASH_ATTN_BR_MAX 32 +#define FLASH_ATTN_BC_MAX 32 + +// Flash Attention: https://arxiv.org/abs/2205.14135 +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + GGML_ASSERT(Q != nullptr); + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + GGML_ASSERT(dst != nullptr); + + if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type); + return; + } + + const float * Q_d = (const float *) Q->data; + const float * K_d = (const float *) K->data; + const float * V_d = (const float *) V->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + const int64_t d = Q->ne[0]; + const int64_t N = Q->ne[1]; + + float scale; + float max_bias; + float logit_softcap; + + std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + const bool masked = (mask != nullptr); + + const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); + const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); + + const int Tr = (N + Br - 1) / Br; + const int Tc = (N + Bc - 1) / Bc; + + float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + + stream->fill(l_d, 0.0f, N); + stream->fill(m_d, -std::numeric_limits::infinity(), N); + stream->fill(dst_d, 0.0f, N * d); + stream->wait(); + + for (int j = 0; j < Tc; ++j) { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) { + const int i = idx[0]; + flash_attn_tiled_kernel(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br, + Bc, N, d, masked, scale); + }); + }); + } + + stream->wait(); + + sycl::free(l_d, *stream); + sycl::free(m_d, *stream); +} + +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + if (Q == nullptr || K == nullptr || V == nullptr) { + return false; + } + + if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_F32) { + return false; + } + + return true; +} diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h new file mode 100644 index 0000000000000..c50d09aa0b859 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h @@ -0,0 +1,10 @@ +#pragma once + +#include "../common.hpp" + +// Flash attention operation for SYCL backend +// This implements the Flash Attention algorithm optimized for SYCL devices +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +// Check if flash attention is supported for given tensor +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h b/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h new file mode 100644 index 0000000000000..721007eaba686 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +template +inline void flash_attn_tiled_kernel(const float * Q, + const float * K, + const float * V, + float * O, + float * l, + float * m, + const int i_block, + const int j_block, + const int Br, + const int Bc, + const int N, + const int d, + const bool masked, + const float scale) { + const int i_start = i_block * Br; + const int j_start = j_block * Bc; + + float S[Br_MAX][Bc_MAX]; + float P[Br_MAX][Bc_MAX]; + float m_local[Br_MAX]; + float l_local[Br_MAX]; + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + for (int kj = 0; kj < Bc; ++kj) { + const int k_row = j_start + kj; + if (k_row >= N) { + S[qi][kj] = -INFINITY; + continue; + } + + if (masked && k_row > q_row) { + S[qi][kj] = -INFINITY; + continue; + } + + float score = 0.0f; + for (int k = 0; k < d; ++k) { + score += Q[q_row * d + k] * K[k_row * d + k]; + } + S[qi][kj] = score * scale; + } + } + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + m_local[qi] = -INFINITY; + for (int kj = 0; kj < Bc; ++kj) { + if (j_start + kj < N) { + m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]); + } + } + + l_local[qi] = 0.0f; + for (int kj = 0; kj < Bc; ++kj) { + if (j_start + kj < N && !sycl::isinf(S[qi][kj])) { + P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]); + l_local[qi] += P[qi][kj]; + } else { + P[qi][kj] = 0.0f; + } + } + } + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + const float m_old = m[q_row]; + const float m_new = sycl::fmax(m_old, m_local[qi]); + const float l_old = l[q_row]; + const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi]; + + const float correction_old = sycl::exp(m_old - m_new); + const float correction_new = sycl::exp(m_local[qi] - m_new); + + for (int k = 0; k < d; ++k) { + float pv = 0.0f; + for (int kj = 0; kj < Bc; ++kj) { + const int v_row = j_start + kj; + if (v_row < N) { + pv += P[qi][kj] * V[v_row * d + k]; + } + } + + const int o_idx = q_row * d + k; + O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new; + } + + l[q_row] = l_new; + m[q_row] = m_new; + } +} diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 33f9035075ba7..a37e89d4e47e0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3839,6 +3839,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_sycl_op_flash_attn(ctx, dst); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_sycl_op_timestep_embedding(ctx, dst); break; @@ -4501,6 +4504,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_MEAN: case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); + case GGML_OP_FLASH_ATTN_EXT: + return ggml_sycl_flash_attn_ext_supported(op); case GGML_OP_POOL_2D: case GGML_OP_ACC: return true; From dd1fde5b8e4689a86a2deccd61cd4b37c860454b Mon Sep 17 00:00:00 2001 From: safranowith Date: Tue, 4 Nov 2025 10:47:28 +0200 Subject: [PATCH 2/6] Update ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp Co-authored-by: safranowith Co-authored-by: ye-NX Co-authored-by: Neo Zhang Jianyu --- ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp index 1dbc4b952a555..c439710a35691 100644 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -59,7 +59,6 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) stream->fill(l_d, 0.0f, N); stream->fill(m_d, -std::numeric_limits::infinity(), N); stream->fill(dst_d, 0.0f, N * d); - stream->wait(); for (int j = 0; j < Tc; ++j) { stream->submit([&](sycl::handler & cgh) { From 4f52591e946b543e9e6e91ad4b09d28d8aec9bd6 Mon Sep 17 00:00:00 2001 From: safranowith Date: Tue, 4 Nov 2025 10:47:49 +0200 Subject: [PATCH 3/6] Update ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp Co-authored-by: Neo Zhang Jianyu Co-authored-by: ye-NX Co-authored-by: safranowith --- ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp index c439710a35691..5b74630eef13e 100644 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -70,7 +70,6 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) }); } - stream->wait(); sycl::free(l_d, *stream); sycl::free(m_d, *stream); From af5b6446426e4dae20fcbaa86c8108f3776b5d10 Mon Sep 17 00:00:00 2001 From: YehuditE Date: Tue, 4 Nov 2025 11:10:45 +0200 Subject: [PATCH 4/6] Update ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: safranowith Co-authored-by: ye-NX Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp index 5b74630eef13e..609fc29cba611 100644 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -40,7 +40,7 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) float scale; float max_bias; float logit_softcap; - + std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); From 8e8fb5735d3ffdbc062c5af6bed292b4ac00e026 Mon Sep 17 00:00:00 2001 From: safranowith Date: Tue, 4 Nov 2025 18:15:29 +0200 Subject: [PATCH 5/6] add include in ggml-sycl.cpp Co-authored-by: safranowith Co-authored-by: ye-NX --- .../ggml-sycl/flash-attn/flash-attn-sycl.cpp | 118 ++++++++++++++++++ .../ggml-sycl/flash-attn/flash-attn-sycl.h | 5 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 1 + 3 files changed, 122 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp index 609fc29cba611..5c2fdb7fdb8d8 100644 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -94,3 +94,121 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { return true; } + + + + +// #include "flash-attn-sycl.h" + +// #include "kernels/flash-attn-kernel.h" + +// #include +// #include +// #include +// #include + +// #ifndef GGML_USE_SYCL +// #warning "SYCL not enabled. This source file will be ignored." +// #else + +// #define FLASH_ATTN_BR_MAX 32 +// #define FLASH_ATTN_BC_MAX 32 + +// // RAII helper to free device memory automatically +// class SyclDeviceBuffer { +// public: +// SyclDeviceBuffer(sycl::queue & q, size_t count) +// : queue(q), ptr(nullptr), size(count) { +// ptr = sycl::malloc_device(count, queue); +// } + +// ~SyclDeviceBuffer() { +// if (ptr) { +// sycl::free(ptr, queue); +// } +// } + +// float * get() const { return ptr; } +// bool valid() const { return ptr != nullptr; } + +// private: +// sycl::queue & queue; +// float * ptr; +// size_t size; +// }; + +// // Flash Attention: https://arxiv.org/abs/2205.14135 +// void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +// const ggml_tensor * Q = dst->src[0]; +// const ggml_tensor * K = dst->src[1]; +// const ggml_tensor * V = dst->src[2]; +// const ggml_tensor * mask = dst->src[3]; + +// GGML_ASSERT(Q && K && V && dst); + +// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { +// fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type); +// return; +// } + +// const float * q_data = static_cast(Q->data); +// const float * k_data = static_cast(K->data); +// const float * v_data = static_cast(V->data); +// float * dst_data = static_cast(dst->data); + +// sycl::queue & stream = *ctx.stream(); + +// const int64_t d = Q->ne[0]; +// const int64_t N = Q->ne[1]; + +// float scale, max_bias, logit_softcap; +// std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); +// std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); +// std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + +// const bool masked = (mask != nullptr); + +// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); +// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); + +// const int Tr = (N + Br - 1) / Br; +// const int Tc = (N + Bc - 1) / Bc; + +// SyclDeviceBuffer l_buf(stream, N); +// SyclDeviceBuffer m_buf(stream, N); + +// if (!l_buf.valid() || !m_buf.valid()) { +// fprintf(stderr, "[SYCL] FLASH-ATTENTION: failed to allocate device buffers.\n"); +// return; +// } + +// stream.fill(l_buf.get(), 0.0f, N).wait(); +// stream.fill(m_buf.get(), -std::numeric_limits::infinity(), N).wait(); +// stream.fill(dst_data, 0.0f, ggml_nelements(dst)).wait(); + +// for (int j = 0; j < Tc; ++j) { +// stream.submit([&](sycl::handler & cgh) { +// cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) { +// const int i = idx[0]; +// flash_attn_tiled_kernel( +// q_data, k_data, v_data, dst_data, l_buf.get(), m_buf.get(), +// i, j, Br, Bc, N, d, masked, scale); +// }); +// }); +// } +// stream.wait(); +// } + +// bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { +// const ggml_tensor * Q = dst->src[0]; +// const ggml_tensor * K = dst->src[1]; +// const ggml_tensor * V = dst->src[2]; + +// if (!Q || !K || !V) return false; +// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) return false; +// if (dst->type != GGML_TYPE_F32) return false; + +// return true; +// } + +// #endif diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h index c50d09aa0b859..b15bb705ec068 100644 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h @@ -2,9 +2,10 @@ #include "../common.hpp" + // Flash attention operation for SYCL backend // This implements the Flash Attention algorithm optimized for SYCL devices -void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_op_flash_attn( ggml_backend_sycl_context & ctx, ggml_tensor * dst); // Check if flash attention is supported for given tensor -bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst); +bool ggml_sycl_flash_attn_ext_supported(const struct ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a37e89d4e47e0..8c6536733ef4b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -41,6 +41,7 @@ #include "ggml-sycl/element_wise.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" +#include "flash-attn/flash-attn-sycl.h" #include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/set.hpp" #include "ggml-sycl/sycl_hw.hpp" From dcd7ca522ed6db3da43d94dcda02688e6b104205 Mon Sep 17 00:00:00 2001 From: safranowith Date: Tue, 4 Nov 2025 18:18:43 +0200 Subject: [PATCH 6/6] remove unrelated changes Co-authored-by: safranowith Co-authored-by: ye-NX --- .../ggml-sycl/flash-attn/flash-attn-sycl.cpp | 120 +----------------- 1 file changed, 1 insertion(+), 119 deletions(-) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp index 5c2fdb7fdb8d8..d6e39e02c4452 100644 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -93,122 +93,4 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { } return true; -} - - - - -// #include "flash-attn-sycl.h" - -// #include "kernels/flash-attn-kernel.h" - -// #include -// #include -// #include -// #include - -// #ifndef GGML_USE_SYCL -// #warning "SYCL not enabled. This source file will be ignored." -// #else - -// #define FLASH_ATTN_BR_MAX 32 -// #define FLASH_ATTN_BC_MAX 32 - -// // RAII helper to free device memory automatically -// class SyclDeviceBuffer { -// public: -// SyclDeviceBuffer(sycl::queue & q, size_t count) -// : queue(q), ptr(nullptr), size(count) { -// ptr = sycl::malloc_device(count, queue); -// } - -// ~SyclDeviceBuffer() { -// if (ptr) { -// sycl::free(ptr, queue); -// } -// } - -// float * get() const { return ptr; } -// bool valid() const { return ptr != nullptr; } - -// private: -// sycl::queue & queue; -// float * ptr; -// size_t size; -// }; - -// // Flash Attention: https://arxiv.org/abs/2205.14135 -// void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -// const ggml_tensor * Q = dst->src[0]; -// const ggml_tensor * K = dst->src[1]; -// const ggml_tensor * V = dst->src[2]; -// const ggml_tensor * mask = dst->src[3]; - -// GGML_ASSERT(Q && K && V && dst); - -// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { -// fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type); -// return; -// } - -// const float * q_data = static_cast(Q->data); -// const float * k_data = static_cast(K->data); -// const float * v_data = static_cast(V->data); -// float * dst_data = static_cast(dst->data); - -// sycl::queue & stream = *ctx.stream(); - -// const int64_t d = Q->ne[0]; -// const int64_t N = Q->ne[1]; - -// float scale, max_bias, logit_softcap; -// std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); -// std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); -// std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); - -// const bool masked = (mask != nullptr); - -// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); -// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); - -// const int Tr = (N + Br - 1) / Br; -// const int Tc = (N + Bc - 1) / Bc; - -// SyclDeviceBuffer l_buf(stream, N); -// SyclDeviceBuffer m_buf(stream, N); - -// if (!l_buf.valid() || !m_buf.valid()) { -// fprintf(stderr, "[SYCL] FLASH-ATTENTION: failed to allocate device buffers.\n"); -// return; -// } - -// stream.fill(l_buf.get(), 0.0f, N).wait(); -// stream.fill(m_buf.get(), -std::numeric_limits::infinity(), N).wait(); -// stream.fill(dst_data, 0.0f, ggml_nelements(dst)).wait(); - -// for (int j = 0; j < Tc; ++j) { -// stream.submit([&](sycl::handler & cgh) { -// cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) { -// const int i = idx[0]; -// flash_attn_tiled_kernel( -// q_data, k_data, v_data, dst_data, l_buf.get(), m_buf.get(), -// i, j, Br, Bc, N, d, masked, scale); -// }); -// }); -// } -// stream.wait(); -// } - -// bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { -// const ggml_tensor * Q = dst->src[0]; -// const ggml_tensor * K = dst->src[1]; -// const ggml_tensor * V = dst->src[2]; - -// if (!Q || !K || !V) return false; -// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) return false; -// if (dst->type != GGML_TYPE_F32) return false; - -// return true; -// } - -// #endif +} \ No newline at end of file