Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ggml/src/ggml-sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
file(GLOB GGML_SOURCES_SYCL "*.cpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})

# Include flash-attn sources (SYCL optimized flash attention implementation)
file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp")
file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})

# Also include kernel headers under flash-attn/kernels
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})

if (WIN32)
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))
Expand Down
96 changes: 96 additions & 0 deletions ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#include "flash-attn-sycl.h"

#include "kernels/flash-attn-kernel.h"

#include <cmath>
#include <cstring>
#include <limits>
#include <sycl/sycl.hpp>

#define FLASH_ATTN_BR_MAX 32
#define FLASH_ATTN_BC_MAX 32

// Flash Attention: https://arxiv.org/abs/2205.14135
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];

GGML_ASSERT(Q != nullptr);
GGML_ASSERT(K != nullptr);
GGML_ASSERT(V != nullptr);
GGML_ASSERT(dst != nullptr);

if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
return;
}

const float * Q_d = (const float *) Q->data;
const float * K_d = (const float *) K->data;
const float * V_d = (const float *) V->data;
float * dst_d = (float *) dst->data;

dpct::queue_ptr stream = ctx.stream();

const int64_t d = Q->ne[0];
const int64_t N = Q->ne[1];

float scale;
float max_bias;
float logit_softcap;

std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));

const bool masked = (mask != nullptr);

const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);

const int Tr = (N + Br - 1) / Br;
const int Tc = (N + Bc - 1) / Bc;

float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);

stream->fill(l_d, 0.0f, N);
stream->fill(m_d, -std::numeric_limits<float>::infinity(), N);
stream->fill(dst_d, 0.0f, N * d);

for (int j = 0; j < Tc; ++j) {
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
const int i = idx[0];
flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br,
Bc, N, d, masked, scale);
});
});
}


sycl::free(l_d, *stream);
sycl::free(m_d, *stream);
}

bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

if (Q == nullptr || K == nullptr || V == nullptr) {
return false;
}

if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) {
return false;
}

if (dst->type != GGML_TYPE_F32) {
return false;
}

return true;
}
11 changes: 11 additions & 0 deletions ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include "../common.hpp"


// Flash attention operation for SYCL backend
// This implements the Flash Attention algorithm optimized for SYCL devices
void ggml_sycl_op_flash_attn( ggml_backend_sycl_context & ctx, ggml_tensor * dst);

// Check if flash attention is supported for given tensor
bool ggml_sycl_flash_attn_ext_supported(const struct ggml_tensor * dst);
108 changes: 108 additions & 0 deletions ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#pragma once

#include <sycl/sycl.hpp>

template <int Br_MAX = 32, int Bc_MAX = 32>
inline void flash_attn_tiled_kernel(const float * Q,
const float * K,
const float * V,
float * O,
float * l,
float * m,
const int i_block,
const int j_block,
const int Br,
const int Bc,
const int N,
const int d,
const bool masked,
const float scale) {
const int i_start = i_block * Br;
const int j_start = j_block * Bc;

float S[Br_MAX][Bc_MAX];
float P[Br_MAX][Bc_MAX];
float m_local[Br_MAX];
float l_local[Br_MAX];

for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}

for (int kj = 0; kj < Bc; ++kj) {
const int k_row = j_start + kj;
if (k_row >= N) {
S[qi][kj] = -INFINITY;
continue;
}

if (masked && k_row > q_row) {
S[qi][kj] = -INFINITY;
continue;
}

float score = 0.0f;
for (int k = 0; k < d; ++k) {
score += Q[q_row * d + k] * K[k_row * d + k];
}
S[qi][kj] = score * scale;
}
}

for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}

m_local[qi] = -INFINITY;
for (int kj = 0; kj < Bc; ++kj) {
if (j_start + kj < N) {
m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]);
}
}

l_local[qi] = 0.0f;
for (int kj = 0; kj < Bc; ++kj) {
if (j_start + kj < N && !sycl::isinf(S[qi][kj])) {
P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]);
l_local[qi] += P[qi][kj];
} else {
P[qi][kj] = 0.0f;
}
}
}

for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}

const float m_old = m[q_row];
const float m_new = sycl::fmax(m_old, m_local[qi]);
const float l_old = l[q_row];
const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi];

const float correction_old = sycl::exp(m_old - m_new);
const float correction_new = sycl::exp(m_local[qi] - m_new);

for (int k = 0; k < d; ++k) {
float pv = 0.0f;
for (int kj = 0; kj < Bc; ++kj) {
const int v_row = j_start + kj;
if (v_row < N) {
pv += P[qi][kj] * V[v_row * d + k];
}
}

const int o_idx = q_row * d + k;
O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new;
}

l[q_row] = l_new;
m[q_row] = m_new;
}
}
6 changes: 6 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "ggml-sycl/element_wise.hpp"
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
#include "flash-attn/flash-attn-sycl.h"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/sycl_hw.hpp"
Expand Down Expand Up @@ -3839,6 +3840,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_sycl_argsort(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_sycl_op_flash_attn(ctx, dst);
break;
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_sycl_op_timestep_embedding(ctx, dst);
break;
Expand Down Expand Up @@ -4501,6 +4505,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
return ggml_sycl_flash_attn_ext_supported(op);
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
Expand Down
Loading