Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4070,10 +4070,16 @@ static void ggml_vk_load_shaders(vk_device& device) {

for (auto &s : device->pipeline_solve_tri_f32) {
const vk_solve_tri_pipeline_state &state = s.first;

// Max number of rows to load at a time, limited by shared memory
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
// Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));

ggml_vk_create_pipeline(
device, s.second, "solve_tri_f32",
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
}

#define IM2COL(bda) \
Expand Down Expand Up @@ -14179,10 +14185,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
const uint32_t N = op->src[0]->ne[0];
const uint32_t K = op->src[1]->ne[0];
// K dimension limited to workgroup size
if (K > 128) {
if (K > 1u << device->max_workgroup_size_log2) {
return false;
}
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));

if (batch_N == 0) {
return false;
}
return true;
Expand Down
63 changes: 36 additions & 27 deletions ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

layout (constant_id = 1) const uint N = 64;
layout (constant_id = 2) const uint K = 32;
layout (constant_id = 3) const uint BATCH_N = 32;

layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in;

uint a_base, b_base, x_base;

Expand All @@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) {
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
}

shared FLOAT_TYPE shA[N * N];
shared FLOAT_TYPE shB[N * K];
shared FLOAT_TYPE shA[BATCH_N * N];
shared FLOAT_TYPE shB[BATCH_N * K];

void main() {
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
Expand All @@ -39,34 +40,42 @@ void main() {
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;

// Load the A matrix into shA
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
shA[idx] = get_a(idx / N, idx % N);
FLOAT_TYPE X[N];

// Loop over batches of rows
[[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) {
const uint cur_N = min(BATCH_N, N - row_base);

// Load the A matrix batch into shA
[[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) {
shA[idx] = get_a(row_base + idx / N, idx % N);
}
}
}
// Load the B matrix into shB
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
shB[idx] = get_b(idx / K, idx % K);
// Load the B matrix batch into shB
[[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) {
shB[idx] = get_b(row_base + idx / K, idx % K);
}
}
}
barrier();
barrier();

FLOAT_TYPE X[N];
// Each thread solves one column
if (tid < K) {
[[unroll]] for (int r = 0; r < N; ++r) {
FLOAT_TYPE b = shB[r * K + tid];
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
[[unroll]] for (int c = 0; c < r; ++c) {
b -= shA[r * N + c] * X[c];
// Each thread solves one column
if (tid < K) {
[[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) {
uint r = row_base + row_offset;
FLOAT_TYPE b = shB[row_offset * K + tid];
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
[[unroll]] for (int c = 0; c < r; ++c) {
b -= shA[row_offset * N + c] * X[c];
}
FLOAT_TYPE x = b / shA[row_offset * N + r];
X[r] = x;
store_x(r, tid, x);
}
FLOAT_TYPE x = b / shA[r * N + r];
X[r] = x;
store_x(r, tid, x);
}
barrier();
}
}
15 changes: 15 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6146,6 +6146,15 @@ struct test_solve_tri : public test_case {

std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }

uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
int64_t n = ne_lhs[0];
int64_t k = ne_rhs[0];
int64_t batch = ne_lhs[2] * ne_lhs[3];
// n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs
return n * (n + 1) * k * batch;
}

test_solve_tri(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
Expand Down Expand Up @@ -7756,6 +7765,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));

for (bool v : {false, true}) {
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
Expand Down Expand Up @@ -7954,6 +7965,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {

test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
// qwen3next with CHUNK_SIZE 64
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
// qwen3next with CHUNK_SIZE 128
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));

test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
Expand Down
Loading