From 73a186be6f79d8590dfa377f618e780fe07b7754 Mon Sep 17 00:00:00 2001 From: bluebread Date: Wed, 19 Nov 2025 07:06:59 +0000 Subject: [PATCH 1/5] cpu : add batching and F16/I32 support to win_part/win_unpart ops/get_rel_pos --- ggml/include/ggml.h | 12 +-- ggml/src/ggml-cpu/ops.cpp | 208 +++++++++++++++++++++++++++++++------- ggml/src/ggml.c | 28 +++-- 3 files changed, 193 insertions(+), 55 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 605fcfcb9c2..364e1454775 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2329,18 +2329,16 @@ extern "C" { struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed - // example: - // a: 768 64 64 1 - // w: 14 - // res: 768 14 14 25 - // used in sam + // a: [B, H, W, C] + // result: [B*NPY*NPX, w, w, C] + // NPY = ceil(H/w) + // NPX = ceil(W/w) GGML_API struct ggml_tensor * ggml_win_part( struct ggml_context * ctx, struct ggml_tensor * a, int w); // reverse of ggml_win_part - // used in sam GGML_API struct ggml_tensor * ggml_win_unpart( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2358,14 +2356,12 @@ extern "C" { struct ggml_tensor * a, enum ggml_unary_op op); - // used in sam GGML_API struct ggml_tensor * ggml_get_rel_pos( struct ggml_context * ctx, struct ggml_tensor * a, int qh, int kh); - // used in sam GGML_API struct ggml_tensor * ggml_add_rel_pos( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b6209588db1..ee95940abb3 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8840,35 +8840,80 @@ static void ggml_compute_forward_win_part_f32( const ggml_tensor * src0 = dst->src[0]; - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_UNARY_OP_LOCALS const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t w = ((const int32_t *)(dst->op_params))[2]; + const int32_t bs = ((const int32_t *)(dst->op_params))[2]; + const int32_t w = ((const int32_t *)(dst->op_params))[3]; assert(ne00 == ne0); - assert(ne3 == nep0*nep1); + assert(ne3 == nep0*nep1*bs); // TODO: optimize / multi-thread - for (int py = 0; py < nep1; ++py) { - for (int px = 0; px < nep0; ++px) { - const int64_t i3 = py*nep0 + px; - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i02 = py*w + i2; - const int64_t i01 = px*w + i1; - const int64_t i00 = i0; - - const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; - const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; - - if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { - ((float *) dst->data)[i] = 0.0f; - } else { - ((float *) dst->data)[i] = ((float *) src0->data)[j]; - } + for (int64_t i3 = 0; i3 < ne3; i3++) { + int px = i3 % nep0; + int py = (i3 / nep0) % nep1; + int b = i3 / (nep0 * nep1); + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i03 = b; + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + *((float *) dp) = 0; + } else { + *((float *) dp) = *((float *) sp); + } + } + } + } + } +} + +static void ggml_compute_forward_win_part_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + GGML_UNUSED(params); + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t bs = ((const int32_t *)(dst->op_params))[2]; + const int32_t w = ((const int32_t *)(dst->op_params))[3]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1*bs); + + // TODO: optimize / multi-thread + for (int64_t i3 = 0; i3 < ne3; i3++) { + int px = i3 % nep0; + int py = (i3 / nep0) % nep1; + int b = i3 / (nep0 * nep1); + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i03 = b; + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + *((ggml_fp16_t *) dp) = 0; + } else { + *((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp); } } } @@ -8883,10 +8928,16 @@ void ggml_compute_forward_win_part( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_I32: case GGML_TYPE_F32: { ggml_compute_forward_win_part_f32(params, dst); } break; + case GGML_TYPE_BF16: + case GGML_TYPE_F16: + { + ggml_compute_forward_win_part_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -8903,35 +8954,82 @@ static void ggml_compute_forward_win_unpart_f32( const ggml_tensor * src0 = dst->src[0]; - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_UNARY_OP_LOCALS const int32_t w = ((const int32_t *)(dst->op_params))[0]; // padding const int px = (w - ne1%w)%w; - //const int py = (w - ne2%w)%w; + const int py = (w - ne2%w)%w; const int npx = (px + ne1)/w; - //const int npy = (py + ne2)/w; + const int npy = (py + ne2)/w; assert(ne0 == ne00); + assert(ne03 == npx*npy*ne3); // TODO: optimize / multi-thread - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int ip2 = i2/w; - const int ip1 = i1/w; + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i03 = i3*npx*npy + ip2*npx + ip1; + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + *((float *) dp) = *((float *) sp); + } + } + } + } +} - const int64_t i02 = i2%w; - const int64_t i01 = i1%w; - const int64_t i00 = i0; +static void ggml_compute_forward_win_unpart_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + GGML_UNUSED(params); - const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; - const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + const int32_t w = ((const int32_t *)(dst->op_params))[0]; + + // padding + const int px = (w - ne1%w)%w; + const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + assert(ne03 == npx*npy*ne3); - ((float *) dst->data)[j] = ((float *) src0->data)[i]; + // TODO: optimize / multi-thread + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i03 = i3*npx*npy + ip2*npx + ip1; + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + *((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp); + } } } } @@ -8944,10 +9042,16 @@ void ggml_compute_forward_win_unpart( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_I32: case GGML_TYPE_F32: { ggml_compute_forward_win_unpart_f32(params, dst); } break; + case GGML_TYPE_BF16: + case GGML_TYPE_F16: + { + ggml_compute_forward_win_unpart_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -9101,6 +9205,32 @@ void ggml_compute_forward_glu( // ggml_compute_forward_get_rel_pos +static void ggml_compute_forward_get_rel_pos_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + GGML_UNUSED(params); + + const ggml_tensor * src0 = dst->src[0]; + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 + + GGML_TENSOR_UNARY_OP_LOCALS + + const int64_t w = ne1; + + float * src0_data = (float *) src0->data; + float * dst_data = (float *) dst->data; + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t pos = (w - i1 - 1) + i2; + for (int64_t i0 = 0; i0 < ne0; ++i0) { + dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + } + } + } +} + static void ggml_compute_forward_get_rel_pos_f16( const ggml_compute_params * params, ggml_tensor * dst) { @@ -9134,6 +9264,10 @@ void ggml_compute_forward_get_rel_pos( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rel_pos_f32(params, dst); + } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a5846a23937..c4cd3f42dde 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5419,21 +5419,19 @@ struct ggml_tensor * ggml_win_part( struct ggml_context * ctx, struct ggml_tensor * a, int w) { - GGML_ASSERT(a->ne[3] == 1); - GGML_ASSERT(a->type == GGML_TYPE_F32); - // padding const int px = (w - a->ne[1]%w)%w; const int py = (w - a->ne[2]%w)%w; + const int bs = a->ne[3]; const int npx = (px + a->ne[1])/w; const int npy = (py + a->ne[2])/w; - const int np = npx*npy; + const int np = npx*npy*bs; const int64_t ne[4] = { a->ne[0], w, w, np, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - int32_t params[] = { npx, npy, w }; + int32_t params[] = { npx, npy, bs, w }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_WIN_PART; @@ -5450,10 +5448,20 @@ struct ggml_tensor * ggml_win_unpart( int w0, int h0, int w) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + return ggml_win_unpart_ext(ctx, a, w0, h0, 1, w); +} + +struct ggml_tensor * ggml_win_unpart_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int b0, + int w) { + const int64_t ne[4] = { a->ne[0], w0, h0, b0 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + GGML_ASSERT(ggml_is_contiguous(a)); int32_t params[] = { w }; ggml_set_op_params(result, params, sizeof(params)); @@ -5471,8 +5479,7 @@ struct ggml_tensor * ggml_get_rel_pos( struct ggml_tensor * a, int qh, int kh) { - GGML_ASSERT(qh == kh); - GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); + GGML_ASSERT(qh + kh - 1 <= a->ne[1]); const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); @@ -6560,6 +6567,7 @@ static void ggml_compute_backward( } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: case GGML_OP_UNARY: { switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_ABS: { From 4d52d20fa5108cbb27260aa085fa6dc995090d23 Mon Sep 17 00:00:00 2001 From: bluebread Date: Wed, 19 Nov 2025 08:39:41 +0000 Subject: [PATCH 2/5] cuda : implement CUDA backend support for rel-pos and window operations --- ggml/include/ggml.h | 20 ++ ggml/src/ggml-cpu/ops.cpp | 32 +-- ggml/src/ggml-cuda/ggml-cuda.cu | 22 ++ ggml/src/ggml-cuda/rel-pos.cu | 135 ++++++++++ ggml/src/ggml-cuda/rel-pos.cuh | 6 + ggml/src/ggml-cuda/win.cu | 430 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/win.cuh | 10 + ggml/src/ggml.c | 6 +- 8 files changed, 642 insertions(+), 19 deletions(-) create mode 100644 ggml/src/ggml-cuda/rel-pos.cu create mode 100644 ggml/src/ggml-cuda/rel-pos.cuh create mode 100644 ggml/src/ggml-cuda/win.cu create mode 100644 ggml/src/ggml-cuda/win.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 364e1454775..7bbf0177bc7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2346,6 +2346,20 @@ extern "C" { int h0, int w); + // reverse of ggml_win_part with explicit output dimensions + // a: [C, w, w, B*NPY*NPX] + // result: [C, w0, h0, b0] + // w0, h0: output width and height (may differ from input due to padding removal) + // b0: output batch size + // w: window size (must match the one used in ggml_win_part) + GGML_API struct ggml_tensor * ggml_win_unpart_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int b0, + int w); + GGML_API struct ggml_tensor * ggml_unary( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2356,6 +2370,12 @@ extern "C" { struct ggml_tensor * a, enum ggml_unary_op op); + // relative position encoding + // a: [C, rel_pos_size] + // res: [C, kh, qh] + // where rel_pos_size >= qh + kh - 1 + // extracts relative position embeddings for attention + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 GGML_API struct ggml_tensor * ggml_get_rel_pos( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ee95940abb3..7984542f0e1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8863,13 +8863,13 @@ static void ggml_compute_forward_win_part_f32( const int64_t i01 = px*w + i1; const int64_t i00 = i0; - void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; - void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { *((float *) dp) = 0; } else { - *((float *) dp) = *((float *) sp); + *((float *) dp) = *((const float *) sp); } } } @@ -8907,13 +8907,13 @@ static void ggml_compute_forward_win_part_f16( const int64_t i01 = px*w + i1; const int64_t i00 = i0; - void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; - void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { *((ggml_fp16_t *) dp) = 0; } else { - *((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp); + *((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp); } } } @@ -8981,10 +8981,10 @@ static void ggml_compute_forward_win_unpart_f32( const int64_t i01 = i1%w; const int64_t i00 = i0; - void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; - void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - *((float *) dp) = *((float *) sp); + *((float *) dp) = *((const float *) sp); } } } @@ -9025,10 +9025,10 @@ static void ggml_compute_forward_win_unpart_f16( const int64_t i01 = i1%w; const int64_t i00 = i0; - void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; - void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + const char * sp = ((const char *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + char * dp = ((char *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - *((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp); + *((ggml_fp16_t *) dp) = *((const ggml_fp16_t *) sp); } } } @@ -9216,14 +9216,14 @@ static void ggml_compute_forward_get_rel_pos_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t w = ne1; + const int64_t kh = ne1; float * src0_data = (float *) src0->data; float * dst_data = (float *) dst->data; for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (w - i1 - 1) + i2; + const int64_t pos = (kh - i1 - 1) + i2; for (int64_t i0 = 0; i0 < ne0; ++i0) { dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; } @@ -9242,14 +9242,14 @@ static void ggml_compute_forward_get_rel_pos_f16( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t w = ne1; + const int64_t kh = ne1; ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (w - i1 - 1) + i2; + const int64_t pos = (kh - i1 - 1) + i2; for (int64_t i0 = 0; i0 < ne0; ++i0) { dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 889801cb5da..dfa50d4f752 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -34,6 +34,7 @@ #include "ggml-cuda/pad.cuh" #include "ggml-cuda/pool2d.cuh" #include "ggml-cuda/quantize.cuh" +#include "ggml-cuda/rel-pos.cuh" #include "ggml-cuda/rope.cuh" #include "ggml-cuda/roll.cuh" #include "ggml-cuda/scale.cuh" @@ -48,6 +49,7 @@ #include "ggml-cuda/topk-moe.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" +#include "ggml-cuda/win.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" #include "ggml-cuda/set.cuh" @@ -2717,6 +2719,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; + case GGML_OP_WIN_PART: + ggml_cuda_op_win_part(ctx, dst); + break; + case GGML_OP_WIN_UNPART: + ggml_cuda_op_win_unpart(ctx, dst); + break; + case GGML_OP_GET_REL_POS: + ggml_cuda_op_get_rel_pos(ctx, dst); + break; default: return false; } @@ -4152,6 +4163,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + return true; + default: + return false; + } default: return false; } diff --git a/ggml/src/ggml-cuda/rel-pos.cu b/ggml/src/ggml-cuda/rel-pos.cu new file mode 100644 index 00000000000..1eaef024bad --- /dev/null +++ b/ggml/src/ggml-cuda/rel-pos.cu @@ -0,0 +1,135 @@ +#include "common.cuh" +#include "ggml.h" +#include "ggml-cuda/rel-pos.cuh" + +/* + +static void ggml_compute_forward_get_rel_pos_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + GGML_UNUSED(params); + + const ggml_tensor * src0 = dst->src[0]; + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 + + GGML_TENSOR_UNARY_OP_LOCALS + + const int64_t kh = ne1; + + ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; + ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t pos = (kh - i1 - 1) + i2; + for (int64_t i0 = 0; i0 < ne0; ++i0) { + dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + } + } + } +} + + +void ggml_compute_forward_get_rel_pos( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rel_pos_f32(params, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + { + ggml_compute_forward_get_rel_pos_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh) { + GGML_ASSERT(qh + kh - 1 <= a->ne[1]); + + const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne); + + result->op = GGML_OP_GET_REL_POS; + result->src[0] = a; + + return result; +} + +*/ + +template +__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) { + int kh = gridDim.x; + int ki = blockIdx.x; + int qi = blockIdx.y; + int pos = (kh - 1) + qi - ki; + + int s0 = C; + int s1 = C * kh; + + for (int ci = threadIdx.x; ci < C; ci += blockDim.x) { + ((T *) dst)[qi*s1 + ki*s0 + ci] = ((const T *) src)[pos*C + ci]; + } +} + +static unsigned int round_to_pow2(unsigned int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + + return v; +} + +void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == dst->type); + + int C = ne0; + int kh = ne1; + int qh = ne2; + + int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C))); + dim3 grid { (unsigned int)kh, (unsigned int)qh, 1 }; + + const void * src0_d = (const void *)src0->data; + void * dst_d = (void *)dst->data; + cudaStream_t stream = ctx.stream(); + + switch (src0->type) + { + case GGML_TYPE_F32: + get_rel_pos_kernel<<>>(src0_d, dst_d, C); + break; + case GGML_TYPE_F16: + get_rel_pos_kernel<<>>(src0_d, dst_d, C); + break; + case GGML_TYPE_BF16: + get_rel_pos_kernel<<>>(src0_d, dst_d, C); + break; + default: + GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type)); + break; + } +} \ No newline at end of file diff --git a/ggml/src/ggml-cuda/rel-pos.cuh b/ggml/src/ggml-cuda/rel-pos.cuh new file mode 100644 index 00000000000..ecf816e67e8 --- /dev/null +++ b/ggml/src/ggml-cuda/rel-pos.cuh @@ -0,0 +1,6 @@ +#pragma once +#include "common.cuh" + +#define CUDA_GET_REL_POS_BLOCK_SIZE 256 + +void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); \ No newline at end of file diff --git a/ggml/src/ggml-cuda/win.cu b/ggml/src/ggml-cuda/win.cu new file mode 100644 index 00000000000..c9e6793ae5f --- /dev/null +++ b/ggml/src/ggml-cuda/win.cu @@ -0,0 +1,430 @@ +#include "common.cuh" +#include "ggml.h" +#include "ggml-cuda/win.cuh" + +/* + +C++ CPU Implementation: + + +static void ggml_compute_forward_win_part_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + GGML_UNUSED(params); + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t bs = ((const int32_t *)(dst->op_params))[2]; + const int32_t w = ((const int32_t *)(dst->op_params))[3]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1*bs); + + // TODO: optimize / multi-thread + for (int64_t i3 = 0; i3 < ne3; i3++) { + int px = i3 % nep0; + int py = (i3 / nep0) % nep1; + int b = i3 / (nep0 * nep1); + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i03 = b; + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + *((ggml_fp16_t *) dp) = 0; + } else { + *((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp); + } + } + } + } + } +} + +void ggml_compute_forward_win_part( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_I32: + case GGML_TYPE_F32: + { + ggml_compute_forward_win_part_f32(params, dst); + } break; + case GGML_TYPE_BF16: + case GGML_TYPE_F16: + { + ggml_compute_forward_win_part_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + + +struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w) { + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int bs = a->ne[3]; + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy*bs; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { npx, npy, bs, w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_PART; + result->src[0] = a; + + return result; +} + + +*/ + +struct win_param { + int w; + int C; + int npx; + int npy; + int ne1; + int ne2; + size_t nb00; + size_t nb01; + size_t nb02; + size_t nb03; +}; + +template +__global__ static void win_part_kernel( + const void * src, + void * dst, + win_param p) +{ + int i1 = blockIdx.x; + int i2 = blockIdx.y; + int i3 = blockIdx.z; + int px = i3 % p.npx; + int py = (i3 / p.npx) % p.npy; + int b = i3 / (p.npx * p.npy); + + const int nb0 = sizeof(T); + const int nb1 = p.C * sizeof(T); + const int nb2 = p.C * p.w * sizeof(T); + const int nb3 = p.C * p.w * p.w * sizeof(T); + + if (py*p.w + i2 >= p.ne2 || px*p.w + i1 >= p.ne1) { + for (int i0 = threadIdx.x; i0 < p.C; i0 += blockDim.x) { + char * dp = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + *((T *) dp) = 0; + } + return; + } + + for (int i0 = threadIdx.x; i0 < p.C; i0 += blockDim.x) { + int i03 = b; + int i02 = py*p.w + i2; + int i01 = px*p.w + i1; + int i00 = i0; + + const char * sp = (const char *)src + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; + char * dp = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + *((T *) dp) = *((const T *) sp); + } +} + + +struct win_unpart_param { + int w; + int C; + int npx; + int npy; + int w0; + int h0; + size_t nb00; + size_t nb01; + size_t nb02; + size_t nb03; +}; + +template +__global__ static void win_unpart_kernel( + const void * src, + void * dst, + win_unpart_param p) +{ + int i1 = blockIdx.x; + int i2 = blockIdx.y; + int i3 = blockIdx.z; + int ip2 = i2/p.w; + int ip1 = i1/p.w; + + int i03 = i3*p.npx*p.npy + ip2*p.npx + ip1; + int i02 = i2%p.w; + int i01 = i1%p.w; + + const int nb0 = sizeof(T); + const int nb1 = p.C * sizeof(T); + const int nb2 = p.C * p.w0 * sizeof(T); + const int nb3 = p.C * p.w0 * p.h0 * sizeof(T); + + for (int i0 = threadIdx.x; i0 < p.C; i0 += blockDim.x) { + int i00 = i0; + const char * sp = (const char *)src + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; + char * dp = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + *((T *) dp) = *((const T *) sp); + } +} + +static unsigned int round_to_pow2(unsigned int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + + return v; +} + +void ggml_cuda_op_win_part(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == dst->type); + + int npx = dst->op_params[0]; + int npy = dst->op_params[1]; + int w = dst->op_params[3]; + int C = ne0; + int np = ne3; + + GGML_ASSERT(ne1 == w && ne2 == w); + + win_param params = { + w, + C, + npx, + npy, + (int)ne01, + (int)ne02, + src0->nb[0], + src0->nb[1], + src0->nb[2], + src0->nb[3] + }; + + dim3 grid { (unsigned int)w, (unsigned int)w, (unsigned int)np }; + int num_threads = MIN(CUDA_WINPART_BLOCK_SIZE, MAX(32, round_to_pow2(C))); + + const void * src0_d = (const void *)src0->data; + void * dst_d = (void *)dst->data; + cudaStream_t stream = ctx.stream(); + + switch (src0->type) + { + case GGML_TYPE_F32: + win_part_kernel<<>>(src0_d, dst_d, params); + break; + case GGML_TYPE_F16: + win_part_kernel<<>>(src0_d, dst_d, params); + break; + case GGML_TYPE_BF16: + win_part_kernel<<>>(src0_d, dst_d, params); + break; + default: + GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type)); + break; + } +} + + +/* + +C++ CPU Implementation: + +static void ggml_compute_forward_win_unpart_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + GGML_UNUSED(params); + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + const int32_t w = ((const int32_t *)(dst->op_params))[0]; + + // padding + const int px = (w - ne1%w)%w; + const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + assert(ne03 == npx*npy*ne3); + + // TODO: optimize / multi-thread + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i03 = i3*npx*npy + ip2*npx + ip1; + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; + void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + + *((ggml_fp16_t *) dp) = *((ggml_fp16_t *) sp); + } + } + } + } +} + +void ggml_compute_forward_win_unpart( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_I32: + case GGML_TYPE_F32: + { + ggml_compute_forward_win_unpart_f32(params, dst); + } break; + case GGML_TYPE_BF16: + case GGML_TYPE_F16: + { + ggml_compute_forward_win_unpart_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w) { + return ggml_win_unpart_ext(ctx, a, w0, h0, 1, w); +} + +struct ggml_tensor * ggml_win_unpart_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int b0, + int w) { + const int64_t ne[4] = { a->ne[0], w0, h0, b0 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + GGML_ASSERT(ggml_is_contiguous(a)); + + int32_t params[] = { w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_UNPART; + result->src[0] = a; + + return result; +} + + +*/ +void ggml_cuda_op_win_unpart(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == dst->type); + + int w = dst->op_params[0]; + int C = ne0; + int w0 = ne1; + int h0 = ne2; + int b0 = ne3; + + const int px = (w - ne1%w)%w; + const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + assert(ne03 == npx*npy*ne3); + + win_unpart_param params = { + w, + C, + npx, + npy, + w0, + h0, + src0->nb[0], + src0->nb[1], + src0->nb[2], + src0->nb[3] + }; + + dim3 grid { (unsigned int)w0, (unsigned int)h0, (unsigned int)b0 }; + int num_threads = MIN(CUDA_WINPART_BLOCK_SIZE, MAX(32, round_to_pow2(C))); + + const void * src0_d = (const void *)src0->data; + void * dst_d = (void *)dst->data; + cudaStream_t stream = ctx.stream(); + + switch (src0->type) + { + case GGML_TYPE_F32: + win_unpart_kernel<<>>(src0_d, dst_d, params); + break; + case GGML_TYPE_F16: + win_unpart_kernel<<>>(src0_d, dst_d, params); + break; + case GGML_TYPE_BF16: + win_unpart_kernel<<>>(src0_d, dst_d, params); + break; + default: + GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type)); + break; + } + +} \ No newline at end of file diff --git a/ggml/src/ggml-cuda/win.cuh b/ggml/src/ggml-cuda/win.cuh new file mode 100644 index 00000000000..f9a3255284c --- /dev/null +++ b/ggml/src/ggml-cuda/win.cuh @@ -0,0 +1,10 @@ +#pragma once +#include "common.cuh" + +#define CUDA_WINPART_BLOCK_SIZE 256 +#define CUDA_WINUNPART_BLOCK_SIZE 256 + + +void ggml_cuda_op_win_part(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_win_unpart(ggml_backend_cuda_context & ctx, ggml_tensor * dst); \ No newline at end of file diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c4cd3f42dde..5c263196fa1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5429,7 +5429,7 @@ struct ggml_tensor * ggml_win_part( const int np = npx*npy*bs; const int64_t ne[4] = { a->ne[0], w, w, np, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne); int32_t params[] = { npx, npy, bs, w }; ggml_set_op_params(result, params, sizeof(params)); @@ -5459,7 +5459,7 @@ struct ggml_tensor * ggml_win_unpart_ext( int b0, int w) { const int64_t ne[4] = { a->ne[0], w0, h0, b0 }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne); GGML_ASSERT(ggml_is_contiguous(a)); @@ -5482,7 +5482,7 @@ struct ggml_tensor * ggml_get_rel_pos( GGML_ASSERT(qh + kh - 1 <= a->ne[1]); const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne); result->op = GGML_OP_GET_REL_POS; result->src[0] = a; From 72cdf76e7ce4a97f09227c46af4438a719faa284 Mon Sep 17 00:00:00 2001 From: bluebread Date: Wed, 19 Nov 2025 10:21:28 +0000 Subject: [PATCH 3/5] ggml : add scaling to get_rel_pos for different query/key heights --- ggml/src/ggml-cpu/ops.cpp | 10 ++++- ggml/src/ggml-cuda/rel-pos.cu | 74 ++--------------------------------- ggml/src/ggml.c | 3 +- 3 files changed, 14 insertions(+), 73 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7984542f0e1..04c9889a844 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9217,13 +9217,16 @@ static void ggml_compute_forward_get_rel_pos_f32( GGML_TENSOR_UNARY_OP_LOCALS const int64_t kh = ne1; + const int64_t qh = ne2; + const float k_scale = MAX(qh / kh, 1.0f); + const float q_scale = MAX(kh / qh, 1.0f); float * src0_data = (float *) src0->data; float * dst_data = (float *) dst->data; for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (kh - i1 - 1) + i2; + const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale); for (int64_t i0 = 0; i0 < ne0; ++i0) { dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; } @@ -9243,13 +9246,16 @@ static void ggml_compute_forward_get_rel_pos_f16( GGML_TENSOR_UNARY_OP_LOCALS const int64_t kh = ne1; + const int64_t qh = ne2; + const float k_scale = MAX(qh / kh, 1.0f); + const float q_scale = MAX(kh / qh, 1.0f); ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (kh - i1 - 1) + i2; + const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale); for (int64_t i0 = 0; i0 < ne0; ++i0) { dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; } diff --git a/ggml/src/ggml-cuda/rel-pos.cu b/ggml/src/ggml-cuda/rel-pos.cu index 1eaef024bad..5c1f5f4d74b 100644 --- a/ggml/src/ggml-cuda/rel-pos.cu +++ b/ggml/src/ggml-cuda/rel-pos.cu @@ -2,82 +2,16 @@ #include "ggml.h" #include "ggml-cuda/rel-pos.cuh" -/* - -static void ggml_compute_forward_get_rel_pos_f16( - const ggml_compute_params * params, - ggml_tensor * dst) { - GGML_UNUSED(params); - - const ggml_tensor * src0 = dst->src[0]; - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 - - GGML_TENSOR_UNARY_OP_LOCALS - - const int64_t kh = ne1; - - ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; - ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; - - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (kh - i1 - 1) + i2; - for (int64_t i0 = 0; i0 < ne0; ++i0) { - dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; - } - } - } -} - - -void ggml_compute_forward_get_rel_pos( - const ggml_compute_params * params, - ggml_tensor * dst) { - - const ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_get_rel_pos_f32(params, dst); - } break; - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - { - ggml_compute_forward_get_rel_pos_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -struct ggml_tensor * ggml_get_rel_pos( - struct ggml_context * ctx, - struct ggml_tensor * a, - int qh, - int kh) { - GGML_ASSERT(qh + kh - 1 <= a->ne[1]); - - const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne); - - result->op = GGML_OP_GET_REL_POS; - result->src[0] = a; - - return result; -} - -*/ template __global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) { int kh = gridDim.x; + int qh = gridDim.x; + float k_scale = MAX(qh / kh, 1.0f); + float q_scale = MAX(kh / qh, 1.0f); int ki = blockIdx.x; int qi = blockIdx.y; - int pos = (kh - 1) + qi - ki; + int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale); int s0 = C; int s1 = C * kh; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5c263196fa1..4e12c479804 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5479,7 +5479,8 @@ struct ggml_tensor * ggml_get_rel_pos( struct ggml_tensor * a, int qh, int kh) { - GGML_ASSERT(qh + kh - 1 <= a->ne[1]); + GGML_ASSERT(qh >= 1 && kh >= 1); + GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 3, ne); From 508def24d48d58dc9ab94b8db95f9a2056ab6b31 Mon Sep 17 00:00:00 2001 From: bluebread Date: Wed, 19 Nov 2025 10:26:33 +0000 Subject: [PATCH 4/5] ggml : fix get_rel_pos scaling bugs and update tests --- ggml/src/ggml-cpu/ops.cpp | 8 +- ggml/src/ggml-cuda/rel-pos.cu | 6 +- tests/test-backend-ops.cpp | 162 ++++++++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 04c9889a844..4663a50ab3f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9218,8 +9218,8 @@ static void ggml_compute_forward_get_rel_pos_f32( const int64_t kh = ne1; const int64_t qh = ne2; - const float k_scale = MAX(qh / kh, 1.0f); - const float q_scale = MAX(kh / qh, 1.0f); + const float k_scale = MAX((float)qh / kh, 1.0f); + const float q_scale = MAX((float)kh / qh, 1.0f); float * src0_data = (float *) src0->data; float * dst_data = (float *) dst->data; @@ -9247,8 +9247,8 @@ static void ggml_compute_forward_get_rel_pos_f16( const int64_t kh = ne1; const int64_t qh = ne2; - const float k_scale = MAX(qh / kh, 1.0f); - const float q_scale = MAX(kh / qh, 1.0f); + const float k_scale = MAX((float)qh / kh, 1.0f); + const float q_scale = MAX((float)kh / qh, 1.0f); ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; diff --git a/ggml/src/ggml-cuda/rel-pos.cu b/ggml/src/ggml-cuda/rel-pos.cu index 5c1f5f4d74b..1d1aba4c737 100644 --- a/ggml/src/ggml-cuda/rel-pos.cu +++ b/ggml/src/ggml-cuda/rel-pos.cu @@ -6,9 +6,9 @@ template __global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) { int kh = gridDim.x; - int qh = gridDim.x; - float k_scale = MAX(qh / kh, 1.0f); - float q_scale = MAX(kh / qh, 1.0f); + int qh = gridDim.y; + float k_scale = MAX((float)qh / kh, 1.0f); + float q_scale = MAX((float)kh / qh, 1.0f); int ki = blockIdx.x; int qi = blockIdx.y; int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6638f28182e..3eaa08b7acc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5518,6 +5518,121 @@ struct test_pad_reflect_1d : public test_case { } }; +// GGML_OP_WIN_PART +struct test_win_part : public test_case { + const ggml_type type; + const std::array ne_a; // [C, W, H, B] + const int w; // window size + const bool v; // view (non-contiguous input) + + std::string vars() override { + return VARS_TO_STR4(type, ne_a, w, v); + } + + test_win_part(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {64, 14, 14, 2}, + int w = 7, + bool v = false) + : type(type), ne_a(ne_a), w(w), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a; + if (v) { + auto ne = ne_a; ne[0] *= 2; ne[1] *= 2; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], + a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + } + + ggml_tensor * out = ggml_win_part(ctx, a, w); + ggml_set_name(out, "out"); + + return out; + } +}; + +// GGML_OP_WIN_UNPART +struct test_win_unpart : public test_case { + const ggml_type type; + const std::array ne_a; // [C, w, w, NPX*NPY*B] + const int w0; // output width + const int h0; // output height + const int w; // window size + + std::string vars() override { + return VARS_TO_STR5(type, ne_a, w0, h0, w); + } + + test_win_unpart(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {64, 7, 7, 8}, + int w0 = 14, int h0 = 14, + int w = 7) + : type(type), ne_a(ne_a), w0(w0), h0(h0), w(w) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_win_unpart(ctx, a, w0, h0, w); + ggml_set_name(out, "out"); + + return out; + } +}; + +// GGML_OP_GET_REL_POS +struct test_get_rel_pos : public test_case { + const ggml_type type; + const int C; // channels + const int qh; // query height + const int kh; // key height + const bool v; // view (non-contiguous input) + + std::string vars() override { + return VARS_TO_STR5(type, C, qh, kh, v); + } + + test_get_rel_pos(ggml_type type = GGML_TYPE_F32, + int C = 64, + int qh = 7, + int kh = 7, + bool v = false) + : type(type), C(C), qh(qh), kh(kh), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + // Input tensor has relative position embeddings table + // Shape: [C, 2*max(qh,kh)-1, 1, 1] + const int64_t ne_a[4] = {C, 2*std::max(qh, kh) - 1, 1, 1}; + + ggml_tensor * a; + if (v) { + // Create larger tensor and view into it (non-contiguous) + int64_t ne_large[4] = {C * 2, 2*std::max(qh, kh) - 1, 1, 1}; + a = ggml_new_tensor(ctx, type, 4, ne_large); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, C, 2*std::max(qh, kh) - 1, 1, 1, + a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a); + ggml_set_name(a, "a"); + } + + // Output shape: [C, kh, qh, 1] + ggml_tensor * out = ggml_get_rel_pos(ctx, a, qh, kh); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_ROLL struct test_roll : public test_case { const int shift0; @@ -7565,6 +7680,53 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad_ext()); test_cases.emplace_back(new test_pad_reflect_1d()); test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); + + // Window partition tests + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) { + for (bool v : {false, true}) { + // Exact division: 14x14 -> 2x2 windows of 7x7 + test_cases.emplace_back(new test_win_part(type, {64, 14, 14, 2}, 7, v)); + // With padding: 15x15 -> 3x3 windows of 7x7 (padded) + test_cases.emplace_back(new test_win_part(type, {64, 15, 15, 2}, 7, v)); + // Single window: 7x7 -> 1x1 windows of 7x7 + test_cases.emplace_back(new test_win_part(type, {64, 7, 7, 1}, 7, v)); + // Larger: 28x28 -> 4x4 windows of 7x7 + test_cases.emplace_back(new test_win_part(type, {128, 28, 28, 4}, 7, v)); + // Window size 8: 16x16 -> 2x2 windows of 8x8 + test_cases.emplace_back(new test_win_part(type, {96, 16, 16, 1}, 8, v)); + } + } + + // Window unpartition tests (inverse of partition) + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) { + // Exact division: 2x2 windows of 7x7 -> 14x14 + test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 4}, 14, 14, 7)); + // With padding: 3x3 windows of 7x7 -> 15x15 + test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 9}, 15, 15, 7)); + // Single window: 1x1 windows of 7x7 -> 7x7 + test_cases.emplace_back(new test_win_unpart(type, {64, 7, 7, 1}, 7, 7, 7)); + // Larger: 4x4 windows of 7x7 -> 28x28 + test_cases.emplace_back(new test_win_unpart(type, {128, 7, 7, 16}, 28, 28, 7)); + // Window size 8: 2x2 windows of 8x8 -> 16x16 + test_cases.emplace_back(new test_win_unpart(type, {96, 8, 8, 4}, 16, 16, 8)); + } + + // Relative position embedding tests (used in SAM) + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) { + for (bool v : {false, true}) { + // Square small: 3x3 attention + test_cases.emplace_back(new test_get_rel_pos(type, 5, 3, 3, v)); + // Square medium: 7x7 attention (typical SAM) + test_cases.emplace_back(new test_get_rel_pos(type, 13, 7, 7, v)); + // Square large: 14x14 attention + test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 14, v)); + // Rectangular: 14x7 attention + test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 7, v)); + // Edge case: 1x1 attention (minimum) + test_cases.emplace_back(new test_get_rel_pos(type, 1, 1, 1, v)); + } + } + test_cases.emplace_back(new test_roll()); test_cases.emplace_back(new test_arange()); test_cases.emplace_back(new test_timestep_embedding()); From 8d2e3b2f16cd6445a74cdcd6f531c29c72e05caf Mon Sep 17 00:00:00 2001 From: bluebread Date: Wed, 19 Nov 2025 13:17:49 +0000 Subject: [PATCH 5/5] fix: replace assert with GGML_ASSERT --- ggml/src/ggml-cpu/ops.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 4663a50ab3f..6b6937d886a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8847,8 +8847,8 @@ static void ggml_compute_forward_win_part_f32( const int32_t bs = ((const int32_t *)(dst->op_params))[2]; const int32_t w = ((const int32_t *)(dst->op_params))[3]; - assert(ne00 == ne0); - assert(ne3 == nep0*nep1*bs); + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne3 == nep0*nep1*bs); // TODO: optimize / multi-thread for (int64_t i3 = 0; i3 < ne3; i3++) { @@ -8891,8 +8891,8 @@ static void ggml_compute_forward_win_part_f16( const int32_t bs = ((const int32_t *)(dst->op_params))[2]; const int32_t w = ((const int32_t *)(dst->op_params))[3]; - assert(ne00 == ne0); - assert(ne3 == nep0*nep1*bs); + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne3 == nep0*nep1*bs); // TODO: optimize / multi-thread for (int64_t i3 = 0; i3 < ne3; i3++) {