From 9cc51d394d4bb7acda21fda614efbaaf5074cd94 Mon Sep 17 00:00:00 2001 From: Samuel Wu Date: Mon, 13 Oct 2025 17:44:32 +0800 Subject: [PATCH 1/3] optimise GGML_OP_SUM --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 15 ++++++++- ggml/src/ggml-metal/ggml-metal.metal | 42 ++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a61ea8fb5a7b3..d590c37899875 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + int nth = 32; // SIMD width + + while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, (int) n); + + const int nsg = (nth + 31) / 32; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 74a9aa998837c..aa75688353b0c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, device float * dst, - ushort tiitg[[thread_index_in_threadgroup]]) { + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - if (tiitg != 0) { + if (args.np == 0) { return; } - float acc = 0.0f; - for (ulong i = 0; i < args.np; ++i) { - acc += src0[i]; + const uint nsg = (ntg.x + 31) / 32; + + float sumf = 0; + + for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { + sumf += src0[i0]; } - dst[0] = acc; + sumf = simd_sum(sumf); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float total = 0; + + if (sgitg == 0) { + float v = 0; + + if (tpitg.x < nsg) { + v = shmem_f32[tpitg.x]; + } + + total = simd_sum(v); + + if (tpitg.x == 0) { + dst[0] = total; + } + } } template From 4619142b88d8aec62149969152f00bd02b707f6d Mon Sep 17 00:00:00 2001 From: Samuel Wu Date: Mon, 13 Oct 2025 19:27:10 +0800 Subject: [PATCH 2/3] add non-contiguous tests by permuting the input --- tests/test-backend-ops.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2fa16b497a6b7..e5a11042b349f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case { struct test_sum : public test_case { const ggml_type type; const std::array ne; + const std::array permute; + bool _use_permute; std::string vars() override { - return VARS_TO_STR2(type, ne); + std::string v = VARS_TO_STR2(type, ne); + if (_use_permute) v += "," + VAR_TO_STR(permute); + return v; } test_sum(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 5, 4, 3}) - : type(type), ne(ne) {} + std::array ne = {10, 5, 4, 3}, + std::array permute = {0, 0, 0, 0}) + : type(type), ne(ne), permute(permute), + _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); + if (_use_permute) { + a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]); + ggml_set_name(a, "a_permuted"); + } + ggml_tensor * out = ggml_sum(ctx, a); ggml_set_name(out, "out"); @@ -6729,11 +6740,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_mean()); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 })); test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); From c25a6c7c748f468cff31704aa9f13ba26a9c33f0 Mon Sep 17 00:00:00 2001 From: Samuel Wu Date: Mon, 13 Oct 2025 21:13:22 +0800 Subject: [PATCH 3/3] change tests to require full contiguity of OP_SUM --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + tests/test-backend-ops.cpp | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fc5083043f7c9..c4981ad522dc4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -657,6 +657,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e5a11042b349f..28cd372d5c525 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6735,12 +6735,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1})); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2})); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_mean()); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));