From 456e83d58fc6ca6032bfcde736b1f3839595b993 Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 12 Nov 2025 21:23:58 +0100 Subject: [PATCH] ggml-cpu : use template for argsort --- ggml/src/ggml-cpu/ops.cpp | 30 ++++++++++++++++++++++-------- tests/test-backend-ops.cpp | 2 ++ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9f1e5f8d64632..09f53b470b26a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort +template +struct argsort_cmp { + const float * data; + bool operator()(int32_t a, int32_t b) const { + if constexpr (order == GGML_SORT_ORDER_ASC) { + return data[a] < data[b]; + } else { + return data[a] > data[b]; + } + } +}; + static void ggml_compute_forward_argsort_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32( dst_data[j] = j; } - std::function cmp; - - // note: this might be causing memory allocations? ideally should be avoided if it's the case switch (order) { - case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break; - case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break; - default: GGML_ABORT("invalid sort order"); - } + case GGML_SORT_ORDER_ASC: + std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + break; - std::sort(dst_data, dst_data + ne0, cmp); + case GGML_SORT_ORDER_DESC: + std::sort(dst_data, dst_data + ne0, argsort_cmp{src_data}); + break; + + default: + GGML_ABORT("invalid sort order"); + } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 92c17ac4399c0..38b7ddf221785 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7631,6 +7631,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it)); } + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + return test_cases; }