From 26fa8d07a55c30f7e26238d866aa746c3cd15a47 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 16 Oct 2025 13:06:36 +0800 Subject: [PATCH 1/8] CUDA: fuse ffn gate for mmvf --- ggml/src/ggml-cuda/CMakeLists.txt | 4 + ggml/src/ggml-cuda/common.cuh | 13 + ggml/src/ggml-cuda/convert.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 338 +++++++++- ggml/src/ggml-cuda/mmvf.cu | 362 +---------- ggml/src/ggml-cuda/mmvf.cuh | 593 +++++++++++++++++- ggml/src/ggml-cuda/mmvq.cu | 470 +++----------- ggml/src/ggml-cuda/mmvq.cuh | 470 +++++++++++++- .../template-instances/generate_cu_files.py | 24 + .../mmvf-instance-ncols_1.cu | 5 + .../mmvf-instance-ncols_2.cu | 5 + .../mmvf-instance-ncols_3.cu | 5 + .../mmvf-instance-ncols_4.cu | 5 + .../mmvf-instance-ncols_5.cu | 5 + .../mmvf-instance-ncols_6.cu | 5 + .../mmvf-instance-ncols_7.cu | 5 + .../mmvf-instance-ncols_8.cu | 5 + .../template-instances/mmvq-instance-iq1_m.cu | 5 + .../template-instances/mmvq-instance-iq1_s.cu | 5 + .../template-instances/mmvq-instance-iq2_s.cu | 5 + .../mmvq-instance-iq2_xs.cu | 5 + .../mmvq-instance-iq2_xxs.cu | 5 + .../template-instances/mmvq-instance-iq3_s.cu | 5 + .../mmvq-instance-iq3_xxs.cu | 5 + .../mmvq-instance-iq4_nl.cu | 5 + .../mmvq-instance-iq4_xs.cu | 5 + .../template-instances/mmvq-instance-mxfp4.cu | 5 + .../template-instances/mmvq-instance-q2_k.cu | 5 + .../template-instances/mmvq-instance-q3_k.cu | 5 + .../template-instances/mmvq-instance-q4_0.cu | 5 + .../template-instances/mmvq-instance-q4_1.cu | 5 + .../template-instances/mmvq-instance-q4_k.cu | 5 + .../template-instances/mmvq-instance-q5_0.cu | 5 + .../template-instances/mmvq-instance-q5_1.cu | 5 + .../template-instances/mmvq-instance-q5_k.cu | 5 + .../template-instances/mmvq-instance-q6_k.cu | 5 + .../template-instances/mmvq-instance-q8_0.cu | 5 + ggml/src/ggml-cuda/unary.cu | 14 +- ggml/src/ggml-cuda/unary.cuh | 21 + src/llama-graph.cpp | 6 + tests/test-backend-ops.cpp | 160 +++++ 41 files changed, 1874 insertions(+), 742 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 3024775135966..6f349f2a4a07e 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -50,6 +50,10 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmvq*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmvf*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmf*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 41ff89c4d6922..1af23588301dd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1005,3 +1005,16 @@ struct ggml_backend_cuda_context { return pool(device); } }; + +struct ggml_cuda_mm_fusion_args_host { + const ggml_tensor * x_bias = nullptr; + const ggml_tensor * gate = nullptr; + const ggml_tensor * gate_bias = nullptr; + ggml_glu_op glu_op; +}; +struct ggml_cuda_mm_fusion_args_device { + const void * x_bias = nullptr; + const void * gate = nullptr; + const void * gate_bias = nullptr; + ggml_glu_op glu_op; +}; diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index ef9e129950c98..8a5e08ef667e0 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -1,3 +1,4 @@ +#pragma once #include "common.cuh" #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bc396b521af07..3767189dadeda 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2007,6 +2007,131 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } +static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, + const ggml_tensor * ffn_gate, + const ggml_tensor * glu, + const ggml_tensor * ffn_up_bias = nullptr, + const ggml_tensor * ffn_gate_bias = nullptr) { + const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr; + + if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) { + return false; + } + + const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU; + const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU; + + GGML_ASSERT(ffn_up && ffn_gate && glu); + + if (!is_mul_mat && !is_mul_mat_id) { + return false; + } + + const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (has_bias) { + if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) { + return false; + } + + if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) { + return false; + } + + if (expected_bias_op == GGML_OP_ADD) { + const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up; + const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate; + if (!up_has_mul || !gate_has_mul) { + return false; + } + } else { // GGML_OP_ADD_ID + if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) { + return false; + } + if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) { + return false; + } + } + } else { + if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) { + return false; + } + } + + if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) || + !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) { + return false; + } + + if (ffn_up->src[1] != ffn_gate->src[1]) { + return false; + } + + if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) { + return false; + } + + static constexpr std::array valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI }; + + if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) { + return false; + } + + if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) { + return false; + } + + const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) || + ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft); + + //TODO: add support for fusion for split buffers + if (split) { + return false; + } + + return true; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID; + + bool use_mul_mat_vec_f = + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && + src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); + + if (tensor->op == GGML_OP_MUL_MAT_ID) { + use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne[2] == 1; + } + + return use_mul_mat_vec_f; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && + ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && + src0->view_src; + + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + + if (tensor->op == GGML_OP_MUL_MAT_ID) { + use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne[2] == 1; + } + + return use_mul_mat_vec_q; +} + static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); @@ -2745,7 +2870,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } } - if (node->op == GGML_OP_SCALE && + if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } @@ -2854,6 +2979,39 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } + std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; + std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; + + std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; + + std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; + + if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || + ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { + + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; + const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { + return true; + } + } + + if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || + ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { + + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { + return true; + } + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -3004,6 +3162,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } + bool fused_mul_mat_vec = false; + int fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; + } + return (ggml_tensor *) nullptr; + } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) + || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) continue; + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 57ab839393aa0..4c67215b44336 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,328 +1,10 @@ -#include "ggml.h" -#include "common.cuh" -#include "convert.cuh" +#include "ggml-cuda/common.cuh" #include "mmvf.cuh" -template -static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { - const int row = blockIdx.x; - const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); - const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); - const int sample_y = sample_dst; - const int tid = threadIdx.x; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - float * buf_iw = (float *) data_mmv; - - if (block_size > warp_size) { - if (tid < warp_size) { - buf_iw[tid] = 0.0f; - } - __syncthreads(); - } - - float sumf[ncols_dst] = {0.0f}; - - if constexpr (std::is_same_v) { - const float2 * x2 = (const float2 *) x; - - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = x2[col2]; - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); - } - } - } else if constexpr (std::is_same_v) { - const half2 * x2 = (const half2 *) x; - - if (std::is_same_v) { - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = __half22float2(x2[col2]); - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); - } - } - } else { -#ifdef FP16_AVAILABLE - half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; - - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const half2 tmpx = x2[col2]; - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); - } - } - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); - } -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE - } - } else if constexpr (std::is_same_v) { -//TODO: add support for ggml_cuda_mad for hip_bfloat162 -#if defined(GGML_USE_HIP) - const int * x2 = (const int *) x; - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const int tmpx = x2[col2]; -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); - const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); - ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); - } - } -#else - const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const nv_bfloat162 tmpx = x2[col2]; -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); - } - } -#endif - } else { - static_assert(std::is_same_v, "unsupported type"); - } - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf[j] = warp_reduce_sum(sumf[j]); - - if (block_size > warp_size) { - buf_iw[tid/warp_size] = sumf[j]; - __syncthreads(); - if (tid < warp_size) { - sumf[j] = buf_iw[tid]; - sumf[j] = warp_reduce_sum(sumf[j]); - } - if (j < ncols_dst) { - __syncthreads(); - } - } - } - - if (tid >= ncols_dst) { - return; - } - - dst[tid*stride_col_dst + row] = sumf[tid]; -} - -template -static void launch_mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - GGML_ASSERT(ncols % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); - const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t block_size_best = warp_size; - int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); - int64_t max_block_size = 256; - if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { - max_block_size = 128; - } - for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { - const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); - if (niter < niter_best) { - niter_best = niter; - block_size_best = block_size; - } - } - - const int nbytes_shared = warp_size*sizeof(float); - const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); - const dim3 block_dims(block_size_best, 1, 1); - switch (block_size_best) { - case 32: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 64: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 96: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 128: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 160: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 192: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 224: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 256: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_vec_f_cuda_switch_ncols_dst( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 2: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 3: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 4: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 5: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 6: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 7: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 8: - launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } -} - -template -static void mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - enum ggml_prec prec, cudaStream_t stream) { - if constexpr(std::is_same_v) { - if (prec == GGML_PREC_DEFAULT) { - mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - return; - } - } - mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); -} +#include "ggml.h" -void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -348,6 +30,28 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + const int64_t s01 = src0->nb[1] / ts_src0; const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s1 = dst->nb[1] / ts_dst; @@ -370,19 +74,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; @@ -409,7 +113,6 @@ void ggml_cuda_op_mul_mat_vec_f( const int cc = ggml_cuda_info().devices[id].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - // ggml_cuda_op provides single, contiguous matrices const int64_t stride_row = ne00; const int64_t stride_col_y = ne10; @@ -426,22 +129,23 @@ void ggml_cuda_op_mul_mat_vec_f( const int64_t stride_sample_y = 0; const int64_t stride_sample_dst = 0; + ggml_cuda_mm_fusion_args_device empty{}; switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index 1da460992e784..dc476a4b169d7 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,6 +1,566 @@ +#pragma once + +#include "ggml.h" #include "common.cuh" +#include "convert.cuh" +#include "unary.cuh" + +#include + +template +static __global__ void mul_mat_vec_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const int row = blockIdx.x; + const int channel_dst = blockIdx.y; + const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); + const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); + const int sample_y = sample_dst; + const int tid = threadIdx.x; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; + dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU; + const T * gate_x = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr; + glu_op = fusion.glu_op; + + if (use_gate) { + gate_x = static_cast(fusion.gate); + } + if (use_bias) { + x_bias = static_cast(fusion.x_bias); + } + if (use_gate_bias) { + gate_bias = static_cast(fusion.gate_bias); + use_gate_bias = use_gate; + } else { + use_gate_bias = false; + } + } + + if (use_gate) { + gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + } + if constexpr (has_fusion) { + const int channel_bias = ids ? channel_x : channel_dst; + if (use_bias) { + x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + if (use_gate_bias) { + gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + } + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + float * buf_iw_gate = nullptr; + if constexpr (has_fusion) { + buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); + } + + if (block_size > warp_size) { + if (tid < warp_size) { + buf_iw[tid] = 0.0f; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid] = 0.0f; + } + } + } + __syncthreads(); + } + + float sumf[ncols_dst] = {0.0f}; + float sumf_gate[ncols_dst]; + if constexpr (has_fusion) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = 0.0f; + } + } + + if constexpr (std::is_same_v) { + const float2 * x2 = (const float2 *) x; + const float2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const float2 *) gate_x; + } + } + + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = x2[col2]; + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } + } else if constexpr (std::is_same_v) { + const half2 * x2 = (const half2 *) x; + const half2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const half2 *) gate_x; + } + } + + if (std::is_same_v) { + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = __half22float2(gate_x2[col2]); + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; + half2 sumh2_gate[ncols_dst]; + if constexpr (has_fusion) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumh2_gate[j] = make_half2(0.0f, 0.0f); + } + } + + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const half2 tmpx = x2[col2]; + half2 tmpx_gate = make_half2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y); + } + } + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); + } + + if constexpr (has_fusion) { + if (use_gate) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]); + } + } + } +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + } else if constexpr (std::is_same_v) { +#if defined(GGML_USE_HIP) + const int * x2 = (const int *) x; + const int * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const int *) gate_x; + } + } + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const int tmpx = x2[col2]; + int tmpx_gate = 0; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + const float tmpx0_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[0]); + const float tmpx1_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[1]); + ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y); + } + } + } + } +#else + const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + const nv_bfloat162 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const nv_bfloat162 *) gate_x; + } + } + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const nv_bfloat162 tmpx = x2[col2]; + nv_bfloat162 tmpx_gate; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); -void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } +#endif + } else { + static_assert(std::is_same_v, "unsupported type"); + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = warp_reduce_sum(sumf[j]); + + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } + + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf[j]; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid/warp_size] = sumf_gate[j]; + } + } + __syncthreads(); + if (tid < warp_size) { + sumf[j] = buf_iw[tid]; + sumf[j] = warp_reduce_sum(sumf[j]); + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = buf_iw_gate[tid]; + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } + } + + if (j < ncols_dst) { + __syncthreads(); + } + } + } + + if (tid >= ncols_dst) { + return; + } + + float value = sumf[tid]; + + if constexpr (has_fusion) { + if (use_bias) { + value += x_bias[tid*stride_col_dst + row]; + } + + if (use_gate) { + float gate_value = sumf_gate[tid]; + if (use_gate_bias) { + gate_value += gate_bias[tid*stride_col_dst + row]; + } + switch (glu_op) { + case GGML_GLU_OP_SWIGLU: + value *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + value *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + value = ggml_cuda_op_swiglu_oai_single(gate_value, value); + break; + } + default: + break; + } + } + } + + dst[tid*stride_col_dst + row] = value; +} + +template +static void mul_mat_vec_f_switch_fusion( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if (has_fusion) { + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template +void launch_mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t block_size_best = warp_size; + int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); + int64_t max_block_size = 256; + if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { + max_block_size = 128; + } + for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + + const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); + const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 64: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 96: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 128: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 160: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 192: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 224: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 256: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +template +static void mul_mat_vec_f_cuda_switch_ncols_dst( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 2: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 3: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 4: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 5: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 6: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 7: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 8: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +template +void mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + enum ggml_prec prec, cudaStream_t stream) { + + if constexpr(std::is_same_v) { + if (prec == GGML_PREC_DEFAULT) { + mul_mat_vec_f_cuda_switch_ncols_dst + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + return; + } + } + mul_mat_vec_f_cuda_switch_ncols_dst + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); +} + +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion = nullptr); void ggml_cuda_op_mul_mat_vec_f( ggml_backend_cuda_context & ctx, @@ -9,3 +569,34 @@ void ggml_cuda_op_mul_mat_vec_f( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); + +#define DECL_MMVF_CASE_HELPER(T, type_acc, ncols_dst) \ + template void launch_mul_mat_vec_f_cuda( \ + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, \ + const int64_t ncols, const int64_t nrows, \ + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, \ + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ + cudaStream_t stream); + +#define DECL_MMVF_CASE_EXTERN(ncols_dst) \ + extern DECL_MMVF_CASE_HELPER(float, float, ncols_dst) \ + extern DECL_MMVF_CASE_HELPER(half, half, ncols_dst) \ + extern DECL_MMVF_CASE_HELPER(half, float, ncols_dst) \ + extern DECL_MMVF_CASE_HELPER(nv_bfloat16, float, ncols_dst) + +#define DECL_MMVF_CASE(ncols_dst) \ + DECL_MMVF_CASE_HELPER(float, float, ncols_dst) \ + DECL_MMVF_CASE_HELPER(half, half, ncols_dst) \ + DECL_MMVF_CASE_HELPER(half, float, ncols_dst) \ + DECL_MMVF_CASE_HELPER(nv_bfloat16, float, ncols_dst) + +DECL_MMVF_CASE_EXTERN(1); +DECL_MMVF_CASE_EXTERN(2); +DECL_MMVF_CASE_EXTERN(3); +DECL_MMVF_CASE_EXTERN(4); +DECL_MMVF_CASE_EXTERN(5); +DECL_MMVF_CASE_EXTERN(6); +DECL_MMVF_CASE_EXTERN(7); +DECL_MMVF_CASE_EXTERN(8); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 3bf0c9ed25038..45c61d2ba0d1b 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,335 +1,7 @@ #include "mmvq.cuh" -#include "quantize.cuh" -#include "vecdotq.cuh" - -#include - -typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); - -static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; - case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; - case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; - case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; - case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; - case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; - case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; - case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; - case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; - case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; - case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; - case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; - case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; - case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; - case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; - case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; - case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; - case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; - case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; - case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; - default: return nullptr; - } -} - -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; - case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; - case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; - case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; - case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; - case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; - case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; - case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; - case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; - case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; - case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; - case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; - case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; - case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; - case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; - case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; - case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; - case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; - default: return 1; - } -} - -enum mmvq_parameter_table_id { - MMVQ_PARAMETERS_GENERIC = 0, - MMVQ_PARAMETERS_GCN, - MMVQ_PARAMETERS_RDNA2 -}; - -static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) - return MMVQ_PARAMETERS_RDNA2; -#elif defined(GCN) || defined(CDNA) - return MMVQ_PARAMETERS_GCN; -#else - return MMVQ_PARAMETERS_GENERIC; -#endif -} - -static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { - return MMVQ_PARAMETERS_RDNA2; - } - if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { - return MMVQ_PARAMETERS_GCN; - } - return MMVQ_PARAMETERS_GENERIC; -} - -static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { - if (table_id == MMVQ_PARAMETERS_GENERIC) { - switch (ncols_dst) { - case 1: - case 2: - case 3: - case 4: - return 4; - case 5: - case 6: - case 7: - case 8: - return 2; - default: - return 1; - } - } else if (table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_dst) { - case 1: - case 2: - case 3: - case 4: - return 2; - case 5: - case 6: - case 7: - case 8: - default: - return 1; - } - } - return 1; -} - -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { - if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_dst) { - case 1: - return 1; - case 2: - case 3: - case 4: - case 5: - case 6: - case 7: - case 8: - return 2; - default: - return 1; - } - } - return 1; -} - -template -// tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) -static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, - const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, - const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, - const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { - - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int vdr = get_vdr_mmvq(type); - constexpr mmvq_parameter_table_id table_id = get_device_table_id(); - constexpr int nwarps = calc_nwarps(ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); - - const int tid = warp_size*threadIdx.y + threadIdx.x; - const int row0 = rows_per_cuda_block*blockIdx.x; - const int blocks_per_row_x = ncols_x / qk; - constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - - // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. - const uint32_t channel_dst = blockIdx.y; - const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - const uint32_t sample_dst = blockIdx.z; - const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); - const uint32_t sample_y = sample_dst; - - // partial sum for each thread - float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; - - const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; - const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; - - for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { - const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx - - // x block quant index when casting the quants to int - const int kqs = vdr * (tid % (qi/vdr)); - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { -#pragma unroll - for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp[j][i] += vec_dot_q_cuda( - vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); - } - } - } - - __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - if (threadIdx.y > 0) { -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { -#pragma unroll - for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; - } - } - } - __syncthreads(); - if (threadIdx.y > 0) { - return; - } - - dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; - - // sum up partial sums and write back result -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { -#pragma unroll - for (int i = 0; i < rows_per_cuda_block; ++i) { -#pragma unroll - for (int l = 0; l < nwarps-1; ++l) { - tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; - } - tmp[j][i] = warp_reduce_sum(tmp[j][i]); - } - - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { - dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; - } - } -} - -static std::pair calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); - const dim3 block_nums(nblocks, nchannels_y, nsamples_y); - const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); - return {block_nums, block_dims}; -} - -template -static void mul_mat_vec_q_switch_ncols_dst( - const void * vx, const void * vy, const int32_t * ids, float * dst, - const int ncols_x, const int nrows_x, const int ncols_dst, - const int stride_row_x, const int stride_col_y, const int stride_col_dst, - const int nchannels_x, const int nchannels_y, const int nchannels_dst, - const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { - - GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); - GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); - - const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); - const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); - const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); - - GGML_ASSERT(!ids || ncols_dst == 1); - switch (ncols_dst) { - case 1: { - constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 2: { - constexpr int c_ncols_dst = 2; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 3: { - constexpr int c_ncols_dst = 3; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 4: { - constexpr int c_ncols_dst = 4; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 5: { - constexpr int c_ncols_dst = 5; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 6: { - constexpr int c_ncols_dst = 6; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 7: { - constexpr int c_ncols_dst = 7; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 8: { - constexpr int c_ncols_dst = 8; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: - GGML_ABORT("fatal error"); - break; - } -} static void mul_mat_vec_q_switch_type( - const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst, + const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int stride_col_y, const int stride_col_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst, @@ -339,143 +11,123 @@ static void mul_mat_vec_q_switch_type( switch (type_x) { case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; default: GGML_ABORT("fatal error"); @@ -484,10 +136,11 @@ static void mul_mat_vec_q_switch_type( } void ggml_cuda_mul_mat_vec_q( - ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_TENSOR_BINARY_OP_LOCALS; @@ -502,13 +155,34 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 == 1); const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; - // If src0 is a temporary compute buffer, clear any potential padding. + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { const size_t size_data = ggml_nbytes(src0); const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); @@ -539,7 +213,6 @@ void ggml_cuda_mul_mat_vec_q( const int64_t s12 = ne11*s11; const int64_t s13 = ne12*s12; - // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_dst = ids ? ne1 : ne2; @@ -549,10 +222,10 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_y = ids ? s11 : s12; mul_mat_vec_q_switch_type( - src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00, + src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -571,15 +244,14 @@ void ggml_cuda_op_mul_mat_vec_q( int id = ggml_cuda_get_device(); - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; const int stride_row_x = ne00 / ggml_blck_size(src0->type); const int stride_col_y = src1_padded_row_size / QK8_1; + ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( - src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, + src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 39dc7d33eb5ac..cff8c3b3252ce 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,9 +1,477 @@ +#pragma once + #include "common.cuh" +#include "quantize.cuh" +#include "unary.cuh" +#include "vecdotq.cuh" + +#include +#include #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +using vec_dot_q_cuda_t = float (*)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); + +static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + default: return nullptr; + } +} + +static constexpr __device__ int get_vdr_mmvq(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + default: return 1; + } +} + +enum mmvq_parameter_table_id { + MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_GCN, + MMVQ_PARAMETERS_RDNA2 +}; + +static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { +#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) + return MMVQ_PARAMETERS_RDNA2; +#elif defined(GCN) || defined(CDNA) + return MMVQ_PARAMETERS_GCN; +#else + return MMVQ_PARAMETERS_GENERIC; +#endif +} + +static __host__ __forceinline__ mmvq_parameter_table_id get_device_table_id(const int cc) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA2; + } + if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { + return MMVQ_PARAMETERS_GCN; + } + return MMVQ_PARAMETERS_GENERIC; +} + +static constexpr __host__ __device__ int calc_nwarps(const int ncols_dst, const mmvq_parameter_table_id table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } else if (table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 2; + case 5: + case 6: + case 7: + case 8: + default: + return 1; + } + } + return 1; +} + +static constexpr __host__ __device__ int calc_rows_per_block(const int ncols_dst, const int table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + return 1; + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } + return 1; +} + +template +__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr mmvq_parameter_table_id table_id = get_device_table_id(); + constexpr int nwarps = calc_nwarps(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const int tid = warp_size*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + const uint32_t sample_dst = blockIdx.z; + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; + + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + const void * vgate = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + ggml_glu_op active_glu; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr && use_gate; + vgate = fusion.gate; + x_bias = (const float *) fusion.x_bias; + gate_bias = (const float *) fusion.gate_bias; + active_glu = fusion.glu_op; + } + + const uint32_t channel_bias = ids ? channel_x : channel_dst; + + if constexpr (has_fusion) { + if (use_bias) { + x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + if (use_gate_bias) { + gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + } + + float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + + const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda( + vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += vec_dot_q_cuda( + vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + if constexpr (!has_fusion) { + (void) tmp_shared_gate; + } else if (!use_gate) { + (void) tmp_shared_gate; + } + + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i]; + } + } + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x]; + } + } + } + tmp[j][i] = warp_reduce_sum(tmp[j][i]); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] = warp_reduce_sum(tmp_gate[j][i]); + } + } + } + + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { + float result = tmp[j][threadIdx.x]; + if constexpr (has_fusion) { + if (use_bias) { + result += x_bias[j*stride_col_dst + threadIdx.x]; + } + if (use_gate) { + float gate_value = tmp_gate[j][threadIdx.x]; + if (use_gate_bias) { + gate_value += gate_bias[j*stride_col_dst + threadIdx.x]; + } + switch (active_glu) { + case GGML_GLU_OP_SWIGLU: + result *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + result *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + result = ggml_cuda_op_swiglu_oai_single(gate_value, result); + break; + } + default: + result = result * gate_value; + break; + } + } + } + dst[j*stride_col_dst + threadIdx.x] = result; + } + } +} + +static inline std::pair calc_launch_params( + const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int warp_size, const mmvq_parameter_table_id table_id) { + const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); + return {block_nums, block_dims}; +} + +template +inline void mul_mat_vec_q_switch_fusion( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if (has_fusion) { + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template +void mul_mat_vec_q_switch_ncols_dst( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + cudaStream_t stream) { + + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); + GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); + + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + + GGML_ASSERT(!ids || ncols_dst == 1); + switch (ncols_dst) { + case 1: { + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 2: { + constexpr int c_ncols_dst = 2; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 3: { + constexpr int c_ncols_dst = 3; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 4: { + constexpr int c_ncols_dst = 4; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 5: { + constexpr int c_ncols_dst = 5; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 6: { + constexpr int c_ncols_dst = 6; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 7: { + constexpr int c_ncols_dst = 7; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 8: { + constexpr int c_ncols_dst = 8; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + default: + GGML_ABORT("fatal error"); + break; + } + + GGML_UNUSED(has_fusion); +} + +#define DECL_MMVQ_CASE(type) \ + template void mul_mat_vec_q_switch_ncols_dst( \ + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, \ + const int ncols_x, const int nrows_x, const int ncols_dst, \ + const int stride_row_x, const int stride_col_y, const int stride_col_dst, \ + const int nchannels_x, const int nchannels_y, const int nchannels_dst, \ + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, \ + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, \ + cudaStream_t stream); + +extern DECL_MMVQ_CASE(GGML_TYPE_Q4_0); +extern DECL_MMVQ_CASE(GGML_TYPE_Q4_1); +extern DECL_MMVQ_CASE(GGML_TYPE_Q5_0); +extern DECL_MMVQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMVQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMVQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMVQ_CASE(GGML_TYPE_Q2_K); +extern DECL_MMVQ_CASE(GGML_TYPE_Q3_K); +extern DECL_MMVQ_CASE(GGML_TYPE_Q4_K); +extern DECL_MMVQ_CASE(GGML_TYPE_Q5_K); +extern DECL_MMVQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ1_S); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ1_M); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ4_NL); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ4_XS); +extern DECL_MMVQ_CASE(GGML_TYPE_IQ3_S); + void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 81a986f38cacf..d48927f9337eb 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -45,6 +45,22 @@ DECL_MMQ_CASE({type}); """ +TYPES_MMVQ = TYPES_MMQ + ["GGML_TYPE_IQ1_M"] + +SOURCE_MMVQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE({type}); +""" + +SOURCE_MMVF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE({ncols}); +""" + SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../mmf.cuh" @@ -92,6 +108,14 @@ def get_short_name(long_quant_name): with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: f.write(SOURCE_MMQ.format(type=type)) +for type in TYPES_MMVQ: + with open(f"mmvq-instance-{get_short_name(type)}.cu", "w") as f: + f.write(SOURCE_MMVQ.format(type=type)) + +for ncols in range(1, 9): + with open(f"mmvf-instance-ncols_{ncols}.cu", "w") as f: + f.write(SOURCE_MMVF.format(ncols=ncols)) + for type in range(1, 17): with open(f"mmf-instance-ncols_{type}.cu", "w") as f: f.write(SOURCE_MMF.format(type=type)) diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu new file mode 100644 index 0000000000000..6fc3f7986b05b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(1); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu new file mode 100644 index 0000000000000..8b7459d9ee2ee --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(2); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu new file mode 100644 index 0000000000000..468491f8800d4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(3); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu new file mode 100644 index 0000000000000..7f7115f38fa5d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(4); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu new file mode 100644 index 0000000000000..407c275578b32 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(5); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu new file mode 100644 index 0000000000000..cfbad12e78c62 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(6); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu new file mode 100644 index 0000000000000..b88526ebe34df --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(7); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu new file mode 100644 index 0000000000000..86b293692b4d4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvf.cuh" + +DECL_MMVF_CASE(8); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu new file mode 100644 index 0000000000000..778c579d2a3b5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ1_M); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu new file mode 100644 index 0000000000000..9b157d0ec0ca7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ1_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu new file mode 100644 index 0000000000000..d9dc4a83f04de --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ2_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu new file mode 100644 index 0000000000000..91fa0dbb4c6e2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu new file mode 100644 index 0000000000000..49e5f73667413 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu new file mode 100644 index 0000000000000..098d4d0e099ac --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ3_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu new file mode 100644 index 0000000000000..c07376b02e649 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu new file mode 100644 index 0000000000000..4eca44cb27ea7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ4_NL); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu new file mode 100644 index 0000000000000..b36fcecc28ab0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu new file mode 100644 index 0000000000000..f4f12547fa94a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_MXFP4); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu new file mode 100644 index 0000000000000..9c984278abd53 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q2_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu new file mode 100644 index 0000000000000..80036cfab6b2b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q3_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu new file mode 100644 index 0000000000000..b000d8d3c4b73 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu new file mode 100644 index 0000000000000..747a7af93a0d6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu new file mode 100644 index 0000000000000..0ecd40853c2fa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu new file mode 100644 index 0000000000000..2fa4e67923a76 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu new file mode 100644 index 0000000000000..10b50256d1687 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu new file mode 100644 index 0000000000000..592af709d5f3f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q5_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu new file mode 100644 index 0000000000000..7386448ab5bb7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q6_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu new file mode 100644 index 0000000000000..0417846ef53f4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmvq.cuh" + +DECL_MMVQ_CASE(GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 3c564566a51ff..5f0d3a6726aef 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) { } static __device__ __forceinline__ float op_gelu(float x) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + return ggml_cuda_op_gelu_single(x); } static __device__ __forceinline__ float op_gelu_erf(float x) { @@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) { } static __device__ __forceinline__ float op_silu(float x) { - return x / (1.0f + expf(-x)); + return ggml_cuda_op_silu_single(x); } static __device__ __forceinline__ float op_tanh(float x) { @@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons float xi = x[j0]; float gi = g[j1]; - xi = fminf(xi, limit); - gi = fmaxf(fminf(gi, limit), -limit); - - float out_glu = xi / (1.0f + expf(-xi * alpha)); - out_glu = out_glu * (1.0f + gi); - dst[i] = out_glu; + dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit); } template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 8e7644fcd9a48..6c738cefecfd2 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,3 +1,4 @@ +#pragma once #include "common.cuh" #define CUDA_NEG_BLOCK_SIZE 256 @@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { + return x / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + + return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x))); +} + +__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) { + x = fminf(x, limit); + g = fmaxf(fminf(g, limit), -limit); + + float out_glu = x / (1.0f + expf(-x * alpha)); + out_glu = out_glu * (1.0f + g); + return out_glu; +} diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 41fa6894377ea..c1b946e3f715d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -810,6 +810,9 @@ ggml_tensor * llm_graph_context::build_ffn( GGML_ABORT("fatal error"); } + //expand here so that we can fuse ffn gate + ggml_build_forward_expand(gf, cur); + if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); @@ -1091,6 +1094,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } + //expand here so that we can fuse ffn gate + ggml_build_forward_expand(gf, cur); + experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9eb2b66879c0b..64e70c1051f9e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4721,6 +4721,140 @@ struct test_topk_moe: public test_case { } }; +struct test_fused_ffn_gate : public test_case { + const ggml_type type; + const ggml_glu_op glu_op; + const int64_t m; + const int64_t n; + const int64_t k; + const bool use_id; + const int n_mats; + const int n_used; + const bool b; // broadcast b matrix (only for use_id) + const bool with_bias; + const bool with_gate; + + test_fused_ffn_gate(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k, + bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true) + : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) { + if (use_id) { + GGML_ASSERT(n_used <= n_mats); + } + } + + std::string vars() override { + return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate); + } + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "FUSED_FFN_GATE"; + } + + bool run_whole_graph() override { return true; } + + ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) { + ggml_tensor * out = nullptr; + if (with_gate) { + if (glu_op == GGML_GLU_OP_SWIGLU_OAI) { + constexpr float alpha = 1.702f; + constexpr float limit = 7.0f; + out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit); + } else { + out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op); + } + } + return out; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + if (!use_id) { + std::array ne = {k, m, 1, 1}; + std::array ne0 = {k, n, 1, 1}; + + ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr; + ggml_tensor * up = ggml_new_tensor(ctx, type, 4, ne0.data()); + + ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur); + if (with_bias) { + std::array bias_ne = {ffn_up->ne[0], 1, 1, 1}; + ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); + ffn_up = ggml_add(ctx, ffn_up, up_bias); + } + + ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr; + if (with_bias && with_gate) { + std::array bias_ne = {ffn_gate->ne[0], 1, 1, 1}; + ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); + ffn_gate = ggml_add(ctx, ffn_gate, gate_bias); + } + + ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + ggml_set_name(out, "out"); + return out; + } else { + ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats); + ggml_tensor * ups = ggml_new_tensor_3d(ctx, type, k, n, n_mats); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m); + + if (n_used != n_mats) { + ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0); + } + + ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m); + ggml_set_name(cur, "cur"); + + ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids); + if (with_bias) { + ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats); + ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids); + } + + ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr; + if (with_bias && with_gate) { + ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats); + ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids); + } + + ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + ggml_set_name(out, "out"); + return out; + } + } + + void initialize_tensors(ggml_context * ctx) override { + if (!use_id) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t); + } + } else { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i % n_mats; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } + } + + double max_nmse_err() override { + return 5e-3; + } +}; + // GGML_OP_SUM struct test_sum : public test_case { const ggml_type type; @@ -6983,6 +7117,32 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3})); + for (ggml_type type : base_types) { + for (bool with_gate : {false, true}) { + for (bool use_id : {false, true}) { + for (bool b : {false, true}) { + if (!use_id && b) { + continue; + } + for (bool with_bias : {false, true}) { + if (!with_gate && !with_bias) { + continue; + } + for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) { + if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) { + continue; + } + if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) { + continue; + } + test_cases.emplace_back(new test_fused_ffn_gate(type, glu_op, 1, 32, 256, use_id, 16, 8, b, with_bias, with_gate)); + } + } + } + } + } + } + for (bool with_norm : {false, true}) { test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); From 8366599a4c93402b052cc655998ec72bd9312ee1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 22 Oct 2025 16:27:49 +0800 Subject: [PATCH 2/8] fix hip build --- ggml/src/ggml-hip/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 6b499320e7b12..3479b79df1e5c 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -61,6 +61,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) +file(GLOB SRCS "../ggml-cuda/template-instances/mmv*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") From bf349cb674a5d6213237eb23cbf2e8a150094d29 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 22 Oct 2025 16:49:29 +0800 Subject: [PATCH 3/8] fix musa build --- ggml/src/ggml-musa/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977f90..15e38a841309f 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -36,6 +36,9 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/mmv*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + if (GGML_MUSA_MUDNN_COPY) file(GLOB SRCS "../ggml-musa/*.cu") From 010a23ad0f9bcdd8a00b49b6934700f7e3df8ad7 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 24 Oct 2025 00:17:07 +0800 Subject: [PATCH 4/8] only fuse ncols_dst=1 --- ggml/src/ggml-cuda/CMakeLists.txt | 4 - ggml/src/ggml-cuda/ggml-cuda.cu | 24 +- ggml/src/ggml-cuda/mmvf.cu | 562 ++++++++++++++++- ggml/src/ggml-cuda/mmvf.cuh | 590 ------------------ ggml/src/ggml-cuda/mmvq.cu | 444 ++++++++++++- ggml/src/ggml-cuda/mmvq.cuh | 468 -------------- .../template-instances/generate_cu_files.py | 24 - .../mmvf-instance-ncols_1.cu | 5 - .../mmvf-instance-ncols_2.cu | 5 - .../mmvf-instance-ncols_3.cu | 5 - .../mmvf-instance-ncols_4.cu | 5 - .../mmvf-instance-ncols_5.cu | 5 - .../mmvf-instance-ncols_6.cu | 5 - .../mmvf-instance-ncols_7.cu | 5 - .../mmvf-instance-ncols_8.cu | 5 - .../template-instances/mmvq-instance-iq1_m.cu | 5 - .../template-instances/mmvq-instance-iq1_s.cu | 5 - .../template-instances/mmvq-instance-iq2_s.cu | 5 - .../mmvq-instance-iq2_xs.cu | 5 - .../mmvq-instance-iq2_xxs.cu | 5 - .../template-instances/mmvq-instance-iq3_s.cu | 5 - .../mmvq-instance-iq3_xxs.cu | 5 - .../mmvq-instance-iq4_nl.cu | 5 - .../mmvq-instance-iq4_xs.cu | 5 - .../template-instances/mmvq-instance-mxfp4.cu | 5 - .../template-instances/mmvq-instance-q2_k.cu | 5 - .../template-instances/mmvq-instance-q3_k.cu | 5 - .../template-instances/mmvq-instance-q4_0.cu | 5 - .../template-instances/mmvq-instance-q4_1.cu | 5 - .../template-instances/mmvq-instance-q4_k.cu | 5 - .../template-instances/mmvq-instance-q5_0.cu | 5 - .../template-instances/mmvq-instance-q5_1.cu | 5 - .../template-instances/mmvq-instance-q5_k.cu | 5 - .../template-instances/mmvq-instance-q6_k.cu | 5 - .../template-instances/mmvq-instance-q8_0.cu | 5 - ggml/src/ggml-hip/CMakeLists.txt | 2 - ggml/src/ggml-musa/CMakeLists.txt | 3 - tests/test-backend-ops.cpp | 9 +- 38 files changed, 1024 insertions(+), 1246 deletions(-) delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 6f349f2a4a07e..3024775135966 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -50,10 +50,6 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/mmvq*.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/mmvf*.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmf*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3767189dadeda..36060ddfcbf6a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2106,10 +2106,16 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) { const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); - if (tensor->op == GGML_OP_MUL_MAT_ID) { - use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne[2] == 1; + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; } + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; + } + + return use_mul_mat_vec_f; } @@ -2125,8 +2131,13 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - if (tensor->op == GGML_OP_MUL_MAT_ID) { - use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne[2] == 1; + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; } return use_mul_mat_vec_q; @@ -2979,12 +2990,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; + std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; - - std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; + std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 4c67215b44336..a3ca6fb895846 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,7 +1,563 @@ -#include "ggml-cuda/common.cuh" +#include "ggml.h" +#include "common.cuh" +#include "unary.cuh" #include "mmvf.cuh" -#include "ggml.h" +template +static __global__ void mul_mat_vec_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const int row = blockIdx.x; + const int channel_dst = blockIdx.y; + const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); + const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); + const int sample_y = sample_dst; + const int tid = threadIdx.x; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; + dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU; + const T * gate_x = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr; + glu_op = fusion.glu_op; + + if (use_gate) { + gate_x = static_cast(fusion.gate); + } + if (use_bias) { + x_bias = static_cast(fusion.x_bias); + } + if (use_gate_bias) { + gate_bias = static_cast(fusion.gate_bias); + use_gate_bias = use_gate; + } else { + use_gate_bias = false; + } + } + + if (use_gate) { + gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + } + if constexpr (has_fusion) { + const int channel_bias = ids ? channel_x : channel_dst; + if (use_bias) { + x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + if (use_gate_bias) { + gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + } + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + float * buf_iw_gate = nullptr; + if constexpr (has_fusion) { + buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); + } + + if (block_size > warp_size) { + if (tid < warp_size) { + buf_iw[tid] = 0.0f; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid] = 0.0f; + } + } + } + __syncthreads(); + } + + float sumf[ncols_dst] = {0.0f}; + float sumf_gate[ncols_dst]; + if constexpr (has_fusion) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = 0.0f; + } + } + + if constexpr (std::is_same_v) { + const float2 * x2 = (const float2 *) x; + const float2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const float2 *) gate_x; + } + } + + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = x2[col2]; + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } + } else if constexpr (std::is_same_v) { + const half2 * x2 = (const half2 *) x; + const half2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const half2 *) gate_x; + } + } + + if (std::is_same_v) { + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = __half22float2(gate_x2[col2]); + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; + half2 sumh2_gate[ncols_dst]; + if constexpr (has_fusion) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumh2_gate[j] = make_half2(0.0f, 0.0f); + } + } + + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const half2 tmpx = x2[col2]; + half2 tmpx_gate = make_half2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y); + } + } + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); + } + + if constexpr (has_fusion) { + if (use_gate) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]); + } + } + } +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + } else if constexpr (std::is_same_v) { +#if defined(GGML_USE_HIP) + const int * x2 = (const int *) x; + const int * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const int *) gate_x; + } + } + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const int tmpx = x2[col2]; + int tmpx_gate = 0; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + const float tmpx0_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[0]); + const float tmpx1_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[1]); + ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y); + } + } + } + } +#else + const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + const nv_bfloat162 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const nv_bfloat162 *) gate_x; + } + } + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const nv_bfloat162 tmpx = x2[col2]; + nv_bfloat162 tmpx_gate; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } +#endif + } else { + static_assert(std::is_same_v, "unsupported type"); + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = warp_reduce_sum(sumf[j]); + + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } + + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf[j]; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid/warp_size] = sumf_gate[j]; + } + } + __syncthreads(); + if (tid < warp_size) { + sumf[j] = buf_iw[tid]; + sumf[j] = warp_reduce_sum(sumf[j]); + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = buf_iw_gate[tid]; + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } + } + + if (j < ncols_dst) { + __syncthreads(); + } + } + } + + if (tid >= ncols_dst) { + return; + } + + float value = sumf[tid]; + + if constexpr (has_fusion) { + if (use_bias) { + value += x_bias[tid*stride_col_dst + row]; + } + + if (use_gate) { + float gate_value = sumf_gate[tid]; + if (use_gate_bias) { + gate_value += gate_bias[tid*stride_col_dst + row]; + } + switch (glu_op) { + case GGML_GLU_OP_SWIGLU: + value *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + value *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + value = ggml_cuda_op_swiglu_oai_single(gate_value, value); + break; + } + default: + break; + } + } + } + + dst[tid*stride_col_dst + row] = value; +} + +template +static void mul_mat_vec_f_switch_fusion( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + + if constexpr (ncols_dst == 1) { + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if (has_fusion) { + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + return; + } + } + + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + +} + +template +void launch_mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t block_size_best = warp_size; + int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); + int64_t max_block_size = 256; + if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { + max_block_size = 128; + } + for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + + const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); + const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 64: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 96: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 128: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 160: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 192: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 224: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + case 256: { + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +template +static void mul_mat_vec_f_cuda_switch_ncols_dst( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 2: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 3: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 4: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 5: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 6: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 7: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 8: + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +template +static void mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + enum ggml_prec prec, cudaStream_t stream) { + + if constexpr(std::is_same_v) { + if (prec == GGML_PREC_DEFAULT) { + mul_mat_vec_f_cuda_switch_ncols_dst + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + return; + } + } + mul_mat_vec_f_cuda_switch_ncols_dst + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); +} void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion) { @@ -33,6 +589,8 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor ggml_cuda_mm_fusion_args_device fusion_local{}; if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); if (fusion->x_bias) { GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index dc476a4b169d7..a205aa8e4c538 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,563 +1,4 @@ -#pragma once - -#include "ggml.h" #include "common.cuh" -#include "convert.cuh" -#include "unary.cuh" - -#include - -template -static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, - const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { - const int row = blockIdx.x; - const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); - const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); - const int sample_y = sample_dst; - const int tid = threadIdx.x; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - bool use_gate = false; - bool use_bias = false; - bool use_gate_bias = false; - ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU; - const T * gate_x = nullptr; - const float * x_bias = nullptr; - const float * gate_bias = nullptr; - - if constexpr (has_fusion) { - use_gate = fusion.gate != nullptr; - use_bias = fusion.x_bias != nullptr; - use_gate_bias = fusion.gate_bias != nullptr; - glu_op = fusion.glu_op; - - if (use_gate) { - gate_x = static_cast(fusion.gate); - } - if (use_bias) { - x_bias = static_cast(fusion.x_bias); - } - if (use_gate_bias) { - gate_bias = static_cast(fusion.gate_bias); - use_gate_bias = use_gate; - } else { - use_gate_bias = false; - } - } - - if (use_gate) { - gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; - } - if constexpr (has_fusion) { - const int channel_bias = ids ? channel_x : channel_dst; - if (use_bias) { - x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; - } - if (use_gate_bias) { - gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; - } - } - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - float * buf_iw = (float *) data_mmv; - float * buf_iw_gate = nullptr; - if constexpr (has_fusion) { - buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); - } - - if (block_size > warp_size) { - if (tid < warp_size) { - buf_iw[tid] = 0.0f; - if constexpr (has_fusion) { - if (use_gate) { - buf_iw_gate[tid] = 0.0f; - } - } - } - __syncthreads(); - } - - float sumf[ncols_dst] = {0.0f}; - float sumf_gate[ncols_dst]; - if constexpr (has_fusion) { -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf_gate[j] = 0.0f; - } - } - - if constexpr (std::is_same_v) { - const float2 * x2 = (const float2 *) x; - const float2 * gate_x2 = nullptr; - if constexpr (has_fusion) { - if (use_gate) { - gate_x2 = (const float2 *) gate_x; - } - } - - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = x2[col2]; - float2 tmpx_gate = make_float2(0.0f, 0.0f); - if constexpr (has_fusion) { - if (use_gate) { - tmpx_gate = gate_x2[col2]; - } - } - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); - - if constexpr (has_fusion) { - if (use_gate) { - ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); - ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); - } - } - } - } - } else if constexpr (std::is_same_v) { - const half2 * x2 = (const half2 *) x; - const half2 * gate_x2 = nullptr; - if constexpr (has_fusion) { - if (use_gate) { - gate_x2 = (const half2 *) gate_x; - } - } - - if (std::is_same_v) { - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = __half22float2(x2[col2]); - float2 tmpx_gate = make_float2(0.0f, 0.0f); - if constexpr (has_fusion) { - if (use_gate) { - tmpx_gate = __half22float2(gate_x2[col2]); - } - } -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); - - if constexpr (has_fusion) { - if (use_gate) { - ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); - ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); - } - } - } - } - } else { -#ifdef FP16_AVAILABLE - half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; - half2 sumh2_gate[ncols_dst]; - if constexpr (has_fusion) { -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumh2_gate[j] = make_half2(0.0f, 0.0f); - } - } - - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const half2 tmpx = x2[col2]; - half2 tmpx_gate = make_half2(0.0f, 0.0f); - if constexpr (has_fusion) { - if (use_gate) { - tmpx_gate = gate_x2[col2]; - } - } -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); - - if constexpr (has_fusion) { - if (use_gate) { - sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y); - } - } - } - } - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); - } - - if constexpr (has_fusion) { - if (use_gate) { -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]); - } - } - } -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE - } - } else if constexpr (std::is_same_v) { -#if defined(GGML_USE_HIP) - const int * x2 = (const int *) x; - const int * gate_x2 = nullptr; - if constexpr (has_fusion) { - if (use_gate) { - gate_x2 = (const int *) gate_x; - } - } - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const int tmpx = x2[col2]; - int tmpx_gate = 0; - if constexpr (has_fusion) { - if (use_gate) { - tmpx_gate = gate_x2[col2]; - } - } -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); - const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); - ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); - - if constexpr (has_fusion) { - if (use_gate) { - const float tmpx0_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[0]); - const float tmpx1_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[1]); - ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x); - ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y); - } - } - } - } -#else - const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; - const nv_bfloat162 * gate_x2 = nullptr; - if constexpr (has_fusion) { - if (use_gate) { - gate_x2 = (const nv_bfloat162 *) gate_x; - } - } - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const nv_bfloat162 tmpx = x2[col2]; - nv_bfloat162 tmpx_gate; - if constexpr (has_fusion) { - if (use_gate) { - tmpx_gate = gate_x2[col2]; - } - } -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); - - if constexpr (has_fusion) { - if (use_gate) { - ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); - ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); - } - } - } - } -#endif - } else { - static_assert(std::is_same_v, "unsupported type"); - } - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf[j] = warp_reduce_sum(sumf[j]); - - if constexpr (has_fusion) { - if (use_gate) { - sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); - } - } - - if (block_size > warp_size) { - buf_iw[tid/warp_size] = sumf[j]; - if constexpr (has_fusion) { - if (use_gate) { - buf_iw_gate[tid/warp_size] = sumf_gate[j]; - } - } - __syncthreads(); - if (tid < warp_size) { - sumf[j] = buf_iw[tid]; - sumf[j] = warp_reduce_sum(sumf[j]); - if constexpr (has_fusion) { - if (use_gate) { - sumf_gate[j] = buf_iw_gate[tid]; - sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); - } - } - } - - if (j < ncols_dst) { - __syncthreads(); - } - } - } - - if (tid >= ncols_dst) { - return; - } - - float value = sumf[tid]; - - if constexpr (has_fusion) { - if (use_bias) { - value += x_bias[tid*stride_col_dst + row]; - } - - if (use_gate) { - float gate_value = sumf_gate[tid]; - if (use_gate_bias) { - gate_value += gate_bias[tid*stride_col_dst + row]; - } - switch (glu_op) { - case GGML_GLU_OP_SWIGLU: - value *= ggml_cuda_op_silu_single(gate_value); - break; - case GGML_GLU_OP_GEGLU: - value *= ggml_cuda_op_gelu_single(gate_value); - break; - case GGML_GLU_OP_SWIGLU_OAI: { - value = ggml_cuda_op_swiglu_oai_single(gate_value, value); - break; - } - default: - break; - } - } - } - - dst[tid*stride_col_dst + row] = value; -} - -template -static void mul_mat_vec_f_switch_fusion( - const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { - - const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; - if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } else { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } -} - -template -void launch_mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - GGML_ASSERT(ncols % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); - const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t block_size_best = warp_size; - int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); - int64_t max_block_size = 256; - if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { - max_block_size = 128; - } - for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { - const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); - if (niter < niter_best) { - niter_best = niter; - block_size_best = block_size; - } - } - - const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; - - const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); - const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); - const dim3 block_dims(block_size_best, 1, 1); - switch (block_size_best) { - case 32: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - case 64: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - case 96: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - case 128: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - case 160: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - case 192: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - case 224: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - case 256: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_vec_f_cuda_switch_ncols_dst( - const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 2: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 3: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 4: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 5: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 6: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 7: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 8: - launch_mul_mat_vec_f_cuda - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } -} - -template -void mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - enum ggml_prec prec, cudaStream_t stream) { - - if constexpr(std::is_same_v) { - if (prec == GGML_PREC_DEFAULT) { - mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - return; - } - } - mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); -} void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); @@ -569,34 +10,3 @@ void ggml_cuda_op_mul_mat_vec_f( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); - -#define DECL_MMVF_CASE_HELPER(T, type_acc, ncols_dst) \ - template void launch_mul_mat_vec_f_cuda( \ - const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, \ - const int64_t ncols, const int64_t nrows, \ - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, \ - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ - cudaStream_t stream); - -#define DECL_MMVF_CASE_EXTERN(ncols_dst) \ - extern DECL_MMVF_CASE_HELPER(float, float, ncols_dst) \ - extern DECL_MMVF_CASE_HELPER(half, half, ncols_dst) \ - extern DECL_MMVF_CASE_HELPER(half, float, ncols_dst) \ - extern DECL_MMVF_CASE_HELPER(nv_bfloat16, float, ncols_dst) - -#define DECL_MMVF_CASE(ncols_dst) \ - DECL_MMVF_CASE_HELPER(float, float, ncols_dst) \ - DECL_MMVF_CASE_HELPER(half, half, ncols_dst) \ - DECL_MMVF_CASE_HELPER(half, float, ncols_dst) \ - DECL_MMVF_CASE_HELPER(nv_bfloat16, float, ncols_dst) - -DECL_MMVF_CASE_EXTERN(1); -DECL_MMVF_CASE_EXTERN(2); -DECL_MMVF_CASE_EXTERN(3); -DECL_MMVF_CASE_EXTERN(4); -DECL_MMVF_CASE_EXTERN(5); -DECL_MMVF_CASE_EXTERN(6); -DECL_MMVF_CASE_EXTERN(7); -DECL_MMVF_CASE_EXTERN(8); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 45c61d2ba0d1b..8b164d3a7002f 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,5 +1,442 @@ #include "mmvq.cuh" +#include "quantize.cuh" +#include "unary.cuh" +#include "vecdotq.cuh" +#include + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); + +static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + default: return nullptr; + } +} + +static constexpr __device__ int get_vdr_mmvq(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + default: return 1; + } +} + +enum mmvq_parameter_table_id { + MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_GCN, + MMVQ_PARAMETERS_RDNA2 +}; + +static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { +#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) + return MMVQ_PARAMETERS_RDNA2; +#elif defined(GCN) || defined(CDNA) + return MMVQ_PARAMETERS_GCN; +#else + return MMVQ_PARAMETERS_GENERIC; +#endif +} + +static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA2; + } + if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { + return MMVQ_PARAMETERS_GCN; + } + return MMVQ_PARAMETERS_GENERIC; +} + +static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } else if (table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 2; + case 5: + case 6: + case 7: + case 8: + default: + return 1; + } + } + return 1; +} + +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + return 1; + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } + return 1; +} + +template +__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr mmvq_parameter_table_id table_id = get_device_table_id(); + constexpr int nwarps = calc_nwarps(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const int tid = warp_size*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + const uint32_t sample_dst = blockIdx.z; + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; + + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + const void * vgate = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + ggml_glu_op active_glu; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr && use_gate; + vgate = fusion.gate; + x_bias = (const float *) fusion.x_bias; + gate_bias = (const float *) fusion.gate_bias; + active_glu = fusion.glu_op; + } + + const uint32_t channel_bias = ids ? channel_x : channel_dst; + + if constexpr (has_fusion) { + if (use_bias) { + x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + if (use_gate_bias) { + gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + } + + float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + + const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda( + vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += vec_dot_q_cuda( + vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + if constexpr (!has_fusion) { + (void) tmp_shared_gate; + } else if (!use_gate) { + (void) tmp_shared_gate; + } + + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i]; + } + } + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x]; + } + } + } + tmp[j][i] = warp_reduce_sum(tmp[j][i]); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] = warp_reduce_sum(tmp_gate[j][i]); + } + } + } + + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { + float result = tmp[j][threadIdx.x]; + if constexpr (has_fusion) { + if (use_bias) { + result += x_bias[j*stride_col_dst + threadIdx.x]; + } + if (use_gate) { + float gate_value = tmp_gate[j][threadIdx.x]; + if (use_gate_bias) { + gate_value += gate_bias[j*stride_col_dst + threadIdx.x]; + } + switch (active_glu) { + case GGML_GLU_OP_SWIGLU: + result *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + result *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + result = ggml_cuda_op_swiglu_oai_single(gate_value, result); + break; + } + default: + result = result * gate_value; + break; + } + } + } + dst[j*stride_col_dst + threadIdx.x] = result; + } + } +} + +static inline std::pair calc_launch_params( + const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int warp_size, const mmvq_parameter_table_id table_id) { + const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); + return {block_nums, block_dims}; +} + +template +static void mul_mat_vec_q_switch_fusion( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + + if constexpr (c_ncols_dst == 1) { + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if (has_fusion) { + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + return; + } + } + + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); +} + +template +static void mul_mat_vec_q_switch_ncols_dst( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + cudaStream_t stream) { + + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); + GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); + + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + + GGML_ASSERT(!ids || ncols_dst == 1); + switch (ncols_dst) { + case 1: { + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 2: { + constexpr int c_ncols_dst = 2; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 3: { + constexpr int c_ncols_dst = 3; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 4: { + constexpr int c_ncols_dst = 4; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 5: { + constexpr int c_ncols_dst = 5; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 6: { + constexpr int c_ncols_dst = 6; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 7: { + constexpr int c_ncols_dst = 7; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + case 8: { + constexpr int c_ncols_dst = 8; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); + } break; + default: + GGML_ABORT("fatal error"); + break; + } + + GGML_UNUSED(has_fusion); +} static void mul_mat_vec_q_switch_type( const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int ncols_x, const int nrows_x, const int ncols_dst, @@ -140,7 +577,7 @@ void ggml_cuda_mul_mat_vec_q( const ggml_cuda_mm_fusion_args_host * fusion) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. GGML_TENSOR_BINARY_OP_LOCALS; @@ -155,7 +592,7 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; @@ -164,6 +601,9 @@ void ggml_cuda_mul_mat_vec_q( ggml_cuda_mm_fusion_args_device fusion_local{}; if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + if (fusion->x_bias) { GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index cff8c3b3252ce..4bb10cfaec2b6 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,475 +1,7 @@ -#pragma once - #include "common.cuh" -#include "quantize.cuh" -#include "unary.cuh" -#include "vecdotq.cuh" - -#include -#include #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. -using vec_dot_q_cuda_t = float (*)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); - -static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; - case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; - case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; - case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; - case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; - case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; - case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; - case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; - case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; - case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; - case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; - case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; - case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; - case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; - case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; - case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; - case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; - case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; - case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; - case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; - default: return nullptr; - } -} - -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; - case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; - case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; - case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; - case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; - case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; - case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; - case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; - case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; - case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; - case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; - case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; - case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; - case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; - case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; - case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; - case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; - case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; - default: return 1; - } -} - -enum mmvq_parameter_table_id { - MMVQ_PARAMETERS_GENERIC = 0, - MMVQ_PARAMETERS_GCN, - MMVQ_PARAMETERS_RDNA2 -}; - -static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) - return MMVQ_PARAMETERS_RDNA2; -#elif defined(GCN) || defined(CDNA) - return MMVQ_PARAMETERS_GCN; -#else - return MMVQ_PARAMETERS_GENERIC; -#endif -} - -static __host__ __forceinline__ mmvq_parameter_table_id get_device_table_id(const int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { - return MMVQ_PARAMETERS_RDNA2; - } - if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { - return MMVQ_PARAMETERS_GCN; - } - return MMVQ_PARAMETERS_GENERIC; -} - -static constexpr __host__ __device__ int calc_nwarps(const int ncols_dst, const mmvq_parameter_table_id table_id) { - if (table_id == MMVQ_PARAMETERS_GENERIC) { - switch (ncols_dst) { - case 1: - case 2: - case 3: - case 4: - return 4; - case 5: - case 6: - case 7: - case 8: - return 2; - default: - return 1; - } - } else if (table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_dst) { - case 1: - case 2: - case 3: - case 4: - return 2; - case 5: - case 6: - case 7: - case 8: - default: - return 1; - } - } - return 1; -} - -static constexpr __host__ __device__ int calc_rows_per_block(const int ncols_dst, const int table_id) { - if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_dst) { - case 1: - return 1; - case 2: - case 3: - case 4: - case 5: - case 6: - case 7: - case 8: - return 2; - default: - return 1; - } - } - return 1; -} - -template -__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) -static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, - const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, - const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, - const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { - - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int vdr = get_vdr_mmvq(type); - constexpr mmvq_parameter_table_id table_id = get_device_table_id(); - constexpr int nwarps = calc_nwarps(ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); - - const int tid = warp_size*threadIdx.y + threadIdx.x; - const int row0 = rows_per_cuda_block*blockIdx.x; - const int blocks_per_row_x = ncols_x / qk; - constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - - const uint32_t channel_dst = blockIdx.y; - const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - const uint32_t sample_dst = blockIdx.z; - const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); - const uint32_t sample_y = sample_dst; - - bool use_gate = false; - bool use_bias = false; - bool use_gate_bias = false; - const void * vgate = nullptr; - const float * x_bias = nullptr; - const float * gate_bias = nullptr; - ggml_glu_op active_glu; - - if constexpr (has_fusion) { - use_gate = fusion.gate != nullptr; - use_bias = fusion.x_bias != nullptr; - use_gate_bias = fusion.gate_bias != nullptr && use_gate; - vgate = fusion.gate; - x_bias = (const float *) fusion.x_bias; - gate_bias = (const float *) fusion.gate_bias; - active_glu = fusion.glu_op; - } - - const uint32_t channel_bias = ids ? channel_x : channel_dst; - - if constexpr (has_fusion) { - if (use_bias) { - x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; - } - if (use_gate_bias) { - gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; - } - } - - float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; - float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; - - const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; - const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; - - for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { - const int kby = kbx * (qk/QK8_1); - const int kqs = vdr * (tid % (qi/vdr)); - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { -#pragma unroll - for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp[j][i] += vec_dot_q_cuda( - vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); - if constexpr (has_fusion) { - if (use_gate) { - tmp_gate[j][i] += vec_dot_q_cuda( - vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); - } - } - } - } - } - - __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - if constexpr (!has_fusion) { - (void) tmp_shared_gate; - } else if (!use_gate) { - (void) tmp_shared_gate; - } - - if (threadIdx.y > 0) { -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { -#pragma unroll - for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; - if constexpr (has_fusion) { - if (use_gate) { - tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i]; - } - } - } - } - } - __syncthreads(); - if (threadIdx.y > 0) { - return; - } - - dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { -#pragma unroll - for (int i = 0; i < rows_per_cuda_block; ++i) { -#pragma unroll - for (int l = 0; l < nwarps-1; ++l) { - tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; - if constexpr (has_fusion) { - if (use_gate) { - tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x]; - } - } - } - tmp[j][i] = warp_reduce_sum(tmp[j][i]); - if constexpr (has_fusion) { - if (use_gate) { - tmp_gate[j][i] = warp_reduce_sum(tmp_gate[j][i]); - } - } - } - - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { - float result = tmp[j][threadIdx.x]; - if constexpr (has_fusion) { - if (use_bias) { - result += x_bias[j*stride_col_dst + threadIdx.x]; - } - if (use_gate) { - float gate_value = tmp_gate[j][threadIdx.x]; - if (use_gate_bias) { - gate_value += gate_bias[j*stride_col_dst + threadIdx.x]; - } - switch (active_glu) { - case GGML_GLU_OP_SWIGLU: - result *= ggml_cuda_op_silu_single(gate_value); - break; - case GGML_GLU_OP_GEGLU: - result *= ggml_cuda_op_gelu_single(gate_value); - break; - case GGML_GLU_OP_SWIGLU_OAI: { - result = ggml_cuda_op_swiglu_oai_single(gate_value, result); - break; - } - default: - result = result * gate_value; - break; - } - } - } - dst[j*stride_col_dst + threadIdx.x] = result; - } - } -} - -static inline std::pair calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); - const dim3 block_nums(nblocks, nchannels_y, nsamples_y); - const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); - return {block_nums, block_dims}; -} - -template -inline void mul_mat_vec_q_switch_fusion( - const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, - const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, - const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { - const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; - if (has_fusion) { - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } else { - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } -} - -template -void mul_mat_vec_q_switch_ncols_dst( - const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int ncols_x, const int nrows_x, const int ncols_dst, - const int stride_row_x, const int stride_col_y, const int stride_col_dst, - const int nchannels_x, const int nchannels_y, const int nchannels_dst, - const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { - - GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); - GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); - - const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); - const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); - const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); - - const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; - - GGML_ASSERT(!ids || ncols_dst == 1); - switch (ncols_dst) { - case 1: { - constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - case 2: { - constexpr int c_ncols_dst = 2; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - case 3: { - constexpr int c_ncols_dst = 3; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - case 4: { - constexpr int c_ncols_dst = 4; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - case 5: { - constexpr int c_ncols_dst = 5; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - case 6: { - constexpr int c_ncols_dst = 6; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - case 7: { - constexpr int c_ncols_dst = 7; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - case 8: { - constexpr int c_ncols_dst = 8; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); - } break; - default: - GGML_ABORT("fatal error"); - break; - } - - GGML_UNUSED(has_fusion); -} - -#define DECL_MMVQ_CASE(type) \ - template void mul_mat_vec_q_switch_ncols_dst( \ - const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, \ - const int ncols_x, const int nrows_x, const int ncols_dst, \ - const int stride_row_x, const int stride_col_y, const int stride_col_dst, \ - const int nchannels_x, const int nchannels_y, const int nchannels_dst, \ - const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, \ - const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, \ - cudaStream_t stream); - -extern DECL_MMVQ_CASE(GGML_TYPE_Q4_0); -extern DECL_MMVQ_CASE(GGML_TYPE_Q4_1); -extern DECL_MMVQ_CASE(GGML_TYPE_Q5_0); -extern DECL_MMVQ_CASE(GGML_TYPE_Q5_1); -extern DECL_MMVQ_CASE(GGML_TYPE_Q8_0); -extern DECL_MMVQ_CASE(GGML_TYPE_MXFP4); -extern DECL_MMVQ_CASE(GGML_TYPE_Q2_K); -extern DECL_MMVQ_CASE(GGML_TYPE_Q3_K); -extern DECL_MMVQ_CASE(GGML_TYPE_Q4_K); -extern DECL_MMVQ_CASE(GGML_TYPE_Q5_K); -extern DECL_MMVQ_CASE(GGML_TYPE_Q6_K); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ2_XXS); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ2_XS); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ2_S); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ3_XXS); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ1_S); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ1_M); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ4_NL); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ4_XS); -extern DECL_MMVQ_CASE(GGML_TYPE_IQ3_S); - void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d48927f9337eb..81a986f38cacf 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -45,22 +45,6 @@ DECL_MMQ_CASE({type}); """ -TYPES_MMVQ = TYPES_MMQ + ["GGML_TYPE_IQ1_M"] - -SOURCE_MMVQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE({type}); -""" - -SOURCE_MMVF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE({ncols}); -""" - SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../mmf.cuh" @@ -108,14 +92,6 @@ def get_short_name(long_quant_name): with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: f.write(SOURCE_MMQ.format(type=type)) -for type in TYPES_MMVQ: - with open(f"mmvq-instance-{get_short_name(type)}.cu", "w") as f: - f.write(SOURCE_MMVQ.format(type=type)) - -for ncols in range(1, 9): - with open(f"mmvf-instance-ncols_{ncols}.cu", "w") as f: - f.write(SOURCE_MMVF.format(ncols=ncols)) - for type in range(1, 17): with open(f"mmf-instance-ncols_{type}.cu", "w") as f: f.write(SOURCE_MMF.format(type=type)) diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu deleted file mode 100644 index 6fc3f7986b05b..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(1); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu deleted file mode 100644 index 8b7459d9ee2ee..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_2.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(2); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu deleted file mode 100644 index 468491f8800d4..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_3.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(3); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu deleted file mode 100644 index 7f7115f38fa5d..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_4.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(4); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu deleted file mode 100644 index 407c275578b32..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_5.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(5); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu deleted file mode 100644 index cfbad12e78c62..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_6.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(6); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu deleted file mode 100644 index b88526ebe34df..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_7.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(7); diff --git a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu b/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu deleted file mode 100644 index 86b293692b4d4..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvf-instance-ncols_8.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvf.cuh" - -DECL_MMVF_CASE(8); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu deleted file mode 100644 index 778c579d2a3b5..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_m.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ1_M); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu deleted file mode 100644 index 9b157d0ec0ca7..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq1_s.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ1_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu deleted file mode 100644 index d9dc4a83f04de..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_s.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ2_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu deleted file mode 100644 index 91fa0dbb4c6e2..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xs.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu deleted file mode 100644 index 49e5f73667413..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq2_xxs.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu deleted file mode 100644 index 098d4d0e099ac..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_s.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ3_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu deleted file mode 100644 index c07376b02e649..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq3_xxs.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu deleted file mode 100644 index 4eca44cb27ea7..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_nl.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ4_NL); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu deleted file mode 100644 index b36fcecc28ab0..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-iq4_xs.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu deleted file mode 100644 index f4f12547fa94a..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-mxfp4.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_MXFP4); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu deleted file mode 100644 index 9c984278abd53..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q2_k.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q2_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu deleted file mode 100644 index 80036cfab6b2b..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q3_k.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q3_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu deleted file mode 100644 index b000d8d3c4b73..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu deleted file mode 100644 index 747a7af93a0d6..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu deleted file mode 100644 index 0ecd40853c2fa..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q4_k.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu deleted file mode 100644 index 2fa4e67923a76..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu deleted file mode 100644 index 10b50256d1687..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu deleted file mode 100644 index 592af709d5f3f..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q5_k.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q5_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu deleted file mode 100644 index 7386448ab5bb7..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q6_k.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q6_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu b/ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu deleted file mode 100644 index 0417846ef53f4..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/mmvq-instance-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../mmvq.cuh" - -DECL_MMVQ_CASE(GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 3479b79df1e5c..6b499320e7b12 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -61,8 +61,6 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) -file(GLOB SRCS "../ggml-cuda/template-instances/mmv*.cu") -list(APPEND GGML_SOURCES_ROCM ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 15e38a841309f..d76cb51977f90 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -36,9 +36,6 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/mmv*.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - if (GGML_MUSA_MUDNN_COPY) file(GLOB SRCS "../ggml-musa/*.cu") diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 64e70c1051f9e..33ac27ff5ca00 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4721,7 +4721,7 @@ struct test_topk_moe: public test_case { } }; -struct test_fused_ffn_gate : public test_case { +struct test_mul_mat_vec_fusion : public test_case { const ggml_type type; const ggml_glu_op glu_op; const int64_t m; @@ -4734,7 +4734,7 @@ struct test_fused_ffn_gate : public test_case { const bool with_bias; const bool with_gate; - test_fused_ffn_gate(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k, + test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k, bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true) : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) { if (use_id) { @@ -4748,7 +4748,7 @@ struct test_fused_ffn_gate : public test_case { std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); - return "FUSED_FFN_GATE"; + return "MUL_MAT_VEC_FUSION"; } bool run_whole_graph() override { return true; } @@ -7135,7 +7135,8 @@ static std::vector> make_test_cases_eval() { if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) { continue; } - test_cases.emplace_back(new test_fused_ffn_gate(type, glu_op, 1, 32, 256, use_id, 16, 8, b, with_bias, with_gate)); + test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256, + use_id, 16, 8, b, with_bias, with_gate)); } } } From e212c8586c6d772f4b97598b13096ab69048a674 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 24 Oct 2025 11:17:49 +0800 Subject: [PATCH 5/8] add missing header --- ggml/src/ggml-cuda/mmvf.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index a3ca6fb895846..7da33f31ffd29 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -2,6 +2,7 @@ #include "common.cuh" #include "unary.cuh" #include "mmvf.cuh" +#include "convert.cuh" template static __global__ void mul_mat_vec_f( From d67fcb87023c0bf0d6ba9fab032f4a2dd05557f1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 25 Oct 2025 12:38:25 +0800 Subject: [PATCH 6/8] check fusion=false for ncols_dst!=1 --- ggml/src/ggml-cuda/mmvf.cu | 12 ++++-------- ggml/src/ggml-cuda/mmvq.cu | 5 ++++- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 7da33f31ffd29..abd0f0570bb72 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -163,13 +163,7 @@ static __global__ void mul_mat_vec_f( } else { #ifdef FP16_AVAILABLE half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; - half2 sumh2_gate[ncols_dst]; - if constexpr (has_fusion) { -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumh2_gate[j] = make_half2(0.0f, 0.0f); - } - } + half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}}; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const half2 tmpx = x2[col2]; @@ -359,8 +353,8 @@ static void mul_mat_vec_f_switch_fusion( const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { - const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if (has_fusion) { mul_mat_vec_f<<>> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, @@ -370,6 +364,8 @@ static void mul_mat_vec_f_switch_fusion( } } + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + mul_mat_vec_f<<>> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8b164d3a7002f..c15ac3051dadc 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -161,6 +161,7 @@ static __global__ void mul_mat_vec_q( const int blocks_per_row_x = ncols_x / qk; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; + // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. const uint32_t channel_dst = blockIdx.y; const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; @@ -324,8 +325,8 @@ static void mul_mat_vec_q_switch_fusion( const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { - const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if (has_fusion) { mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, @@ -335,6 +336,8 @@ static void mul_mat_vec_q_switch_fusion( } } + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, From 65a098f97790793b32a0bcdc0c4b829027f4fd28 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 25 Oct 2025 12:45:42 +0800 Subject: [PATCH 7/8] add back comments --- ggml/src/ggml-cuda/mmvf.cu | 1 + ggml/src/ggml-cuda/mmvq.cu | 13 +++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index abd0f0570bb72..c2c31cdaf231b 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -204,6 +204,7 @@ static __global__ void mul_mat_vec_f( #endif // FP16_AVAILABLE } } else if constexpr (std::is_same_v) { +//TODO: add support for ggml_cuda_mad for hip_bfloat162 #if defined(GGML_USE_HIP) const int * x2 = (const int *) x; const int * gate_x2 = nullptr; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index c15ac3051dadc..7a783e4fcf9b4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -137,6 +137,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } +// tell the compiler to use as many registers as it wants, see nwarps definition below template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( @@ -198,6 +199,7 @@ static __global__ void mul_mat_vec_q( } } + // partial sum for each thread float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; @@ -205,7 +207,9 @@ static __global__ void mul_mat_vec_q( const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { - const int kby = kbx * (qk/QK8_1); + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx + + // x block quant index when casting the quants to int const int kqs = vdr * (tid % (qi/vdr)); #pragma unroll @@ -253,6 +257,7 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { #pragma unroll @@ -307,7 +312,7 @@ static __global__ void mul_mat_vec_q( } } -static inline std::pair calc_launch_params( +static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, const int warp_size, const mmvq_parameter_table_id table_id) { const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); @@ -626,6 +631,7 @@ void ggml_cuda_mul_mat_vec_q( fusion_local.glu_op = fusion->glu_op; } + // If src0 is a temporary compute buffer, clear any potential padding. if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { const size_t size_data = ggml_nbytes(src0); const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); @@ -656,6 +662,7 @@ void ggml_cuda_mul_mat_vec_q( const int64_t s12 = ne11*s11; const int64_t s13 = ne12*s12; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_dst = ids ? ne1 : ne2; @@ -687,6 +694,8 @@ void ggml_cuda_op_mul_mat_vec_q( int id = ggml_cuda_get_device(); + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; const int stride_row_x = ne00 / ggml_blck_size(src0->type); From 975ef381c55d6fac60332fb2c9bc459d811c6eb4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 26 Oct 2025 18:15:16 +0800 Subject: [PATCH 8/8] don't use mmvq in pascal and lower --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 36060ddfcbf6a..19f72975c0ee4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2131,6 +2131,11 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + // fusion is not universally faster on Pascal + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (cc <= GGML_CUDA_CC_PASCAL) { + return false; + } //we only support fusion for ncols_dst = 1 if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { return false;