diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 66dd0bfabd2..0dbf6279f85 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants { uint32_t orig_ncols; uint32_t ncols_input; uint32_t ncols_output; + uint32_t k; uint32_t nrows; uint32_t first_pass; uint32_t last_pass; @@ -1673,6 +1674,14 @@ class vk_perf_logger { timings[name.str()].push_back(time); return; } + if (node->op == GGML_OP_TOP_K) { + std::stringstream name; + name << ggml_op_name(node->op) << + " K=" << node->ne[0] << + " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -4041,7 +4050,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + sizeof(int) * device->subgroup_size + 2 * sizeof(int) + - (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + 2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int); if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); @@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons uint32_t nrows = ggml_nrows(src0); uint32_t k = dst->ne[0]; - vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 }; - // Reserve space for ivec2 per element, double buffered - const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); - const size_t x_sz = dbl_buf_size * 2; - uint32_t dbl_buf_index = 0; - - if (ctx->prealloc_size_x < x_sz) { - ctx->prealloc_size_x = x_sz; - ggml_vk_preallocate_buffers(ctx, subctx); - } if (ctx->prealloc_x_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } @@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // largest elements. Repeat until we have the top K elements. // Need to do at least one iteration to write out the results. bool done_one_iter = false; + uint32_t dbl_buf_index = 0; + size_t dbl_buf_size; while (num_elements > k || !done_one_iter) { - done_one_iter = true; // Prefer going as small as num_topk_pipelines - 3 for perf reasons. // But if K is larger, then we need a larger workgroup @@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // Number of elements remaining after this pass uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + pc2.ncols_output = num_dst_elements; + + if (!done_one_iter) { + // Reserve space for ivec2 per element, double buffered + // K per workgroup per row + dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int); + dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const size_t x_sz = dbl_buf_size * 2; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + } + vk_subbuffer src_buf; vk_subbuffer dst_buf; @@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons if (num_elements > k) { ggml_vk_sync_buffers(ctx, subctx); } + done_one_iter = true; } ctx->prealloc_x_need_sync = true; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp index cd858b7d326..49d4ab8e7c0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols_input; dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (col < s) { @@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) { } } - if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (col < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + col] = dst_row[col].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + col] = dst_row[col].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + col] = dst_row[col]; + if (gl_WorkGroupID.x * p.k + col < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + col] = dst_row[col]; + } } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp index c902e60237a..0b757f38e18 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -37,6 +38,7 @@ shared int counts[SUBGROUP_SIZE]; shared int sh_min_idx; shared uint sh_total; shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; +shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE]; // Map float values to uint such that comparisons still work. // Positive values set the high bit, negative values are inverted. @@ -60,7 +62,7 @@ void topk(const uint row) { const uint row_offset = row * p.ncols_input; dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -68,7 +70,7 @@ void topk(const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (tid < s) { @@ -98,7 +100,7 @@ void topk(const uint row) { uint range_max = 0xFF800000; // How many are above the current range, and how many we need to find. uint total = 0; - uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); while (mask != 0) { barrier(); @@ -139,7 +141,7 @@ void topk(const uint row) { range_max = range_min + ((min_idx + 1) << shift); range_min = range_min + (min_idx << shift); - if (total == p.ncols_output) { + if (total == p.k) { break; } total -= counts[min_idx]; @@ -155,37 +157,82 @@ void topk(const uint row) { // We need to compact these values to the start of the dst_row array. // Have each subgroup count how many items it'll store, so other // subgroups can compute their base offset. - bool top = f2ui(intBitsToFloat(v.y)) >= range_min; - uvec4 b = subgroupBallot(top); - uint bit_count = subgroupBallotBitCount(b); - if ((tid % SUBGROUP_SIZE) == 0) { - offset_partials[tid / SUBGROUP_SIZE] = bit_count; - } - barrier(); + // Values strictly greater than range_min must be stored. For values equal + // to range_min, there can be ties and it's possible we'll need to store + // an arbitrary subset of them. + // If total == p.k, have a fast path where we don't need to handle ties. + if (total == p.k) { + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; + } + barrier(); - uint out_idx = 0; - [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { - if (i < tid / SUBGROUP_SIZE) { - out_idx += offset_partials[i]; + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } } - } - uint bit_count_ex = subgroupBallotExclusiveBitCount(b); - if (top) { - // TODO: Copy directly to the output? - dst_row[out_idx + bit_count_ex] = v; + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + } else { + bool top = f2ui(intBitsToFloat(v.y)) > range_min; + bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min; + uvec4 b_top = subgroupBallot(top); + uvec4 b_eq_min = subgroupBallot(eq_min); + uint bit_count_top = subgroupBallotBitCount(b_top); + uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count_top; + eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min; + } + barrier(); + + uint out_idx = 0; + uint eq_min_base = 0; + uint eq_min_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + eq_min_idx += eq_min_partials[i]; + } + eq_min_base += offset_partials[i]; + } + // range_min values are stored at the end + eq_min_idx += eq_min_base; + + uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top); + uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex_top] = v; + } + if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) { + dst_row[eq_min_idx + bit_count_ex_eq_min] = v; + } } barrier(); } - if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (tid < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + tid] = dst_row[tid].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + tid] = dst_row[tid].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + tid] = dst_row[tid]; + if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + tid] = dst_row[tid]; + } } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 87a61aa1224..16ecc89b7a4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -270,10 +270,11 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } -// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap) -static double jdst(const int32_t * a, const int32_t * b, size_t n) { - std::unordered_map set_a; - std::unordered_map set_b; +// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap) +template +static double jdst(const T * a, const T * b, size_t n) { + std::unordered_map set_a; + std::unordered_map set_b; for (size_t i = 0; i < n; ++i) { set_a[a[i]]++; @@ -4985,42 +4986,94 @@ struct test_top_k : public test_case { const ggml_type type; const std::array ne; const int k; + const bool ties; + ggml_tensor * input {}; std::string vars() override { - return VARS_TO_STR3(type, ne, k); + return VARS_TO_STR4(type, ne, k, ties); } test_top_k(ggml_type type = GGML_TYPE_F32, std::array ne = {16, 10, 10, 10}, - int k = 4) - : type(type), ne(ne), k(k) {} + int k = 4, bool ties = false) + : type(type), ne(ne), k(k), ties(ties) {} double max_err() override { return 0.0; } + // When there are ties, only validate the final result. + // The logic in err can't handle the sentinel tensors. + bool run_whole_graph() override { return ties; } + double err(const float * a, const float * b, size_t n) override { - std::vector ia(n); - std::vector ib(n); + // When there are no ties, we expect the exact same set of indices, + // but possibly in a different order. When there are ties, the indices + // can be different but the input values they correspond to should be + // the same. The logic for ties could work for non-ties, but only for + // the output tensor, not for the sentinel tensors. + if (ties) { + std::vector src(ggml_nelements(input)); + + ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type)); + + double diff = 0.0f; + + GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k)); + int64_t cols = input->ne[0]; + std::vector ia(k); + std::vector ib(k); + std::vector asrc(k); + std::vector bsrc(k); + for (int64_t r = 0; r < ggml_nrows(input); r++) { + // Convert indices for the row back to integer + for (int64_t c = 0; c < k; c++) { + ia[c] = (int32_t)a[r * k + c]; + ib[c] = (int32_t)b[r * k + c]; + } + // The src values for each row should match. + for (int64_t c = 0; c < k; c++) { + asrc[c] = src[r * cols + ia[c]]; + bsrc[c] = src[r * cols + ib[c]]; + } + diff += jdst(asrc.data(), bsrc.data(), k); + // There should be no duplicate indices + std::sort(ia.begin(), ia.end()); + std::sort(ib.begin(), ib.end()); + if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) { + diff += 1; + } + if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) { + diff += 1; + } + } + return diff; + } else { + std::vector ia(n); + std::vector ib(n); - double diff = 0.0f; + double diff = 0.0f; - for (size_t i = 0; i < n; i++) { - ia[i] = (int32_t) a[i]; - ib[i] = (int32_t) b[i]; + for (size_t i = 0; i < n; i++) { + ia[i] = (int32_t) a[i]; + ib[i] = (int32_t) b[i]; - // penalize the result if the data is not integer valued - diff += std::fabs(a[i] - ia[i]); - diff += std::fabs(b[i] - ib[i]); - } + // penalize the result if the data is not integer valued + diff += std::fabs(a[i] - ia[i]); + diff += std::fabs(b[i] - ib[i]); + } - return diff + jdst(ia.data(), ib.data(), n); + return diff + jdst(ia.data(), ib.data(), n); + } } ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); + // Save 'a' for err() + input = a; + ggml_tensor * out = ggml_top_k(ctx, a, k); ggml_set_name(out, "out"); @@ -5031,11 +5084,16 @@ struct test_top_k : public test_case { std::random_device rd; std::default_random_engine rng(rd()); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - // initialize with unique values to avoid ties + int tie_denom = std::max(1, std::min(10, k / 2)); for (int64_t r = 0; r < ggml_nrows(t); r++) { std::vector data(t->ne[0]); for (int i = 0; i < t->ne[0]; i++) { - data[i] = i; + if (ties) { + // integer division to introduce duplicates + data[i] = i / tie_denom; + } else { + data[i] = i; + } } std::shuffle(data.begin(), data.end(), rng); ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); @@ -7640,6 +7698,7 @@ static std::vector> make_test_cases_eval() { if (k <= 1<