From f06a8d35dbacff78d6a42bb951ec1f09770bd252 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 17 Nov 2025 10:08:10 +0200 Subject: [PATCH 1/2] metal : faster argsort --- ggml/src/ggml-metal/ggml-metal.metal | 92 ++++++++++++++++++---------- 1 file changed, 59 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8afc7318f684e..cbf18ba0a4571 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4628,12 +4628,13 @@ kernel void kernel_argsort_merge_f32_i32( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int im = tgpig[0] / args.ne01; - int i01 = tgpig[0] % args.ne01; - int i02 = tgpig[1]; - int i03 = tgpig[2]; - const int start = im * (2*args.len); + const int im = tgpig[0] / args.ne01; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + const int start = im * (2 * args.len); const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); @@ -4657,44 +4658,58 @@ kernel void kernel_argsort_merge_f32_i32( + args.nb02*i02 + args.nb03*i03); - for (int k = tpitg.x; k < (int) total; k += ntg.x) { - // find partition (i,j) such that i+j = k - int low = k > len1 ? k - len1 : 0; - int high = MIN(k, len0); + if (total == 0) { + return; + } - while (low < high) { - const int mid = (low + high) >> 1; + const int chunk = (total + ntg.x - 1) / ntg.x; - const int32_t idx0 = tmp0[mid]; - const int32_t idx1 = tmp1[k - mid - 1]; + const int k0 = tpitg.x * chunk; + const int k1 = min(k0 + chunk, total); - const float val0 = src0_row[idx0]; - const float val1 = src0_row[idx1]; + if (k0 >= total) { + return; + } - if (order == GGML_SORT_ORDER_ASC) { - if (val0 <= val1) { - low = mid + 1; - } else { - high = mid; - } - } else { - if (val0 >= val1) { - low = mid + 1; - } else { - high = mid; - } - } + int low = k0 > len1 ? k0 - len1 : 0; + int high = MIN(k0, len0); + + // binary-search partition (i, j) such that i + j = k + while (low < high) { + const int mid = (low + high) >> 1; + + const int32_t idx0 = tmp0[mid]; + const int32_t idx1 = tmp1[k0 - mid - 1]; + + const float val0 = src0_row[idx0]; + const float val1 = src0_row[idx1]; + + bool take_left; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); } - const int i = low; - const int j = k - i; + if (take_left) { + low = mid + 1; + } else { + high = mid; + } + } + int i = low; + int j = k0 - i; + + for (int k = k0; k < k1; ++k) { int32_t out_idx; if (i >= len0) { out_idx = tmp1[j]; + ++j; } else if (j >= len1) { out_idx = tmp0[i]; + ++i; } else { const int32_t idx0 = tmp0[i]; const int32_t idx1 = tmp1[j]; @@ -4702,9 +4717,20 @@ kernel void kernel_argsort_merge_f32_i32( const float val0 = src0_row[idx0]; const float val1 = src0_row[idx1]; - out_idx = (order == GGML_SORT_ORDER_ASC) - ? (val0 <= val1 ? idx0 : idx1) - : (val0 >= val1 ? idx0 : idx1); + bool take_left; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } + + if (take_left) { + out_idx = idx0; + ++i; + } else { + out_idx = idx1; + ++j; + } } dst[k] = out_idx; From f2c5763391c7dd42ea4d307868c231c4bc9d7b2a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 17 Nov 2025 10:30:43 +0200 Subject: [PATCH 2/2] cont : keep data in registers --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 -- ggml/src/ggml-metal/ggml-metal.metal | 42 ++++++++++++++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 0c1714fdbc093..5555f91bb8cce 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3605,8 +3605,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); - ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); std::swap(bid_dst, bid_tmp); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index cbf18ba0a4571..294b5ffc1e414 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4701,23 +4701,37 @@ kernel void kernel_argsort_merge_f32_i32( int i = low; int j = k0 - i; + // keep the merge fronts into registers + int32_t idx0 = 0; + float val0 = 0.0f; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + + int32_t idx1 = 0; + float val1 = 0.0f; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + for (int k = k0; k < k1; ++k) { int32_t out_idx; if (i >= len0) { - out_idx = tmp1[j]; - ++j; + while (k < k1) { + dst[k++] = tmp1[j++]; + } + break; } else if (j >= len1) { - out_idx = tmp0[i]; - ++i; + while (k < k1) { + dst[k++] = tmp0[i++]; + } + break; } else { - const int32_t idx0 = tmp0[i]; - const int32_t idx1 = tmp1[j]; - - const float val0 = src0_row[idx0]; - const float val1 = src0_row[idx1]; - bool take_left; + if (order == GGML_SORT_ORDER_ASC) { take_left = (val0 <= val1); } else { @@ -4727,9 +4741,17 @@ kernel void kernel_argsort_merge_f32_i32( if (take_left) { out_idx = idx0; ++i; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } } else { out_idx = idx1; ++j; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } } }