From c50e5cbdbec9ecb390b9c3c2fc85f72b49c53faf Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 4 Dec 2025 18:50:30 -0600 Subject: [PATCH] vulkan: support solve_tri with larger N/K values Split N into chunks to fit into shared memory. If K > 128, use a larger workgroup with enough invocations. Add perf tests matching qwen3next. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 14 ++++- .../ggml-vulkan/vulkan-shaders/solve_tri.comp | 63 +++++++++++-------- tests/test-backend-ops.cpp | 15 +++++ 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f917a745d5a..d42b68d1835 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4070,10 +4070,16 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &s : device->pipeline_solve_tri_f32) { const vk_solve_tri_pipeline_state &state = s.first; + + // Max number of rows to load at a time, limited by shared memory + const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float)); + // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory + const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K)))); + ggml_vk_create_pipeline( device, s.second, "solve_tri_f32", solve_tri_f32_len, solve_tri_f32_data, "main", 3, - sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true); + sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true); } #define IM2COL(bda) \ @@ -14179,10 +14185,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm const uint32_t N = op->src[0]->ne[0]; const uint32_t K = op->src[1]->ne[0]; // K dimension limited to workgroup size - if (K > 128) { + if (K > 1u << device->max_workgroup_size_log2) { return false; } - if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) { + const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float)); + + if (batch_N == 0) { return false; } return true; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp index 253a9e7efee..3b65145032c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp @@ -5,8 +5,9 @@ layout (constant_id = 1) const uint N = 64; layout (constant_id = 2) const uint K = 32; +layout (constant_id = 3) const uint BATCH_N = 32; -layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in; uint a_base, b_base, x_base; @@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) { data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v); } -shared FLOAT_TYPE shA[N * N]; -shared FLOAT_TYPE shB[N * K]; +shared FLOAT_TYPE shA[BATCH_N * N]; +shared FLOAT_TYPE shB[BATCH_N * K]; void main() { const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; @@ -39,34 +40,42 @@ void main() { b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13; x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23; - // Load the A matrix into shA - [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) { - uint idx = i + tid; - if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) { - shA[idx] = get_a(idx / N, idx % N); + FLOAT_TYPE X[N]; + + // Loop over batches of rows + [[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) { + const uint cur_N = min(BATCH_N, N - row_base); + + // Load the A matrix batch into shA + [[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) { + shA[idx] = get_a(row_base + idx / N, idx % N); + } } - } - // Load the B matrix into shB - [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) { - uint idx = i + tid; - if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) { - shB[idx] = get_b(idx / K, idx % K); + // Load the B matrix batch into shB + [[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) { + shB[idx] = get_b(row_base + idx / K, idx % K); + } } - } - barrier(); + barrier(); - FLOAT_TYPE X[N]; - // Each thread solves one column - if (tid < K) { - [[unroll]] for (int r = 0; r < N; ++r) { - FLOAT_TYPE b = shB[r * K + tid]; - // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] - [[unroll]] for (int c = 0; c < r; ++c) { - b -= shA[r * N + c] * X[c]; + // Each thread solves one column + if (tid < K) { + [[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) { + uint r = row_base + row_offset; + FLOAT_TYPE b = shB[row_offset * K + tid]; + // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] + [[unroll]] for (int c = 0; c < r; ++c) { + b -= shA[row_offset * N + c] * X[c]; + } + FLOAT_TYPE x = b / shA[row_offset * N + r]; + X[r] = x; + store_x(r, tid, x); } - FLOAT_TYPE x = b / shA[r * N + r]; - X[r] = x; - store_x(r, tid, x); } + barrier(); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 844db455d29..0e0cb12adc1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6146,6 +6146,15 @@ struct test_solve_tri : public test_case { std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); } + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + int64_t n = ne_lhs[0]; + int64_t k = ne_rhs[0]; + int64_t batch = ne_lhs[2] * ne_lhs[3]; + // n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs + return n * (n + 1) * k * batch; + } + test_solve_tri(ggml_type type = GGML_TYPE_F32, std::array ne_lhs = { 10, 10, 4, 3 }, std::array ne_rhs = { 3, 10, 4, 3 } @@ -7756,6 +7765,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 })); for (bool v : {false, true}) { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v)); @@ -7954,6 +7965,10 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + // qwen3next with CHUNK_SIZE 64 + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 })); + // qwen3next with CHUNK_SIZE 128 + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));