diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9f1e5f8d64632..e573e78fae122 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9,7 +9,6 @@ #include #include -#include // ggml_compute_forward_dup @@ -7665,6 +7664,18 @@ void ggml_compute_forward_timestep_embedding( // ggml_compute_forward_argsort +template +struct cmp_argsort { + const float * data; + bool operator()(int32_t a, int32_t b) const { + if constexpr (order == GGML_SORT_ORDER_ASC) { + return data[a] < data[b]; + } else { + return data[a] > data[b]; + } + } +}; + static void ggml_compute_forward_argsort_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -7691,16 +7702,17 @@ static void ggml_compute_forward_argsort_f32( dst_data[j] = j; } - std::function cmp; - - // note: this might be causing memory allocations? ideally should be avoided if it's the case switch (order) { - case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break; - case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break; + case GGML_SORT_ORDER_ASC: + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); + break; + + case GGML_SORT_ORDER_DESC: + std::sort(dst_data, dst_data + ne0, cmp_argsort{src_data}); + break; + default: GGML_ABORT("invalid sort order"); } - - std::sort(dst_data, dst_data + ne0, cmp); } }