diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912cc65..19f96607b99d9 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) +# Include flash-attn sources (SYCL optimized flash attention implementation) +file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp") +file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH}) + +# Also include kernel headers under flash-attn/kernels +file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS}) + if (WIN32) # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C")) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp new file mode 100644 index 0000000000000..d6e39e02c4452 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -0,0 +1,96 @@ +#include "flash-attn-sycl.h" + +#include "kernels/flash-attn-kernel.h" + +#include +#include +#include +#include + +#define FLASH_ATTN_BR_MAX 32 +#define FLASH_ATTN_BC_MAX 32 + +// Flash Attention: https://arxiv.org/abs/2205.14135 +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + GGML_ASSERT(Q != nullptr); + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + GGML_ASSERT(dst != nullptr); + + if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type); + return; + } + + const float * Q_d = (const float *) Q->data; + const float * K_d = (const float *) K->data; + const float * V_d = (const float *) V->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + const int64_t d = Q->ne[0]; + const int64_t N = Q->ne[1]; + + float scale; + float max_bias; + float logit_softcap; + + std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + const bool masked = (mask != nullptr); + + const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); + const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); + + const int Tr = (N + Br - 1) / Br; + const int Tc = (N + Bc - 1) / Bc; + + float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + + stream->fill(l_d, 0.0f, N); + stream->fill(m_d, -std::numeric_limits::infinity(), N); + stream->fill(dst_d, 0.0f, N * d); + + for (int j = 0; j < Tc; ++j) { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) { + const int i = idx[0]; + flash_attn_tiled_kernel(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br, + Bc, N, d, masked, scale); + }); + }); + } + + + sycl::free(l_d, *stream); + sycl::free(m_d, *stream); +} + +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + if (Q == nullptr || K == nullptr || V == nullptr) { + return false; + } + + if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_F32) { + return false; + } + + return true; +} \ No newline at end of file diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h new file mode 100644 index 0000000000000..b15bb705ec068 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h @@ -0,0 +1,11 @@ +#pragma once + +#include "../common.hpp" + + +// Flash attention operation for SYCL backend +// This implements the Flash Attention algorithm optimized for SYCL devices +void ggml_sycl_op_flash_attn( ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +// Check if flash attention is supported for given tensor +bool ggml_sycl_flash_attn_ext_supported(const struct ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h b/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h new file mode 100644 index 0000000000000..721007eaba686 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +template +inline void flash_attn_tiled_kernel(const float * Q, + const float * K, + const float * V, + float * O, + float * l, + float * m, + const int i_block, + const int j_block, + const int Br, + const int Bc, + const int N, + const int d, + const bool masked, + const float scale) { + const int i_start = i_block * Br; + const int j_start = j_block * Bc; + + float S[Br_MAX][Bc_MAX]; + float P[Br_MAX][Bc_MAX]; + float m_local[Br_MAX]; + float l_local[Br_MAX]; + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + for (int kj = 0; kj < Bc; ++kj) { + const int k_row = j_start + kj; + if (k_row >= N) { + S[qi][kj] = -INFINITY; + continue; + } + + if (masked && k_row > q_row) { + S[qi][kj] = -INFINITY; + continue; + } + + float score = 0.0f; + for (int k = 0; k < d; ++k) { + score += Q[q_row * d + k] * K[k_row * d + k]; + } + S[qi][kj] = score * scale; + } + } + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + m_local[qi] = -INFINITY; + for (int kj = 0; kj < Bc; ++kj) { + if (j_start + kj < N) { + m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]); + } + } + + l_local[qi] = 0.0f; + for (int kj = 0; kj < Bc; ++kj) { + if (j_start + kj < N && !sycl::isinf(S[qi][kj])) { + P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]); + l_local[qi] += P[qi][kj]; + } else { + P[qi][kj] = 0.0f; + } + } + } + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + const float m_old = m[q_row]; + const float m_new = sycl::fmax(m_old, m_local[qi]); + const float l_old = l[q_row]; + const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi]; + + const float correction_old = sycl::exp(m_old - m_new); + const float correction_new = sycl::exp(m_local[qi] - m_new); + + for (int k = 0; k < d; ++k) { + float pv = 0.0f; + for (int kj = 0; kj < Bc; ++kj) { + const int v_row = j_start + kj; + if (v_row < N) { + pv += P[qi][kj] * V[v_row * d + k]; + } + } + + const int o_idx = q_row * d + k; + O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new; + } + + l[q_row] = l_new; + m[q_row] = m_new; + } +} diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 33f9035075ba7..8c6536733ef4b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -41,6 +41,7 @@ #include "ggml-sycl/element_wise.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" +#include "flash-attn/flash-attn-sycl.h" #include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/set.hpp" #include "ggml-sycl/sycl_hw.hpp" @@ -3839,6 +3840,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_sycl_op_flash_attn(ctx, dst); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_sycl_op_timestep_embedding(ctx, dst); break; @@ -4501,6 +4505,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_MEAN: case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); + case GGML_OP_FLASH_ATTN_EXT: + return ggml_sycl_flash_attn_ext_supported(op); case GGML_OP_POOL_2D: case GGML_OP_ACC: return true;