diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 41ff89c4d6922..1af23588301dd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1005,3 +1005,16 @@ struct ggml_backend_cuda_context { return pool(device); } }; + +struct ggml_cuda_mm_fusion_args_host { + const ggml_tensor * x_bias = nullptr; + const ggml_tensor * gate = nullptr; + const ggml_tensor * gate_bias = nullptr; + ggml_glu_op glu_op; +}; +struct ggml_cuda_mm_fusion_args_device { + const void * x_bias = nullptr; + const void * gate = nullptr; + const void * gate_bias = nullptr; + ggml_glu_op glu_op; +}; diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index ef9e129950c98..8a5e08ef667e0 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -1,3 +1,4 @@ +#pragma once #include "common.cuh" #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bc396b521af07..19f72975c0ee4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2007,6 +2007,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } +static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, + const ggml_tensor * ffn_gate, + const ggml_tensor * glu, + const ggml_tensor * ffn_up_bias = nullptr, + const ggml_tensor * ffn_gate_bias = nullptr) { + const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr; + + if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) { + return false; + } + + const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU; + const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU; + + GGML_ASSERT(ffn_up && ffn_gate && glu); + + if (!is_mul_mat && !is_mul_mat_id) { + return false; + } + + const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (has_bias) { + if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) { + return false; + } + + if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) { + return false; + } + + if (expected_bias_op == GGML_OP_ADD) { + const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up; + const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate; + if (!up_has_mul || !gate_has_mul) { + return false; + } + } else { // GGML_OP_ADD_ID + if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) { + return false; + } + if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) { + return false; + } + } + } else { + if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) { + return false; + } + } + + if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) || + !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) { + return false; + } + + if (ffn_up->src[1] != ffn_gate->src[1]) { + return false; + } + + if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) { + return false; + } + + static constexpr std::array valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI }; + + if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) { + return false; + } + + if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) { + return false; + } + + const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) || + ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft); + + //TODO: add support for fusion for split buffers + if (split) { + return false; + } + + return true; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID; + + bool use_mul_mat_vec_f = + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && + src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); + + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; + } + + + return use_mul_mat_vec_f; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && + ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && + src0->view_src; + + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + + // fusion is not universally faster on Pascal + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (cc <= GGML_CUDA_CC_PASCAL) { + return false; + } + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; + } + + return use_mul_mat_vec_q; +} + static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); @@ -2745,7 +2886,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } } - if (node->op == GGML_OP_SCALE && + if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } @@ -2854,6 +2995,38 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } + std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; + std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; + + std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; + std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; + + if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || + ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { + + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; + const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { + return true; + } + } + + if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || + ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { + + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { + return true; + } + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -3004,6 +3177,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } + bool fused_mul_mat_vec = false; + int fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; + } + return (ggml_tensor *) nullptr; + } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) + || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) continue; + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 57ab839393aa0..c2c31cdaf231b 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,11 +1,12 @@ #include "ggml.h" #include "common.cuh" -#include "convert.cuh" +#include "unary.cuh" #include "mmvf.cuh" +#include "convert.cuh" -template +template static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { @@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f( y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU; + const T * gate_x = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr; + glu_op = fusion.glu_op; + + if (use_gate) { + gate_x = static_cast(fusion.gate); + } + if (use_bias) { + x_bias = static_cast(fusion.x_bias); + } + if (use_gate_bias) { + gate_bias = static_cast(fusion.gate_bias); + use_gate_bias = use_gate; + } else { + use_gate_bias = false; + } + } + + if (use_gate) { + gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + } + if constexpr (has_fusion) { + const int channel_bias = ids ? channel_x : channel_dst; + if (use_bias) { + x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + if (use_gate_bias) { + gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + } + const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; + float * buf_iw_gate = nullptr; + if constexpr (has_fusion) { + buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); + } if (block_size > warp_size) { if (tid < warp_size) { buf_iw[tid] = 0.0f; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid] = 0.0f; + } + } } __syncthreads(); } float sumf[ncols_dst] = {0.0f}; + float sumf_gate[ncols_dst]; + if constexpr (has_fusion) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = 0.0f; + } + } if constexpr (std::is_same_v) { const float2 * x2 = (const float2 *) x; + const float2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const float2 *) gate_x; + } + } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = x2[col2]; + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } } } } else if constexpr (std::is_same_v) { const half2 * x2 = (const half2 *) x; + const half2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const half2 *) gate_x; + } + } if (std::is_same_v) { for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); - + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = __half22float2(gate_x2[col2]); + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } } } } else { #ifdef FP16_AVAILABLE half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; + half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}}; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const half2 tmpx = x2[col2]; - + half2 tmpx_gate = make_half2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y); + } + } } } @@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f( for (int j = 0; j < ncols_dst; ++j) { sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); } + + if constexpr (has_fusion) { + if (use_gate) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]); + } + } + } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f( //TODO: add support for ggml_cuda_mad for hip_bfloat162 #if defined(GGML_USE_HIP) const int * x2 = (const int *) x; + const int * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const int *) gate_x; + } + } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; + int tmpx_gate = 0; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; @@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f( const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + const float tmpx0_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[0]); + const float tmpx1_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[1]); + ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y); + } + } } } #else const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + const nv_bfloat162 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const nv_bfloat162 *) gate_x; + } + } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const nv_bfloat162 tmpx = x2[col2]; + nv_bfloat162 tmpx_gate; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } } } #endif @@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f( for (int j = 0; j < ncols_dst; ++j) { sumf[j] = warp_reduce_sum(sumf[j]); + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } + if (block_size > warp_size) { buf_iw[tid/warp_size] = sumf[j]; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid/warp_size] = sumf_gate[j]; + } + } __syncthreads(); if (tid < warp_size) { sumf[j] = buf_iw[tid]; sumf[j] = warp_reduce_sum(sumf[j]); + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = buf_iw_gate[tid]; + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } } + if (j < ncols_dst) { __syncthreads(); } @@ -139,12 +313,70 @@ static __global__ void mul_mat_vec_f( return; } - dst[tid*stride_col_dst + row] = sumf[tid]; + float value = sumf[tid]; + + if constexpr (has_fusion) { + if (use_bias) { + value += x_bias[tid*stride_col_dst + row]; + } + + if (use_gate) { + float gate_value = sumf_gate[tid]; + if (use_gate_bias) { + gate_value += gate_bias[tid*stride_col_dst + row]; + } + switch (glu_op) { + case GGML_GLU_OP_SWIGLU: + value *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + value *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + value = ggml_cuda_op_swiglu_oai_single(gate_value, value); + break; + } + default: + break; + } + } + } + + dst[tid*stride_col_dst + row] = value; +} + +template +static void mul_mat_vec_f_switch_fusion( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if constexpr (ncols_dst == 1) { + if (has_fusion) { + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + return; + } + } + + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } template -static void launch_mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, +void launch_mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, @@ -176,57 +408,59 @@ static void launch_mul_mat_vec_f_cuda( } } - const int nbytes_shared = warp_size*sizeof(float); + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + + const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 64: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 96: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 128: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 160: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 192: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 224: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 256: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; default: { GGML_ABORT("fatal error"); @@ -236,7 +470,7 @@ static void launch_mul_mat_vec_f_cuda( template static void mul_mat_vec_f_cuda_switch_ncols_dst( - const T * x, const float * y, const int32_t * ids, float * dst, + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, @@ -246,49 +480,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( switch (ncols_dst) { case 1: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 2: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 3: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 4: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 5: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 6: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 7: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 8: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; @@ -300,29 +534,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( template static void mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { + if constexpr(std::is_same_v) { if (prec == GGML_PREC_DEFAULT) { mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); return; } } mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } -void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -348,6 +584,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + const int64_t s01 = src0->nb[1] / ts_src0; const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s1 = dst->nb[1] / ts_dst; @@ -370,19 +630,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; @@ -409,7 +669,6 @@ void ggml_cuda_op_mul_mat_vec_f( const int cc = ggml_cuda_info().devices[id].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - // ggml_cuda_op provides single, contiguous matrices const int64_t stride_row = ne00; const int64_t stride_col_y = ne10; @@ -426,22 +685,23 @@ void ggml_cuda_op_mul_mat_vec_f( const int64_t stride_sample_y = 0; const int64_t stride_sample_dst = 0; + ggml_cuda_mm_fusion_args_device empty{}; switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index 1da460992e784..a205aa8e4c538 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,6 +1,7 @@ #include "common.cuh" -void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion = nullptr); void ggml_cuda_op_mul_mat_vec_f( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 3bf0c9ed25038..7a783e4fcf9b4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,5 +1,6 @@ #include "mmvq.cuh" #include "quantize.cuh" +#include "unary.cuh" #include "vecdotq.cuh" #include @@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } -static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { +static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { case 1: @@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template // tell the compiler to use as many registers as it wants, see nwarps definition below +template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, @@ -169,8 +170,38 @@ static __global__ void mul_mat_vec_q( const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + const void * vgate = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + ggml_glu_op active_glu; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr && use_gate; + vgate = fusion.gate; + x_bias = (const float *) fusion.x_bias; + gate_bias = (const float *) fusion.gate_bias; + active_glu = fusion.glu_op; + } + + const uint32_t channel_bias = ids ? channel_x : channel_dst; + + if constexpr (has_fusion) { + if (use_bias) { + x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + if (use_gate_bias) { + gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + } + // partial sum for each thread float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; @@ -187,17 +218,35 @@ static __global__ void mul_mat_vec_q( for (int i = 0; i < rows_per_cuda_block; ++i) { tmp[j][i] += vec_dot_q_cuda( vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += vec_dot_q_cuda( + vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } } } } __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + if constexpr (!has_fusion) { + (void) tmp_shared_gate; + } else if (!use_gate) { + (void) tmp_shared_gate; + } + if (threadIdx.y > 0) { #pragma unroll for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i]; + } + } } } } @@ -216,12 +265,49 @@ static __global__ void mul_mat_vec_q( #pragma unroll for (int l = 0; l < nwarps-1; ++l) { tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x]; + } + } } tmp[j][i] = warp_reduce_sum(tmp[j][i]); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] = warp_reduce_sum(tmp_gate[j][i]); + } + } } if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { - dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; + float result = tmp[j][threadIdx.x]; + if constexpr (has_fusion) { + if (use_bias) { + result += x_bias[j*stride_col_dst + threadIdx.x]; + } + if (use_gate) { + float gate_value = tmp_gate[j][threadIdx.x]; + if (use_gate_bias) { + gate_value += gate_bias[j*stride_col_dst + threadIdx.x]; + } + switch (active_glu) { + case GGML_GLU_OP_SWIGLU: + result *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + result *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + result = ggml_cuda_op_swiglu_oai_single(gate_value, result); + break; + } + default: + result = result * gate_value; + break; + } + } + } + dst[j*stride_col_dst + threadIdx.x] = result; } } } @@ -235,9 +321,37 @@ static std::pair calc_launch_params( return {block_nums, block_dims}; } +template +static void mul_mat_vec_q_switch_fusion( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if constexpr (c_ncols_dst == 1) { + if (has_fusion) { + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + return; + } + } + + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); +} + template static void mul_mat_vec_q_switch_ncols_dst( - const void * vx, const void * vy, const int32_t * ids, float * dst, + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int stride_col_y, const int stride_col_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst, @@ -256,80 +370,83 @@ static void mul_mat_vec_q_switch_ncols_dst( const int warp_size = ggml_cuda_info().devices[device].warp_size; const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 2: { constexpr int c_ncols_dst = 2; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 3: { constexpr int c_ncols_dst = 3; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 4: { constexpr int c_ncols_dst = 4; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 5: { constexpr int c_ncols_dst = 5; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 6: { constexpr int c_ncols_dst = 6; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 7: { constexpr int c_ncols_dst = 7; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 8: { constexpr int c_ncols_dst = 8; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; default: GGML_ABORT("fatal error"); break; } -} + GGML_UNUSED(has_fusion); +} static void mul_mat_vec_q_switch_type( - const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst, + const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int stride_col_y, const int stride_col_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst, @@ -339,143 +456,123 @@ static void mul_mat_vec_q_switch_type( switch (type_x) { case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; default: GGML_ABORT("fatal error"); @@ -484,7 +581,8 @@ static void mul_mat_vec_q_switch_type( } void ggml_cuda_mul_mat_vec_q( - ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. @@ -508,6 +606,31 @@ void ggml_cuda_mul_mat_vec_q( const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + // If src0 is a temporary compute buffer, clear any potential padding. if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { const size_t size_data = ggml_nbytes(src0); @@ -549,10 +672,10 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_y = ids ? s11 : s12; mul_mat_vec_q_switch_type( - src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00, + src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -578,8 +701,9 @@ void ggml_cuda_op_mul_mat_vec_q( const int stride_row_x = ne00 / ggml_blck_size(src0->type); const int stride_col_y = src1_padded_row_size / QK8_1; + ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( - src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, + src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 39dc7d33eb5ac..4bb10cfaec2b6 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -3,7 +3,7 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 3c564566a51ff..5f0d3a6726aef 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) { } static __device__ __forceinline__ float op_gelu(float x) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + return ggml_cuda_op_gelu_single(x); } static __device__ __forceinline__ float op_gelu_erf(float x) { @@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) { } static __device__ __forceinline__ float op_silu(float x) { - return x / (1.0f + expf(-x)); + return ggml_cuda_op_silu_single(x); } static __device__ __forceinline__ float op_tanh(float x) { @@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons float xi = x[j0]; float gi = g[j1]; - xi = fminf(xi, limit); - gi = fmaxf(fminf(gi, limit), -limit); - - float out_glu = xi / (1.0f + expf(-xi * alpha)); - out_glu = out_glu * (1.0f + gi); - dst[i] = out_glu; + dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit); } template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 8e7644fcd9a48..6c738cefecfd2 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,3 +1,4 @@ +#pragma once #include "common.cuh" #define CUDA_NEG_BLOCK_SIZE 256 @@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { + return x / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + + return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x))); +} + +__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) { + x = fminf(x, limit); + g = fmaxf(fminf(g, limit), -limit); + + float out_glu = x / (1.0f + expf(-x * alpha)); + out_glu = out_glu * (1.0f + g); + return out_glu; +} diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 41fa6894377ea..c1b946e3f715d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -810,6 +810,9 @@ ggml_tensor * llm_graph_context::build_ffn( GGML_ABORT("fatal error"); } + //expand here so that we can fuse ffn gate + ggml_build_forward_expand(gf, cur); + if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); @@ -1091,6 +1094,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } + //expand here so that we can fuse ffn gate + ggml_build_forward_expand(gf, cur); + experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9eb2b66879c0b..33ac27ff5ca00 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4721,6 +4721,140 @@ struct test_topk_moe: public test_case { } }; +struct test_mul_mat_vec_fusion : public test_case { + const ggml_type type; + const ggml_glu_op glu_op; + const int64_t m; + const int64_t n; + const int64_t k; + const bool use_id; + const int n_mats; + const int n_used; + const bool b; // broadcast b matrix (only for use_id) + const bool with_bias; + const bool with_gate; + + test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k, + bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true) + : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) { + if (use_id) { + GGML_ASSERT(n_used <= n_mats); + } + } + + std::string vars() override { + return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate); + } + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "MUL_MAT_VEC_FUSION"; + } + + bool run_whole_graph() override { return true; } + + ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) { + ggml_tensor * out = nullptr; + if (with_gate) { + if (glu_op == GGML_GLU_OP_SWIGLU_OAI) { + constexpr float alpha = 1.702f; + constexpr float limit = 7.0f; + out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit); + } else { + out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op); + } + } + return out; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + if (!use_id) { + std::array ne = {k, m, 1, 1}; + std::array ne0 = {k, n, 1, 1}; + + ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr; + ggml_tensor * up = ggml_new_tensor(ctx, type, 4, ne0.data()); + + ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur); + if (with_bias) { + std::array bias_ne = {ffn_up->ne[0], 1, 1, 1}; + ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); + ffn_up = ggml_add(ctx, ffn_up, up_bias); + } + + ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr; + if (with_bias && with_gate) { + std::array bias_ne = {ffn_gate->ne[0], 1, 1, 1}; + ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); + ffn_gate = ggml_add(ctx, ffn_gate, gate_bias); + } + + ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + ggml_set_name(out, "out"); + return out; + } else { + ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats); + ggml_tensor * ups = ggml_new_tensor_3d(ctx, type, k, n, n_mats); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m); + + if (n_used != n_mats) { + ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0); + } + + ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m); + ggml_set_name(cur, "cur"); + + ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids); + if (with_bias) { + ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats); + ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids); + } + + ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr; + if (with_bias && with_gate) { + ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats); + ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids); + } + + ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + ggml_set_name(out, "out"); + return out; + } + } + + void initialize_tensors(ggml_context * ctx) override { + if (!use_id) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t); + } + } else { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i % n_mats; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } + } + + double max_nmse_err() override { + return 5e-3; + } +}; + // GGML_OP_SUM struct test_sum : public test_case { const ggml_type type; @@ -6983,6 +7117,33 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3})); + for (ggml_type type : base_types) { + for (bool with_gate : {false, true}) { + for (bool use_id : {false, true}) { + for (bool b : {false, true}) { + if (!use_id && b) { + continue; + } + for (bool with_bias : {false, true}) { + if (!with_gate && !with_bias) { + continue; + } + for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) { + if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) { + continue; + } + if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) { + continue; + } + test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256, + use_id, 16, 8, b, with_bias, with_gate)); + } + } + } + } + } + } + for (bool with_norm : {false, true}) { test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));