diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 0c1714fdbc093..5555f91bb8cce 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3605,8 +3605,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); - ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); std::swap(bid_dst, bid_tmp); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8afc7318f684e..294b5ffc1e414 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4628,12 +4628,13 @@ kernel void kernel_argsort_merge_f32_i32( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int im = tgpig[0] / args.ne01; - int i01 = tgpig[0] % args.ne01; - int i02 = tgpig[1]; - int i03 = tgpig[2]; - const int start = im * (2*args.len); + const int im = tgpig[0] / args.ne01; + const int i01 = tgpig[0] % args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + const int start = im * (2 * args.len); const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); @@ -4657,54 +4658,101 @@ kernel void kernel_argsort_merge_f32_i32( + args.nb02*i02 + args.nb03*i03); - for (int k = tpitg.x; k < (int) total; k += ntg.x) { - // find partition (i,j) such that i+j = k - int low = k > len1 ? k - len1 : 0; - int high = MIN(k, len0); + if (total == 0) { + return; + } - while (low < high) { - const int mid = (low + high) >> 1; + const int chunk = (total + ntg.x - 1) / ntg.x; - const int32_t idx0 = tmp0[mid]; - const int32_t idx1 = tmp1[k - mid - 1]; + const int k0 = tpitg.x * chunk; + const int k1 = min(k0 + chunk, total); - const float val0 = src0_row[idx0]; - const float val1 = src0_row[idx1]; + if (k0 >= total) { + return; + } - if (order == GGML_SORT_ORDER_ASC) { - if (val0 <= val1) { - low = mid + 1; - } else { - high = mid; - } - } else { - if (val0 >= val1) { - low = mid + 1; - } else { - high = mid; - } - } + int low = k0 > len1 ? k0 - len1 : 0; + int high = MIN(k0, len0); + + // binary-search partition (i, j) such that i + j = k + while (low < high) { + const int mid = (low + high) >> 1; + + const int32_t idx0 = tmp0[mid]; + const int32_t idx1 = tmp1[k0 - mid - 1]; + + const float val0 = src0_row[idx0]; + const float val1 = src0_row[idx1]; + + bool take_left; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); } - const int i = low; - const int j = k - i; + if (take_left) { + low = mid + 1; + } else { + high = mid; + } + } + + int i = low; + int j = k0 - i; + + // keep the merge fronts into registers + int32_t idx0 = 0; + float val0 = 0.0f; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + + int32_t idx1 = 0; + float val1 = 0.0f; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + for (int k = k0; k < k1; ++k) { int32_t out_idx; if (i >= len0) { - out_idx = tmp1[j]; + while (k < k1) { + dst[k++] = tmp1[j++]; + } + break; } else if (j >= len1) { - out_idx = tmp0[i]; + while (k < k1) { + dst[k++] = tmp0[i++]; + } + break; } else { - const int32_t idx0 = tmp0[i]; - const int32_t idx1 = tmp1[j]; + bool take_left; - const float val0 = src0_row[idx0]; - const float val1 = src0_row[idx1]; + if (order == GGML_SORT_ORDER_ASC) { + take_left = (val0 <= val1); + } else { + take_left = (val0 >= val1); + } - out_idx = (order == GGML_SORT_ORDER_ASC) - ? (val0 <= val1 ? idx0 : idx1) - : (val0 >= val1 ? idx0 : idx1); + if (take_left) { + out_idx = idx0; + ++i; + if (i < len0) { + idx0 = tmp0[i]; + val0 = src0_row[idx0]; + } + } else { + out_idx = idx1; + ++j; + if (j < len1) { + idx1 = tmp1[j]; + val1 = src0_row[idx1]; + } + } } dst[k] = out_idx;