diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 6ff3215d5a439..8df5bfa5625bd 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -32,6 +32,7 @@ #include "pad.hpp" #include "quantize.hpp" #include "quants.hpp" +#include "roll.hpp" #include "rope.hpp" #include "set_rows.hpp" #include "softmax.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a7e077ec8ebe0..e80f77edf1502 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3836,6 +3836,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_ROLL: + ggml_sycl_roll(ctx, dst); + break; case GGML_OP_ARANGE: ggml_sycl_arange(ctx, dst); break; @@ -4491,6 +4494,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; + case GGML_OP_ROLL: + return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: return op->type == GGML_TYPE_F32; default: diff --git a/ggml/src/ggml-sycl/roll.cpp b/ggml/src/ggml-sycl/roll.cpp new file mode 100644 index 0000000000000..1e05181789c28 --- /dev/null +++ b/ggml/src/ggml-sycl/roll.cpp @@ -0,0 +1,122 @@ +#include "roll.hpp" +#include "common.hpp" + +using namespace sycl; + +static inline int wrap_add(int i, int shift, int n) { + + int s = i + shift; + return (s >= n) ? (s - n) : s; +} + +static void kernel_roll_fused_i0_i1( + queue &q, + const float *src_d, + float *dst_d, + int ne0, int ne1, int ne2, int ne3, + int sh0, int sh1, int sh2, int sh3) +{ + if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return; + + + const int stride1 = ne0; + const int stride2 = ne0 * ne1; + const int stride3 = ne0 * ne1 * ne2; + + + const int shNe0 = (ne0 - sh0) % ne0; + const int shNe1 = (ne1 - sh1) % ne1; + const int shNe2 = (ne2 - sh2) % ne2; + const int shNe3 = (ne3 - sh3) % ne3; + + + const size_t g0 = (size_t) ne3; + const size_t g1 = (size_t) ne2; + const size_t g2 = (size_t) (ne1 * ne0); + + const range<3> global{ g0, g1, g2 }; + + q.submit([&](handler &h) { + h.parallel_for(global, [=](id<3> idx) { + const int i3 = (int) idx[0]; + const int i2 = (int) idx[1]; + + const int fused = (int) idx[2]; + const int i1 = fused / ne0; + const int i0 = fused - i1 * ne0; // fused % ne0 + + + const int idx_dst = i0 + + i1 * stride1 + + i2 * stride2 + + i3 * stride3; + + + const int s0 = wrap_add(i0, shNe0, ne0); + const int s1 = wrap_add(i1, shNe1, ne1); + const int s2 = wrap_add(i2, shNe2, ne2); + const int s3 = wrap_add(i3, shNe3, ne3); + + const int idx_src = s0 + + s1 * stride1 + + s2 * stride2 + + s3 * stride3; + + dst_d[idx_dst] = src_d[idx_src]; + }); + }); +} + +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const ggml_tensor *src = dst->src[0]; + GGML_ASSERT(src && src->type == GGML_TYPE_F32); + + const int ne0 = (int) dst->ne[0]; + const int ne1 = (int) dst->ne[1]; + const int ne2 = (int) dst->ne[2]; + const int ne3 = (int) dst->ne[3]; + + const int32_t *params = (const int32_t *) dst->op_params; + int shift0 = params[0]; + int shift1 = params[1]; + int shift2 = params[2]; + int shift3 = params[3]; + + + if ((shift0 | shift1 | shift2 | shift3) == 0) { + const size_t nb = ggml_nbytes(src); + queue *q = ctx.stream(); + SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb))); + return; + } + + auto norm = [](int sh, int n) -> int { + if (n <= 0) return 0; + sh %= n; + if (sh < 0) sh += n; + return sh; + }; + shift0 = norm(shift0, ne0); + shift1 = norm(shift1, ne1); + shift2 = norm(shift2, ne2); + shift3 = norm(shift3, ne3); + + try { + queue *q = ctx.stream(); + + const float *src_d = (const float *) src->data; + float *dst_d = (float *) dst->data; + GGML_ASSERT(src_d && dst_d); + + kernel_roll_fused_i0_i1( + *q, src_d, dst_d, + ne0, ne1, ne2, ne3, + shift0, shift1, shift2, shift3 + ); + } catch (const std::exception &e) { + std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what()); + throw; + } +} diff --git a/ggml/src/ggml-sycl/roll.hpp b/ggml/src/ggml-sycl/roll.hpp new file mode 100644 index 0000000000000..97dc03d64b24d --- /dev/null +++ b/ggml/src/ggml-sycl/roll.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_ROLL_HPP +#define GGML_SYCL_ROLL_HPP + +#include "common.hpp" + +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_ROLL_HPP