Skip to content

Commit fd89556

Browse files
authored
[SYCL] Add BF16 support to GET_ROWS operation (#21391)
Add GGML_TYPE_BF16 to the SYCL backend's GET_ROWS operation, both in supports_op and in the kernel dispatch. This fixes a performance regression where models using BF16 embedding tensors (e.g., Gemma4's per_layer_token_embd.weight) fall back to CPU for the GET_ROWS op, causing a full GPU-to-CPU tensor transfer every token. The fix reuses the existing get_rows_sycl_float template with sycl::ext::oneapi::bfloat16, matching the pattern already used for sycl::half (F16) and float (F32).
1 parent 6048993 commit fd89556

2 files changed

Lines changed: 5 additions & 0 deletions

File tree

ggml/src/ggml-sycl/getrows.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
183183
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data,
184184
src1_i32, (float *)dst->data, ctx.stream());
185185
break;
186+
case GGML_TYPE_BF16:
187+
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::ext::oneapi::bfloat16 *)dst->src[0]->data,
188+
src1_i32, (float *)dst->data, ctx.stream());
189+
break;
186190
case GGML_TYPE_F32:
187191
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
188192
src1_i32, (float *)dst->data, ctx.stream());

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4974,6 +4974,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
49744974
{
49754975
switch (op->src[0]->type) {
49764976
case GGML_TYPE_F16:
4977+
case GGML_TYPE_BF16:
49774978
case GGML_TYPE_F32:
49784979
case GGML_TYPE_Q4_0:
49794980
case GGML_TYPE_Q4_1:

0 commit comments

Comments
 (0)