From 63918141633e1aa04c0ba2a20eab3ef09d4a8e6b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 13 Nov 2025 18:16:07 -0600 Subject: [PATCH] vulkan: Fuse mul_mat_id+add_id+mul and mul_mat+add+add. These both show up in gpt-oss. Also, cleanup the mul_mat_vec fusion code a bit. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 246 +++++++++++++----- .../vulkan-shaders/mul_mat_vec_base.glsl | 96 ++++--- .../vulkan-shaders/mul_mat_vec_iface.glsl | 33 +++ .../vulkan-shaders/mul_mat_vec_nc.comp | 20 +- .../vulkan-shaders/mul_mat_vec_p021.comp | 20 +- tests/test-backend-ops.cpp | 24 +- 6 files changed, 294 insertions(+), 145 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f5812dc4694cb..ef99c3c1eba45 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -32,6 +32,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include +#include #include #include #include @@ -824,6 +825,12 @@ struct vk_mat_mat_push_constants { uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t padded_N; }; + +#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1 +#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2 +#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 +#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 + struct vk_mat_vec_push_constants { uint32_t ncols; uint32_t stride_a; @@ -832,8 +839,7 @@ struct vk_mat_vec_push_constants { uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; - uint32_t enable_bias; - uint32_t enable_scale; + uint32_t fusion_flags; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; @@ -847,7 +853,7 @@ struct vk_mat_vec_p021_push_constants { uint32_t nchannels_y; uint32_t b_offset; uint32_t d_offset; - uint32_t enable_bias; + uint32_t fusion_flags; }; struct vk_mat_vec_nc_push_constants { @@ -863,7 +869,7 @@ struct vk_mat_vec_nc_push_constants { uint32_t nb03; uint32_t nb13; uint32_t nb23; - uint32_t enable_bias; + uint32_t fusion_flags; }; struct vk_mat_mat_id_push_constants { @@ -881,8 +887,7 @@ struct vk_mat_vec_id_push_constants { uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; - uint32_t enable_bias; - uint32_t enable_scale; + uint32_t fusion_flags; uint32_t nei0; uint32_t ne11; }; @@ -3465,8 +3470,8 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0; const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0; - static constexpr uint32_t mul_mat_vec_num_bindings = 4; - static constexpr uint32_t mul_mat_vec_id_num_bindings = 5; + static constexpr uint32_t mul_mat_vec_num_bindings = 5; + static constexpr uint32_t mul_mat_vec_id_num_bindings = 6; for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) { const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4); @@ -6871,21 +6876,31 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& groups_x = CEIL_DIV(groups_x, groups_z); } - uint32_t enable_bias = ctx->num_additional_fused_ops > 0; + uint32_t fusion_flags = 0; - vk_subbuffer d_B = d_D; - - if (enable_bias) { + vk_subbuffer d_F0 = d_D; + if (ctx->num_additional_fused_ops > 0) { const ggml_tensor * add = cgraph->nodes[node_idx + 1]; const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0]; - d_B = ggml_vk_tensor_subbuffer(ctx, bias); + d_F0 = ggml_vk_tensor_subbuffer(ctx, bias); + fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0; + } + + vk_subbuffer d_F1 = d_D; + if (ctx->num_additional_fused_ops == 2) { + const ggml_tensor * add = cgraph->nodes[node_idx + 2]; + const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0]; + + d_F1 = ggml_vk_tensor_subbuffer(ctx, bias); + fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1; } // compute const vk_mat_vec_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0, + stride_batch_x, stride_batch_y, stride_batch_d, + fusion_flags, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, @@ -6893,7 +6908,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& d_X, d_Y, d_D, - d_B, + d_F0, + d_F1, }, pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); @@ -6946,22 +6962,31 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0); vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true); - vk_subbuffer d_B = d_D; + vk_subbuffer d_F0 = d_D; - uint32_t enable_bias = ctx->num_additional_fused_ops > 0; + uint32_t fusion_flags = 0; - if (enable_bias) { + if (ctx->num_additional_fused_ops > 0) { const ggml_tensor * add = cgraph->nodes[node_idx + 1]; const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0]; - d_B = ggml_vk_tensor_subbuffer(ctx, bias); + d_F0 = ggml_vk_tensor_subbuffer(ctx, bias); + fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0; + } + + vk_subbuffer d_F1 = d_D; + if (ctx->num_additional_fused_ops > 1) { + const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1]; + + d_F1 = ggml_vk_tensor_subbuffer(ctx, bias); + fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1; } // compute vk_mat_vec_p021_push_constants pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, - 0, 0, enable_bias + 0, 0, fusion_flags }; init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -6977,7 +7002,8 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c d_Qx, d_Qy, d_D, - d_B, + d_F0, + d_F1, }, pc, { 1, (uint32_t)ne01, workgroups_z }); } @@ -7029,15 +7055,24 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true); vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0); vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true); - vk_subbuffer d_B = d_D; + vk_subbuffer d_F0 = d_D; - uint32_t enable_bias = ctx->num_additional_fused_ops > 0; + uint32_t fusion_flags = 0; - if (enable_bias) { + if (ctx->num_additional_fused_ops > 0) { const ggml_tensor * add = cgraph->nodes[node_idx + 1]; const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0]; - d_B = ggml_vk_tensor_subbuffer(ctx, bias); + d_F0 = ggml_vk_tensor_subbuffer(ctx, bias); + fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0; + } + + vk_subbuffer d_F1 = d_D; + if (ctx->num_additional_fused_ops > 1) { + const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1]; + + d_F1 = ggml_vk_tensor_subbuffer(ctx, bias); + fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1; } // compute @@ -7046,7 +7081,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, 0, 0, - nb03, nb13, nb23, enable_bias + nb03, nb13, nb23, fusion_flags }; init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -7056,7 +7091,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con d_Qx, d_Qy, d_D, - d_B, + d_F0, + d_F1, }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } @@ -7477,7 +7513,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0); vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1); vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids); - vk_subbuffer d_B = d_D; + vk_subbuffer d_F0 = d_D; vk_subbuffer d_X, d_Y; if (qx_needs_dequant) { @@ -7530,30 +7566,34 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte groups_x = CEIL_DIV(groups_x, groups_z); } - uint32_t enable_bias = 0; - uint32_t enable_scale = 0; + uint32_t fusion_flags = 0; + if (ctx->num_additional_fused_ops > 0) { + const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1]; + + d_F0 = ggml_vk_tensor_subbuffer(ctx, bias); + if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) { - enable_scale = 1; + fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0; } else { GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID); - enable_bias = 1; + fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0; } } - if (enable_bias || enable_scale) { - const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1]; + vk_subbuffer d_F1 = d_D; + if (ctx->num_additional_fused_ops > 1) { + const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1]; - d_B = ggml_vk_tensor_subbuffer(ctx, bias); + d_F1 = ggml_vk_tensor_subbuffer(ctx, scale); + fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; } // compute const vk_mat_vec_id_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), - - enable_bias, enable_scale, - + fusion_flags, (uint32_t)nei0, (uint32_t)ne11, }; ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, @@ -7561,7 +7601,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte d_X, d_Y, d_D, - d_B, + d_F0, + d_F1, d_ids, }, pc, { groups_x, (uint32_t)nei0, groups_z }); @@ -12305,10 +12346,7 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return false; } } - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) { - // additional constraints specific to this fusion - const ggml_tensor *mul = cgraph->nodes[node_idx]; - const ggml_tensor *add = cgraph->nodes[node_idx + 1]; + auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) { const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0]; // mat-vec only @@ -12328,14 +12366,31 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g if (get_misalign_bytes(ctx, bias) != 0) { return false; } - } - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) { + return true; + }; + + if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) { // additional constraints specific to this fusion const ggml_tensor *mul = cgraph->nodes[node_idx]; const ggml_tensor *add = cgraph->nodes[node_idx + 1]; - const ggml_tensor *bias = add->src[1]; - if (mul != add->src[0]) { + if (!mm_add_ok(mul, add)) { + return false; + } + if (ops.size() == 3) { + if (ops.begin()[2] != GGML_OP_ADD) { + return false; + } + if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) { + return false; + } + } + } + + auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) { + const ggml_tensor *scale = mul->src[1]; + + if (mmid != mul->src[0]) { return false; } // mat-vec only @@ -12343,30 +12398,34 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return false; } // shaders assume the types match - if (mul->type != bias->type) { + if (mmid->type != scale->type) { return false; } // shaders assume the bias is contiguous - if (!ggml_is_contiguous(bias)) { + if (!ggml_is_contiguous(scale)) { return false; } - // the ID tensor must be the same for mul_mat_id and add_id - if (mul->src[2] != add->src[2]) { + // unaligned bias isn't handled + if (get_misalign_bytes(ctx, scale) != 0) { return false; } - // unaligned bias isn't handled - if (get_misalign_bytes(ctx, bias) != 0) { + // shader only indexes by expert index + if (scale->ne[0] != 1 || + scale->ne[1] != mul->ne[1] || + scale->ne[2] != 1 || + scale->ne[3] != 1) { return false; } - } + return true; + }; - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) { + if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) { // additional constraints specific to this fusion - const ggml_tensor *mmid = cgraph->nodes[node_idx]; - const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; - const ggml_tensor *scale = mul->src[1]; + const ggml_tensor *mul = cgraph->nodes[node_idx]; + const ggml_tensor *add = cgraph->nodes[node_idx + 1]; + const ggml_tensor *bias = add->src[1]; - if (mmid != mul->src[0]) { + if (mul != add->src[0]) { return false; } // mat-vec only @@ -12374,22 +12433,37 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return false; } // shaders assume the types match - if (mmid->type != scale->type) { + if (mul->type != bias->type) { return false; } // shaders assume the bias is contiguous - if (!ggml_is_contiguous(scale)) { + if (!ggml_is_contiguous(bias)) { + return false; + } + // the ID tensor must be the same for mul_mat_id and add_id + if (mul->src[2] != add->src[2]) { return false; } // unaligned bias isn't handled - if (get_misalign_bytes(ctx, scale) != 0) { + if (get_misalign_bytes(ctx, bias) != 0) { return false; } - // shader only indexes by expert index - if (scale->ne[0] != 1 || - scale->ne[1] != mul->ne[1] || - scale->ne[2] != 1 || - scale->ne[3] != 1) { + + if (ops.size() == 3) { + if (ops.begin()[2] != GGML_OP_MUL) { + return false; + } + const ggml_tensor *mul = cgraph->nodes[node_idx + 2]; + return mmid_mul_ok(add, mul); + } + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *mmid = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + + if (!mmid_mul_ok(mmid, mul)) { return false; } } @@ -12704,8 +12778,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) { + ctx->num_additional_fused_ops = 2; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 2; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { @@ -12872,6 +12950,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * std::vector new_order; std::vector used(graph->n_nodes, false); + std::set used_node_set; + int first_unused = 0; while (first_unused < graph->n_nodes) { std::vector current_set; @@ -12894,6 +12974,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (match_pattern(pattern, first_unused)) { for (size_t j = 0; j < pattern.size(); ++j) { new_order.push_back(graph->nodes[first_unused + j]); + used_node_set.insert(graph->nodes[first_unused + j]); used[first_unused + j] = true; } while (first_unused < graph->n_nodes && used[first_unused]) { @@ -12997,6 +13078,36 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * used[set_rows_idx] = true; } } + // Look for MUL_MAT_ID + ADD_ID + MUL + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD_ID && + graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_MUL && + graph->nodes[k]->src[0] == graph->nodes[j] && + // src1 must either be weights or already processed + (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } + // Look for MUL_MAT + ADD + ADD + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD && + graph->nodes[j-1]->op == GGML_OP_MUL_MAT) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_ADD && + graph->nodes[k]->src[0] == graph->nodes[j] && + // src1 must either be weights or already processed + (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } } } // Second pass grabs view nodes. @@ -13029,6 +13140,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * // Push the current set into new_order for (auto c : current_set) { new_order.push_back(graph->nodes[c]); + used_node_set.insert(graph->nodes[c]); used[c] = true; } while (first_unused < graph->n_nodes && used[first_unused]) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index eb8fa6dc09fb1..e4651a683bf0e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -11,29 +11,7 @@ #define EXPERT_COUNT 8 #endif -#include "types.glsl" - -#ifndef MMQ -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -#else -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; -#endif - -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -#ifdef B_TYPE_VEC2 -layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; -#endif -#ifdef B_TYPE_VEC4 -layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; -#endif - -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; - -layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];}; - -#ifdef MUL_MAT_ID -layout (binding = 4) readonly buffer IDS {int data_ids[];}; -#endif +#include "mul_mat_vec_iface.glsl" #include "dequant_funcs.glsl" @@ -48,8 +26,7 @@ layout (push_constant) uniform parameter uint batch_stride_b; uint batch_stride_d; - uint enable_bias; - uint enable_scale; + uint fusion_flags; #ifdef MUL_MAT_ID uint nei0; @@ -123,17 +100,24 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t if (tid == 0) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { - if (p.enable_bias != 0) { #ifdef MUL_MAT_ID - temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]); -#else - temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); -#endif + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } -#ifdef MUL_MAT_ID - if (p.enable_scale != 0) { + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]); + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + } +#else + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) { + temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]); } #endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); @@ -171,17 +155,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { temp[j][n] += tmpsh[j][n][s]; } - if (p.enable_bias != 0) { #ifdef MUL_MAT_ID - temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]); -#else - temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); -#endif + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } -#ifdef MUL_MAT_ID - if (p.enable_scale != 0) { + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]); + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + } +#else + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) { + temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]); } #endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); @@ -209,17 +200,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs if (tid == 0) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { - if (p.enable_bias != 0) { #ifdef MUL_MAT_ID - tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]); -#else - tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); -#endif + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } -#ifdef MUL_MAT_ID - if (p.enable_scale != 0) { + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]); + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]); + } +#else + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) { + tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]); } #endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl new file mode 100644 index 0000000000000..14ab1fd74c0a7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -0,0 +1,33 @@ +#include "types.glsl" + +#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1 +#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2 +#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 +#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 + +#ifndef MMQ +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_VEC4) +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +#endif +#else +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#endif + +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +#ifdef B_TYPE_VEC2 +layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#endif +#ifdef B_TYPE_VEC4 +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#endif + +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (binding = 3) readonly buffer Fuse0 {D_TYPE data_fuse0[];}; +layout (binding = 4) readonly buffer Fuse1 {D_TYPE data_fuse1[];}; + +#ifdef MUL_MAT_ID +layout (binding = 5) readonly buffer IDS {int data_ids[];}; +#endif + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index 3f4584c984c1f..beea5296225e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -8,14 +8,7 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - -layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; -layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; - -layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];}; +#include "mul_mat_vec_iface.glsl" layout (push_constant) uniform parameter { @@ -31,7 +24,7 @@ layout (push_constant) uniform parameter uint nb03; uint nb13; uint nb23; - uint enable_bias; + uint fusion_flags; } p; shared FLOAT_TYPE tmp[BLOCK_SIZE]; @@ -120,9 +113,12 @@ void main() { } if (tid == 0) { - if (p.enable_bias != 0) { - tmp[0] += FLOAT_TYPE(data_bias[idst]); + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + tmp[0] += FLOAT_TYPE(data_fuse0[idst]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) { + tmp[0] += FLOAT_TYPE(data_fuse1[idst]); } - dst[idst] = tmp[0]; + data_d[idst] = tmp[0]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp index d51424d417573..32628c6e97d43 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -10,14 +10,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; - -layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; -layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; - -layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];}; +#include "mul_mat_vec_iface.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 32; // gqa_ratio is in the range [1,8] @@ -31,7 +24,7 @@ layout (push_constant) uniform parameter uint nchannels_y; uint b_offset; uint d_offset; - uint enable_bias; + uint fusion_flags; } p; #if !USE_SUBGROUP_ADD @@ -151,10 +144,13 @@ void main() { [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { // dst is not transposed and not permuted const uint idst = (channel + c)*nrows_dst + row_dst; - if (p.enable_bias != 0) { - temp[c] += FLOAT_TYPE(data_bias[idst]); + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { + temp[c] += FLOAT_TYPE(data_fuse0[idst]); + } + if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) { + temp[c] += FLOAT_TYPE(data_fuse1[idst]); } - dst[idst] = temp[c]; + data_d[idst] = temp[c]; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a7707eb03fe87..a87190e9f446f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5002,17 +5002,19 @@ struct test_mul_mat_vec_fusion : public test_case { const bool b; // broadcast b matrix (only for use_id) const bool with_bias; const bool with_gate; + std::array batch_dims; test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k, - bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true) - : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) { + bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true, + std::array batch_dims = {4, 2}) + : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) { if (use_id) { GGML_ASSERT(n_used <= n_mats); } } std::string vars() override { - return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate); + return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims); } std::string op_desc(ggml_tensor * t) override { @@ -5038,8 +5040,8 @@ struct test_mul_mat_vec_fusion : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { if (!use_id) { - const int channels = 4; - const int samples = 2; + const int channels = batch_dims[0]; + const int samples = batch_dims[1]; std::array ne = { k, m, channels, samples }; std::array ne0 = { k, n, channels, samples }; @@ -5062,6 +5064,11 @@ struct test_mul_mat_vec_fusion : public test_case { } ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + + std::array bias2_ne = { out->ne[0], 1, channels, samples }; + ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data()); + out = ggml_add(ctx, out, bias2); + ggml_set_name(out, "out"); return out; } else { @@ -5089,6 +5096,11 @@ struct test_mul_mat_vec_fusion : public test_case { } ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + + std::array scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] }; + ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data()); + out = ggml_mul(ctx, out, scale); + ggml_set_name(out, "out"); return out; } @@ -7645,6 +7657,8 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256, use_id, 16, 8, b, with_bias, with_gate)); + test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256, + use_id, 16, 8, b, with_bias, with_gate, {1, 1})); } } }