Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
Expand Down
15 changes: 14 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);

int nth = 32; // SIMD width

while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}

nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, (int) n);

const int nsg = (nth + 31) / 32;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);

ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);

ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);

return 1;
}
Expand Down
42 changes: 36 additions & 6 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {

if (tiitg != 0) {
if (args.np == 0) {
return;
}

float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
const uint nsg = (ntg.x + 31) / 32;

float sumf = 0;

for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}

dst[0] = acc;
sumf = simd_sum(sumf);

if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

float total = 0;

if (sgitg == 0) {
float v = 0;

if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}

total = simd_sum(v);

if (tpitg.x == 0) {
dst[0] = total;
}
}
}

template <bool norm>
Expand Down
21 changes: 18 additions & 3 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case {
struct test_sum : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int64_t, 4> permute;
bool _use_permute;

std::string vars() override {
return VARS_TO_STR2(type, ne);
std::string v = VARS_TO_STR2(type, ne);
if (_use_permute) v += "," + VAR_TO_STR(permute);
return v;
}

test_sum(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3})
: type(type), ne(ne) {}
std::array<int64_t, 4> ne = {10, 5, 4, 3},
std::array<int64_t, 4> permute = {0, 0, 0, 0})
: type(type), ne(ne), permute(permute),
_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");

if (_use_permute) {
a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
ggml_set_name(a, "a_permuted");
}

ggml_tensor * out = ggml_sum(ctx, a);
ggml_set_name(out, "out");

Expand Down Expand Up @@ -6724,6 +6735,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {

test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
Expand All @@ -6734,6 +6748,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
Expand Down
Loading