From d2485730eea2237c5efdca6005330f5e3bbec235 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Tue, 26 Aug 2025 13:14:38 -0700 Subject: [PATCH 01/21] tests : fix test-opt with GGML_BACKEND_DL (#15599) --- tests/test-opt.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index 18d3fcf2cb948..8dcb4a7dbfd62 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -3,7 +3,6 @@ #include "ggml.h" #include "ggml-alloc.h" #include "ggml-backend.h" -#include "ggml-cpu.h" #include "ggml-opt.h" #include @@ -899,6 +898,7 @@ static std::pair test_backend( int main(void) { ggml_log_set(nullptr, nullptr); + ggml_backend_load_all(); const size_t dev_count = ggml_backend_dev_count(); printf("Testing %zu devices\n\n", dev_count); size_t n_ok = 0; @@ -911,11 +911,12 @@ int main(void) { ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL); GGML_ASSERT(backend != NULL); -#ifndef _MSC_VER - if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2); + + auto * reg = ggml_backend_dev_backend_reg(devs[i]); + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency() / 2); } -#endif backends.push_back(backend); } From dccd7599ff3a98d69af0ff099e4968c41447b6ff Mon Sep 17 00:00:00 2001 From: rmatif Date: Wed, 27 Aug 2025 08:36:05 +0200 Subject: [PATCH 02/21] OpenCL: add fused group_norm/norm, mul, add (#15314) * add fused group_norm/norm, mul, add * fix spacing * revert rms_norm logic * fix trailing whitespace --- ggml/src/ggml-opencl/ggml-opencl.cpp | 189 ++++++++++++++++++++- ggml/src/ggml-opencl/kernels/group_norm.cl | 49 ++++++ ggml/src/ggml-opencl/kernels/norm.cl | 80 +++++++++ tests/test-backend-ops.cpp | 85 +++++++++ 4 files changed, 399 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 36b18ddb8a9ac..c25c2daaf60ea 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -420,9 +420,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_clamp; cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; - cl_kernel kernel_norm; + cl_kernel kernel_norm, kernel_norm_mul_add; cl_kernel kernel_rms_norm, kernel_rms_norm_mul; - cl_kernel kernel_group_norm; + cl_kernel kernel_group_norm, kernel_group_norm_mul_add; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; @@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_norm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err)); GGML_LOG_CONT("."); } @@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_group_norm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err)); GGML_LOG_CONT("."); } @@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { return false; } + } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { + const ggml_tensor *norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = cgraph->nodes[node_idx+2]; + const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0]; + const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + + // norm fusion only supports F32 + if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { + return false; + } + + if (norm->src[0]->ne[0] % 4 != 0) { + return false; + } + + if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { + return false; + } + } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { + const ggml_tensor *gn = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = cgraph->nodes[node_idx+2]; + const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0]; + const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + + if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { + return false; + } } return true; } static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); +static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); +static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { + ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { + ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); i++; @@ -5039,6 +5086,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { + GGML_ASSERT(norm_tensor && mul_tensor && add_tensor); + + const ggml_tensor * src0 = norm_tensor->src[0]; + const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; + const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; + const ggml_tensor * dst = add_tensor; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + float eps; + memcpy(&eps, norm_tensor->op_params, sizeof(float)); + + const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; + const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; + const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3]; + const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; + const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3]; + const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3]; + const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3]; + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) sgs = 64; + else if (backend_ctx->gpu_family == INTEL) sgs = 32; + else GGML_ASSERT(false && "Unsupported GPU"); + + cl_kernel kernel = backend_ctx->kernel_norm_mul_add; + + int nth = sgs; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2; + nth = MIN(nth, max_workgroup_size); + nth = MIN(nth, ne00/4); + + size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t lws[] = {(size_t)nth, 1, 1}; + size_t num_subgroups = (nth + sgs - 1) / sgs; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst); +} + +static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { + GGML_ASSERT(gn_tensor && mul_tensor && add_tensor); + + const ggml_tensor * src0 = gn_tensor->src[0]; + const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; + const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; + const ggml_tensor * dst = add_tensor; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + int groups; + float eps; + memcpy(&groups, gn_tensor->op_params, sizeof(int)); + memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float)); + + cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int ne = ggml_nelements(src0); + int group_size = ne / groups; + + size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) }; + size_t gws[] = { (size_t)groups * lws[0] }; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst); +} + static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); diff --git a/ggml/src/ggml-opencl/kernels/group_norm.cl b/ggml/src/ggml-opencl/kernels/group_norm.cl index 57c9df4d35b09..8e4fa0ed12d11 100644 --- a/ggml/src/ggml-opencl/kernels/group_norm.cl +++ b/ggml/src/ggml-opencl/kernels/group_norm.cl @@ -70,3 +70,52 @@ kernel void kernel_group_norm( dst[j] *= scale; } } + +//------------------------------------------------------------------------------ +// group_norm_mul_add +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_group_norm_mul_add( + global float * src0, ulong offset0, + global float * src1, ulong offset1, + global float * src2, ulong offset2, + global float * dst, ulong offsetd, + int ne, + int group_size, + float eps +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global float *)((global char *)src1 + offset1); + src2 = (global float *)((global char *)src2 + offset2); + dst = (global float *)((global char *)dst + offsetd); + + int start = get_group_id(0) * group_size; + int end = start + group_size; + if (end > ne) { + end = ne; + } + + float sum = 0.0f; + float sum_sq = 0.0f; + + for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) { + float val = src0[j]; + sum += val; + sum_sq += val*val; + } + + sum = sub_group_reduce_add(sum); + sum_sq = sub_group_reduce_add(sum_sq); + + const float mean = sum / group_size; + const float var = sum_sq / group_size - mean * mean; + const float scale = rsqrt(var + eps); + + for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) { + dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl index 43167ba4d2212..170f822787be2 100644 --- a/ggml/src/ggml-opencl/kernels/norm.cl +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -79,3 +79,83 @@ kernel void kernel_norm( y[i00] = y[i00] * scale; } } + +//------------------------------------------------------------------------------ +// norm_mul_add +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_norm_mul_add( + global char * src0_ptr, ulong src0_offset, + global char * src1_ptr, ulong src1_offset, + global char * src2_ptr, ulong src2_offset, + global char * dst_ptr, ulong dst_offset, + int ne00, int ne01, int ne02, int ne03, + ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb11, ulong nb12, ulong nb13, + int ne20, int ne21, int ne22, int ne23, + ulong nb21, ulong nb22, ulong nb23, + ulong nbd1, ulong nbd2, ulong nbd3, + float eps, + local float2 * sums +) { + const int i03 = get_group_id(2); + const int i02 = get_group_id(1); + const int i01 = get_group_id(0); + + global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03); + global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13); + global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23); + global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3); + + float p_sum = 0.0f; + float p_sum_sq = 0.0f; + + const int n_chunks = ne00 / 4; + for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) { + float4 val = x[i00]; + p_sum += val.x + val.y + val.z + val.w; + p_sum_sq += dot(val, val); + } + + p_sum = sub_group_reduce_add(p_sum); + p_sum_sq = sub_group_reduce_add(p_sum_sq); + + if (get_sub_group_local_id() == 0) { + sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq); + } + barrier(CLK_LOCAL_MEM_FENCE); + + if (get_local_id(0) == 0) { + float sum = 0.0f; + float sum_sq = 0.0f; + for (uint i = 0; i < get_num_sub_groups(); ++i) { + float2 s = sums[i]; + sum += s.x; + sum_sq += s.y; + } + + const float inv_ne00 = 1.0f / (float)ne00; + const float mean = sum * inv_ne00; + const float variance = mad(-mean, mean, sum_sq * inv_ne00); + + sums[0] = (float2)(mean, rsqrt(variance + eps)); + } + barrier(CLK_LOCAL_MEM_FENCE); + + const float2 mean_scale = sums[0]; + const float mean = mean_scale.x; + const float scale = mean_scale.y; + const float neg_mean_scale = -mean * scale; + + for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) { + const int w_idx = ne10 > 1 ? i00 : 0; + const int b_idx = ne20 > 1 ? i00 : 0; + const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale); + y[i00] = mad(norm_x, w[w_idx], b[b_idx]); + } +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c84023e05e984..3a58621094d17 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2789,6 +2789,49 @@ struct test_norm : public test_case { } }; +// GGML_OP_NORM + GGML_OP_MUL + GGML_OP_ADD +struct test_norm_mul_add : public test_case { + const ggml_type type; + const std::array ne; + float eps; + const bool broadcast; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "NORM_MUL_ADD"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR4(type, ne, eps, broadcast); + } + + test_norm_mul_add(ggml_type type = GGML_TYPE_F32, + std::array ne = {128, 2, 1, 1}, + float eps = 1e-5f, + bool broadcast = false) + : type(type), ne(ne), eps(eps), broadcast(broadcast) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + std::array broadcast_dims = {ne[0], ne[1] * 2, ne[2] * 2, ne[3] * 2}; + + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data()); + ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); ggml_set_param(w); ggml_set_param(b); + ggml_set_name(a, "a"); ggml_set_name(w, "w"); ggml_set_name(b, "b"); + + // Use a, w and b early to avoid OP_NONE in graph + a = ggml_add(ctx, ggml_add(ctx, a, w), b); + + ggml_tensor * n = ggml_norm(ctx, a, eps); + ggml_tensor * m = ggml_mul(ctx, n, w); + ggml_tensor * out = ggml_add(ctx, m, b); + ggml_set_name(out, "out"); + return out; + } +}; // GGML_OP_RMS_NORM struct test_rms_norm : public test_case { const ggml_type type; @@ -4475,6 +4518,44 @@ struct test_group_norm : public test_case { } }; +// GGML_OP_GROUP_NORM + GGML_OP_MUL + GGML_OP_ADD +struct test_group_norm_mul_add : public test_case { + const ggml_type type; + const std::array ne; + int num_groups; + float eps; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "GROUP_NORM_MUL_ADD"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR4(type, ne, num_groups, eps); + } + + test_group_norm_mul_add(ggml_type type = GGML_TYPE_F32, + std::array ne = {128, 1, 1, 1}, + int num_groups = 4, + float eps = 1e-5f) + : type(type), ne(ne), num_groups(num_groups), eps(eps) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); ggml_set_param(w); ggml_set_param(b); + ggml_set_name(a, "a"); ggml_set_name(w, "w"); ggml_set_name(b, "b"); + ggml_tensor * n = ggml_group_norm(ctx, a, num_groups, eps); + ggml_tensor * m = ggml_mul(ctx, n, w); + ggml_tensor * out = ggml_add(ctx, m, b); + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_L2_NORM struct test_l2_norm : public test_case { const ggml_type type; @@ -5865,6 +5946,8 @@ static std::vector> make_test_cases_eval() { for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); + test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); + test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); } for (uint32_t n : {1, 511, 1025, 8192, 33*512}) { for (bool multi_add : {false, true}) { @@ -6253,6 +6336,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); + test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1})); + test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_pad_reflect_1d()); From 1b76e8bfda86a2664108cccfca942f03b92119cd Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 27 Aug 2025 10:28:53 +0200 Subject: [PATCH 03/21] common : add -m to bash completion for --model [no ci] (#15591) This commit updates the bash completion script to include the -m short option for the --model argument. The motivation for this is that currently tab completion only works the full --model option, and it is nice to have it work for the short option as well. --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 81c4005c5e7fc..1ae3fdbf4a802 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1106,7 +1106,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) { printf("\"\n\n"); printf(" case \"$prev\" in\n"); - printf(" --model)\n"); + printf(" --model|-m)\n"); printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n"); printf(" return 0\n"); printf(" ;;\n"); From dbd8ebc9aef308e939a711bd4cbb468e6b85e9bc Mon Sep 17 00:00:00 2001 From: xctan Date: Wed, 27 Aug 2025 16:44:22 +0800 Subject: [PATCH 04/21] ggml-cpu : add basic RVV support for vector f32 ops (#15057) * ggml-cpu : add basic RVV support for vector f32 ops * ggml-cpu : add RVV support for f32 softmax --- ggml/src/ggml-cpu/CMakeLists.txt | 2 +- ggml/src/ggml-cpu/ops.cpp | 7 +- ggml/src/ggml-cpu/simd-mappings.h | 53 +++++++++++---- ggml/src/ggml-cpu/vec.cpp | 21 +++++- ggml/src/ggml-cpu/vec.h | 104 +++++++++++++++++++++++++++++- 5 files changed, 168 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index ce0a3e1285eb0..b70302ec8c9b1 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -435,7 +435,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ) if (GGML_RVV) if (GGML_XTHEADVECTOR) - list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) + list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d) elseif (GGML_RV_ZFH) list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d) else() diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 460367cca09e9..93330b43a9b84 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9072,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32( } sumf = GGML_F32xt_REDUCE_ONE(sum); + #elif defined(__riscv_v_intrinsic) + // todo: RVV implementation + const int np = 0; #else const int np = (nc & ~(GGML_F32_STEP - 1)); @@ -10023,8 +10026,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_stride_2d = head_size * head_size; #if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - // scalar Route to scalar implementation //TODO: Write SVE code + #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic) + // scalar Route to scalar implementation //TODO: Write SVE code and RVV code for (int64_t t = 0; t < T; t++) { int64_t t_offset = t * t_stride; int64_t state_offset = head_size * C * (t / (T / n_seqs)); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index b4ad68c9fd647..f71ce580799f6 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -18,6 +18,10 @@ #include #endif +#if defined(__riscv_v_intrinsic) +#include +#endif + #ifdef __cplusplus extern "C" { #endif @@ -94,24 +98,15 @@ extern "C" { } #elif defined(__riscv) && defined(__riscv_zfhmin) static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) { - float f; - __asm__( - "fmv.h.x %[f], %[h]\n\t" - "fcvt.s.h %[f], %[f]" - : [f] "=&f" (f) - : [h] "r" (h) - ); - return f; + _Float16 hf; + memcpy(&hf, &h, sizeof(ggml_fp16_t)); + return hf; } static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) { ggml_fp16_t res; - __asm__( - "fcvt.h.s %[f], %[f]\n\t" - "fmv.x.h %[h], %[f]" - : [h] "=&r" (res) - : [f] "f" (f) - ); + _Float16 hf = (_Float16)f; + memcpy(&res, &hf, sizeof(ggml_fp16_t)); return res; } @@ -1170,6 +1165,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +#elif defined(__riscv_v_intrinsic) + +// compatible with vlen >= 128 + +#define GGML_SIMD + +// F32 + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vfloat32m1_t +#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR) +#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR) +#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR) +#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR) +#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR) + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + #endif // GGML_F32_ARR / GGML_F16_ARR diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 07b377bdd82a7..d8ec3b81d2446 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -84,6 +84,16 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G } // reduce sum1,sum2 to sum1 GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); + #elif defined(__riscv_v_intrinsic) + vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1); + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl); + vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl); + } + sumf += __riscv_vfmv_f_s_f32m1_f32(vsum); #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -197,7 +207,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G ggml_float sumf = 0.0; -#if defined(GGML_SIMD) +#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic) const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; @@ -325,6 +335,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } +#elif defined(__riscv_v_intrinsic) + vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1); + for (int avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl); + __riscv_vse32_v_f32m2(&y[i], val, avl); + vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl); + } + return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum); #endif for (; i < n; ++i) { float val = expf(x[i] - max); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 2250d93cb00d1..8ccf340d472ad 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -119,6 +119,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } #if defined(GGML_SIMD) +#if defined(__riscv_v_intrinsic) + // todo: RVV impl + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + } +#else const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; @@ -149,6 +157,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); } } +#endif #else for (int i = 0; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { @@ -243,6 +252,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const svst1_f32(pg, y + np2, ay1); } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -276,6 +293,13 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { #if defined(GGML_SIMD) +#if defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); + } +#else const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); @@ -297,6 +321,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } +#endif #else // scalar for (int i = 0; i < n; ++i) { @@ -324,6 +349,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int y[i] += x[k][i]*v[k][0]; } } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) { + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl); + ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl); + } + __riscv_vse32_v_f32m8(&y[i], ay, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -375,6 +410,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co for (int i = 0; i < n; ++i) { y[i] = x[i]*s + b; } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -436,6 +479,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { ay1 = svmul_f32_m(pg, ay1, vx); svst1_f32(pg, y + np, ay1); } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -467,6 +517,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { #if defined(GGML_SIMD) +#if defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } +#else const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); @@ -486,6 +543,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); } +#endif #else // scalar for (int i = 0; i < n; ++i) { @@ -928,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) { return _mm_div_ps(x, one_plus_exp_neg_x); } -#endif // __ARM_NEON / __AVX2__ / __SSE2__ +#elif defined(__riscv_v_intrinsic) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); +#ifdef __riscv_xtheadvector + // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4') + vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl); + z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl); +#else + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); +#endif + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), + 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), + u, vl), u, vl); + if (!__riscv_vcpop_m_b16(c, vl)) + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2( + __riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), + c, vl); + return __riscv_vmerge_vvm_f32m2( + r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), + vl); +} + +#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { From f01044cd4fa510eba07a87c91d271dff1787132e Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 27 Aug 2025 17:21:41 +0800 Subject: [PATCH 05/21] CANN: refactor mask handling and improve performance in FA (#15561) * CANN(flash-attn): refactor mask handling and improve performance 1. Refactored the mask computation in Flash Attention, unified the logic without separating prefill and decode. 2. Optimized performance in non-alibi scenarios by reducing one repeat operation. 3. Updated operator management to explicitly mark unsupported cases on 310P devices and when dim is not divisible by 16. Signed-off-by: noemotiovon <757486878@qq.com> * [CANN]: fix review Signed-off-by: noemotiovon <757486878@qq.com> * [CANN]: Optimization FA BNSD to BSND Signed-off-by: noemotiovon <757486878@qq.com> --------- Signed-off-by: noemotiovon <757486878@qq.com> --- ggml/src/ggml-cann/aclnn_ops.cpp | 175 ++++++++++++++++--------------- ggml/src/ggml-cann/ggml-cann.cpp | 12 ++- 2 files changed, 98 insertions(+), 89 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index bc33b99d96ea6..c42871c575822 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1427,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, float m, int64_t size, float start, float stop, float step){ int64_t ne[] = {size}; - size_t nb[] = {sizeof(float)}; + size_t nb[] = {sizeof(uint16_t)}; - ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float)); + ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t)); void* arange_buffer = arange_allocator.get(); aclTensor* arange_tensor = ggml_cann_create_tensor( - arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); + arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1); aclnn_arange(ctx, arange_tensor, start, stop, step, size); aclTensor* slope_tensor = ggml_cann_create_tensor( - slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); + slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1); aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); @@ -3180,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ - ggml_tensor* src0 = dst->src[0]; // q, fp32 - ggml_tensor* src1 = dst->src[1]; // k, fp16 - ggml_tensor* src2 = dst->src[2]; // v, fp16 + ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont) + ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) + ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) ggml_tensor* src3 = dst->src[3]; // mask, fp16 + // B, N, S, D (uncont) -> B, S, N, D (cont) + int64_t src0_bsnd_ne[GGML_MAX_DIMS]; + memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t)); + size_t src0_bsnd_nb[GGML_MAX_DIMS]; + memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t)); + int64_t src1_bsnd_ne[GGML_MAX_DIMS]; + memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t)); + size_t src1_bsnd_nb[GGML_MAX_DIMS]; + memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t)); + int64_t src2_bsnd_ne[GGML_MAX_DIMS]; + memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t)); + size_t src2_bsnd_nb[GGML_MAX_DIMS]; + memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t)); + + auto transpose12 = [](int64_t* ne, size_t* nb) { + int64_t ne_tmp = ne[1]; + size_t nb_tmp = nb[1]; + ne[1] = ne[2]; + nb[1] = nb[2]; + ne[2] = ne_tmp; + nb[2] = nb_tmp; + }; + + transpose12(src0_bsnd_ne, src0_bsnd_nb); + transpose12(src1_bsnd_ne, src1_bsnd_nb); + transpose12(src2_bsnd_ne, src2_bsnd_nb); + float maxBias = 0.0f; float scaleValue = 1.0f; float logitSoftcap = 0.0f; @@ -3206,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ void* src0_f16_buffer = nullptr; if(ggml_cann_type_mapping(src0->type) != faDataType){ - aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, + src0_bsnd_nb, GGML_MAX_DIMS); src0_f16_buffer = src0_f16_allocator.alloc( ggml_nelements(src0) * faElemSize); - int64_t* src0_f16_ne = src0->ne; + int64_t* src0_f16_ne = src0_bsnd_ne; size_t src0_f16_nb[GGML_MAX_DIMS]; src0_f16_nb[0] = sizeof(uint16_t); for(int i = 1; i < GGML_MAX_DIMS; ++i){ @@ -3224,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); ggml_cann_release_resources(ctx, acl_src0_f32_tensor); }else{ - acl_src0_f16_tensor = ggml_cann_create_tensor(src0); + acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, + src0_bsnd_nb, GGML_MAX_DIMS); } // Step 2: create the acl tensors for src1 (Key), src2 (Value), // and the direct output from FusedInferAttention - acl_src1_f16_tensor = ggml_cann_create_tensor(src1); - acl_src2_f16_tensor = ggml_cann_create_tensor(src2); + acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, + src1_bsnd_nb, GGML_MAX_DIMS); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, + src2_bsnd_nb, GGML_MAX_DIMS); ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); void* out_f16_buffer = out_f16_allocator.alloc( ggml_nelements(dst) * faElemSize); - int64_t* out_f16_ne = src0->ne; + int64_t* out_f16_ne = src0_bsnd_ne; size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; for(int i = 1; i < GGML_MAX_DIMS; ++i){ @@ -3251,88 +3282,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp - aclTensor* bcast_pse_tensor = nullptr; - int64_t bcast_pse_ne[GGML_MAX_DIMS]; - size_t bcast_pse_nb[GGML_MAX_DIMS]; ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); - void* bcast_pse_buffer = nullptr; - if(src3 != nullptr){ - bcast_pse_buffer = bcast_pse_allocator.alloc( - ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); - - if(src0->ne[1] > 1){ - // Case 1: broadcast pse for prefill stage with multiple head - aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); - bcast_pse_ne[0] = src3->ne[0]; - bcast_pse_ne[1] = src3->ne[1]; - bcast_pse_ne[2] = src0->ne[2]; - bcast_pse_ne[3] = src3->ne[3]; + // Construct the truncated pse tensor (common for prefill/decode) + int64_t trunc_pse_ne[GGML_MAX_DIMS] = { + src3->ne[0], // D + src0->ne[1], // S (number of Q tokens) + src3->ne[2], // mask N + src3->ne[3] // B + }; + size_t* trunc_pse_nb = src3->nb; + + aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( + src3->data, ACL_FLOAT16, sizeof(uint16_t), + trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS + ); + int64_t bcast_pse_ne[GGML_MAX_DIMS]; + size_t bcast_pse_nb[GGML_MAX_DIMS]; + bcast_pse_ne[0] = src3->ne[0]; // D + bcast_pse_ne[1] = src0->ne[1]; // S + bcast_pse_ne[2] = src0->ne[2]; // N (num_heads) + bcast_pse_ne[3] = src3->ne[3]; // B + if (maxBias == 0.0f) { + // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2) + // Construct the bcast tensor (simulate repeat on the head dimension using stride=0) bcast_pse_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; - } + bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0]; + bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data + bcast_pse_nb[3] = src3->nb[3]; bcast_pse_tensor = ggml_cann_create_tensor( - bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), - bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); - - int64_t repeats[] = {1, src0->ne[2], 1, 1}; - aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); - - ggml_cann_release_resources(ctx, acl_mask_f16_tensor); - }else{ - // Case 2: trunc the first row and broadcast pse for decode stage with multiple head - int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; - size_t* trunc_pse_nb = src3->nb; - - aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( src3->data, ACL_FLOAT16, sizeof(uint16_t), - trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); - - bcast_pse_ne[0] = src3->ne[0]; - bcast_pse_ne[1] = src0->ne[1]; - bcast_pse_ne[2] = src0->ne[2]; - bcast_pse_ne[3] = src3->ne[3]; + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS + ); + ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); + } else { bcast_pse_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ + for (int i = 1; i < GGML_MAX_DIMS; i++) { bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; } + void* bcast_pse_buffer = bcast_pse_allocator.alloc( + ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t) + ); + bcast_pse_tensor = ggml_cann_create_tensor( bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), - bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS + ); int64_t repeats[] = {1, src0->ne[2], 1, 1}; aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); - ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); - } - - // Compute the slope if needed. Derived from ggml_cann_softmax(). - if(maxBias != 0.0f){ // alibi + // Compute the slope if needed. Derived from ggml_cann_softmax(). const int64_t n_heads = src0->ne[2]; - ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t)); void* slope_buffer = slope_allocator.get(); aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias); int64_t slope_ne[] = {1, 1, n_heads, 1}; size_t slope_nb[GGML_MAX_DIMS]; - slope_nb[0] = sizeof(float); + slope_nb[0] = sizeof(uint16_t); for(int i = 1;ine[0]); // 1/sqrt(d) int64_t preTokens = 65535; int64_t nextTokens = 65535; - char layout[5] = {'B', 'N', 'S', 'D', 0}; + char layout[5] = {'B', 'S', 'N', 'D', 0}; int64_t sparseMode = 0; int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; int64_t blockSize = 0; @@ -3386,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ); // Step 6: post-processing, permute and cast to f32 - - int64_t new_dim[] = {0, 2, 1, 3}; aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - - if(ggml_cann_type_mapping(dst->type) != faDataType){ - ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); - perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - void* perm_out_f16_buffer = perm_out_f16_allocator.get(); - - int64_t* perm_out_f16_ne = dst->ne; - size_t perm_out_f16_nb[GGML_MAX_DIMS]; - perm_out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; - } - aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( - perm_out_f16_buffer, faDataType, faElemSize, - perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); - aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); - aclnn_cast(ctx, - acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); - }else{ - // only need to permute - aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); - } + // TODO: when dst is fp16, don't need cast + aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cb8af42ebf956..81215425618a3 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2336,7 +2336,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: #ifdef ASCEND_310P - // Q4 && Q8 per group is not suppor on 310p device + // Q4 && Q8 per group is not support on 310p device return false; #endif // only support contiguous for quantized types. @@ -2354,7 +2354,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: #ifdef ASCEND_310P - // Q4 && Q8 per group is not suppor on 310p device + // Q4 && Q8 per group is not support on 310p device return false; #endif // only support contiguous for quantized types. @@ -2505,6 +2505,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } return true; case GGML_OP_FLASH_ATTN_EXT:{ +#ifdef ASCEND_310P + // FA not support on 310p device + return false; +#endif // derived from [ggml-cuda.cu] if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ return false; @@ -2530,6 +2534,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // DeepSeek MLA return false; } + if (op->src[0]->ne[0] % 16 != 0) { + // TODO: padding to support + return false; + } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); if(logitSoftcap != 0.0f) { From 4456c0caaea459a96f9703d529abafac81dc12b9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 27 Aug 2025 13:55:12 +0300 Subject: [PATCH 06/21] kv-cache : better estimate of n_kv for multi-sequence batches (#15610) ggml-ci --- src/llama-kv-cache.cpp | 25 ++++++++++++------------- src/llama-kv-cache.h | 6 +++--- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index d7ab56ccd9aac..920c1d0dbdc74 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); } - res.s0 = std::min(res.s0, seq_to_stream[seq_id]); - res.s1 = std::max(res.s1, seq_to_stream[seq_id]); + res.s0 = std::min(res.s0, seq_to_stream[seq_id]); + res.s1 = std::max(res.s1, seq_to_stream[seq_id]); res.strm[s] = seq_to_stream[seq_id]; res.idxs[s].reserve(n_tokens); @@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const { return result; } -uint32_t llama_kv_cache::get_n_kv() const { +uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; - for (uint32_t s = 0; s < n_stream; ++s) { - const auto & cells = v_cells[s]; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const auto & cells = v_cells[sinfo.strm[s]]; result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); } @@ -1017,18 +1017,18 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] - ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] return ggml_view_4d(ctx, v, n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, - ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, kv_size), // v->nb[2] - ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } @@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() { } kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); - - n_kv = kv->get_n_kv(); + n_kv = kv->get_n_kv(sinfos[i_cur]); return true; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 76a5cb1e28e7e..3ca82917d3237 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -38,8 +38,8 @@ class llama_kv_cache : public llama_memory_i { using idx_vec_t = std::vector; // number of streams: ns = s1 - s0 + 1 - llama_seq_id s0; - llama_seq_id s1; + uint32_t s0; + uint32_t s1; std::vector strm; // [ns] std::vector idxs; // [ns] @@ -139,7 +139,7 @@ class llama_kv_cache : public llama_memory_i { // graph_build API // - uint32_t get_n_kv() const; + uint32_t get_n_kv(const slot_info & sinfo) const; // TODO: temporary bool get_supports_set_rows() const; From d9c27bff79b676aaab5d7a6846341b6fe9d1d0c1 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 27 Aug 2025 13:58:54 +0200 Subject: [PATCH 07/21] HIP: Enable support for ggml_backend_cuda_register_host_buffer (#15615) --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 449488341557f..3a50527248045 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3106,7 +3106,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { return false; } -#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error From ab02cc24e93cb6995ea5f6ad03ffdf05aaff0742 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 27 Aug 2025 15:48:07 +0300 Subject: [PATCH 08/21] presets : add qwen3-30B-a3b FIM (#15616) --- common/arg.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 1ae3fdbf4a802..d82f55890dee5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3538,6 +3538,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--fim-qwen-30b-default"}, + string_format("use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet)"), + [](common_params & params) { + params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF"; + params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf"; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( { "--diffusion-steps" }, "N", string_format("number of diffusion steps (default: %d)", params.diffusion.steps), From a11f1e58d5f5e687cf1c40d1af6feccfc7108b3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 27 Aug 2025 20:58:09 +0200 Subject: [PATCH 09/21] server: higher timeout for tests (#15621) --- tools/server/tests/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 5f42bcae61e1a..f55a539471194 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -26,10 +26,7 @@ import wget -DEFAULT_HTTP_TIMEOUT = 12 - -if "LLAMA_SANITIZE" in os.environ or "GITHUB_ACTION" in os.environ: - DEFAULT_HTTP_TIMEOUT = 30 +DEFAULT_HTTP_TIMEOUT = 30 class ServerResponse: From 0d3144a8a61dc42b92e10b6b3f969578e5d55ac8 Mon Sep 17 00:00:00 2001 From: matiaslin <45382001+matiaslin@users.noreply.github.com> Date: Wed, 27 Aug 2025 17:32:36 -0700 Subject: [PATCH 10/21] cuda: Add cublasLt_static linking when GGML_STATIC is enabled (#15622) Prior to this change, we faced undefined cublasLt references when attempting to compile 'llama-cli' with GGML_STATIC=ON on Linux. We add linking with CUDA::cublasLt_static when CUDA version is greater than 10.1. --- ggml/src/ggml-cuda/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index ea824965aae2d..d3dfc7807de0e 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -94,7 +94,11 @@ if (CUDAToolkit_FOUND) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) + endif() endif() else() target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) From 6a85048ea14b265094929adbd26e384b757ea0eb Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 Aug 2025 09:26:48 +0200 Subject: [PATCH 11/21] model-conversion : add mmproj conversion target (#15628) This commit adds a new target to the Makefile for converting models that are multimodal. This target will convert the original model and in addition also create the mmproj GGUF model. The motivation for this change is that for models that are multimodal, for example those that contain a vision encoders, we will often want to upload both the quantized model and the vision encoder model to HuggingFace. Example usage: ```console $ make causal-convert-mm-model MODEL_PATH=~/work/ai/models/gemma-3-4b-it-qat-q4_0-unquantized/ ... The environment variable CONVERTED_MODEL can be set to this path using: export CONVERTED_MODEL=/home/danbev/work/ai/llama.cpp/models/gemma-3-4b-it-qat-q4_0-unquantized.gguf The mmproj model was created in /home/danbev/work/ai/llama.cpp/models/mmproj-gemma-3-4b-it-qat-q4_0-unquantized.gguf ``` The converted original model can then be quantized, and after that both the quantized model and the mmproj file can then be uploaded to HuggingFace. Refs: https://huggingface.co/ggml-org/gemma-3-4b-it-qat-GGUF/tree/main --- examples/model-conversion/Makefile | 14 ++++++++ .../scripts/causal/convert-model.sh | 34 ++++++++++++++++--- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index 37982495b2655..03b928afda3e7 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -37,6 +37,20 @@ causal-convert-model: METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ ./scripts/causal/convert-model.sh +causal-convert-mm-model-bf16: OUTTYPE=bf16 +causal-convert-mm-model-bf16: MM_OUTTYPE=f16 +causal-convert-mm-model-bf16: causal-convert-mm-model + +causal-convert-mm-model: + $(call validate_model_path,causal-convert-mm-model) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/causal/convert-model.sh + + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(MM_OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/causal/convert-model.sh --mmproj + causal-run-original-model: $(call validate_model_path,causal-run-original-model) @MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py diff --git a/examples/model-conversion/scripts/causal/convert-model.sh b/examples/model-conversion/scripts/causal/convert-model.sh index 56b21f9baa673..9d95025950d42 100755 --- a/examples/model-conversion/scripts/causal/convert-model.sh +++ b/examples/model-conversion/scripts/causal/convert-model.sh @@ -1,5 +1,21 @@ #!/bin/bash +set -e + +# Parse command line arguments +MMPROJ="" +while [[ $# -gt 0 ]]; do + case $1 in + --mmproj) + MMPROJ="--mmproj" + shift + ;; + *) + shift + ;; + esac +done + MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}" OUTPUT_DIR="${OUTPUT_DIR:-../../models}" TYPE="${OUTTYPE:-f16}" @@ -11,12 +27,20 @@ echo "Model name: ${MODEL_NAME}" echo "Data type: ${TYPE}" echo "Converted model path:: ${CONVERTED_MODEL}" echo "Metadata override: ${METADATA_OVERRIDE}" -python ../../convert_hf_to_gguf.py --verbose \ - ${MODEL_PATH} \ - --outfile ${CONVERTED_MODEL} \ - --outtype ${TYPE} \ - --metadata "${METADATA_OVERRIDE}" + +CMD_ARGS=("python" "../../convert_hf_to_gguf.py" "--verbose") +CMD_ARGS+=("${MODEL_PATH}") +CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}") +CMD_ARGS+=("--outtype" "${TYPE}") +[[ -n "$METADATA_OVERRIDE" ]] && CMD_ARGS+=("--metadata" "${METADATA_OVERRIDE}") +[[ -n "$MMPROJ" ]] && CMD_ARGS+=("${MMPROJ}") + +"${CMD_ARGS[@]}" echo "" echo "The environment variable CONVERTED_MODEL can be set to this path using:" echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})" +if [[ -n "$MMPROJ" ]]; then + mmproj_file="${OUTPUT_DIR}/mmproj-$(basename "${CONVERTED_MODEL}")" + echo "The mmproj model was created in $(realpath "$mmproj_file")" +fi From 852ab81784f6d3819d3d7fb4d1addbd14d0a8454 Mon Sep 17 00:00:00 2001 From: Joshua Cogliati Date: Thu, 28 Aug 2025 01:48:20 -0600 Subject: [PATCH 12/21] cli : change log to warning to explain reason for stopping (#15604) * Change to warn instead of debug, to explain reason for stopping. * Update tools/main/main.cpp Fix printing --2 Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- tools/main/main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/main/main.cpp b/tools/main/main.cpp index dc776f59e90a4..865ea4a2f724c 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -587,12 +587,12 @@ int main(int argc, char ** argv) { if (n_past + (int) embd.size() >= n_ctx) { if (!params.ctx_shift){ - LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__); + LOG_WRN("\n\n%s: context full and context shift is disabled => stopping\n", __func__); break; } if (params.n_predict == -2) { - LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + LOG_WRN("\n\n%s: context full and n_predict == %d => stopping\n", __func__, params.n_predict); break; } From d4e26506a6e386d8e789dcb37242849f0e5d9527 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> Date: Thu, 28 Aug 2025 10:56:41 +0200 Subject: [PATCH 13/21] gguf-py: byteswapping improvements (#12851) * gguf-py: implement byteswapping for Q4_0 This is needed to byteswap Mistral model. Also restore original shapes after byteswapping tensors. It is not needed at the moment, but do it in case they'd be used in future. * Rework byteswapping code in gguf-py Move out details from byteswapping tensor blocks code --- gguf-py/gguf/scripts/gguf_convert_endian.py | 130 ++++++++++---------- 1 file changed, 67 insertions(+), 63 deletions(-) diff --git a/gguf-py/gguf/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py index 0e0febaa79178..211a3f536a6a9 100755 --- a/gguf-py/gguf/scripts/gguf_convert_endian.py +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -19,6 +19,61 @@ logger = logging.getLogger("gguf-convert-endian") +def byteswap_q4_0(tensor, block_offs): + # Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations. + + # Byte-Swap f16 sized delta field + delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + +def byteswap_q8_0(tensor, block_offs): + # Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations. + + # Byte-Swap f16 sized delta field + delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + +def byteswap_q4_k(tensor, block_offs): + # Each block_q4_k consists of 2 f16 values followed by 140 int8 values. + + # Byte-Swap f16 sized fields + delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + +def byteswap_q6_k(tensor, block_offs): + # Each block_q6_k consists of 208 int8 values followed by 1 f16 value. + + # Byte-Swap f16 sized field + delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16) + delta.byteswap(inplace=True) + + +byteswap_tensors = { + gguf.GGMLQuantizationType.Q4_0: { + "block_size": 18, # 18 bytes = + 16 * + "byteswap_func": byteswap_q4_0, + }, + gguf.GGMLQuantizationType.Q8_0: { + "block_size": 34, # 34 bytes = + 32 * + "byteswap_func": byteswap_q8_0, + }, + gguf.GGMLQuantizationType.Q4_K: { + "block_size": 144, # 144 bytes = 2 * + 140 * + "byteswap_func": byteswap_q4_k, + }, + gguf.GGMLQuantizationType.Q6_K: { + "block_size": 210, # 210 bytes = + 208 * + "byteswap_func": byteswap_q6_k, + }, +} + + def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None: file_endian = reader.endianess.name if reader.byte_order == 'S': @@ -32,13 +87,11 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None sys.exit(0) logger.info("* Checking tensors for conversion compatibility") for tensor in reader.tensors: - if tensor.tensor_type not in ( - gguf.GGMLQuantizationType.F32, - gguf.GGMLQuantizationType.F16, - gguf.GGMLQuantizationType.Q8_0, - gguf.GGMLQuantizationType.Q4_K, - gguf.GGMLQuantizationType.Q6_K, - ): + if tensor.tensor_type not in byteswap_tensors and \ + tensor.tensor_type not in ( + gguf.GGMLQuantizationType.F32, + gguf.GGMLQuantizationType.F16, + ): raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}") logger.info(f"* Preparing to convert from {file_endian} to {order}") if args.dry_run: @@ -72,78 +125,29 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None part.byteswap(inplace=True) # Byte-swap tensor data if necessary - if tensor.tensor_type == gguf.GGMLQuantizationType.Q8_0: - # Handle Q8_0 tensor blocks (block_q8_0) - # Specific handling of block_q8_0 is required. - # Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations. - - block_size = 34 # 34 bytes = + 32 * - - n_blocks = len(tensor.data) // block_size - for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): - block_offs = block_num * block_size - - # Byte-Swap f16 sized delta field - delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) - delta.byteswap(inplace=True) - - # Byte-Swap Q8 weights - if block_num % 100000 == 0: - inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") - - elif tensor.tensor_type == gguf.GGMLQuantizationType.Q4_K: - # Handle Q4_K tensor blocks (block_q4_k) - # Specific handling of block_q4_k is required. - # Each block_q4_k consists of 2 f16 values followed by 140 int8 values. - + if tensor.tensor_type in byteswap_tensors: # first flatten structure + oldshape = tensor.data.shape newshape = 1 for i in tensor.data.shape: newshape *= i tensor.data.resize(newshape) - block_size = 144 - n_blocks = len(tensor.data) // block_size - for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): - block_offs = block_num * block_size - - # Byte-Swap f16 sized fields - delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) - delta.byteswap(inplace=True) - - delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16) - delta.byteswap(inplace=True) - - # Byte-Swap - if block_num % 100000 == 0: - inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") - - elif tensor.tensor_type == gguf.GGMLQuantizationType.Q6_K: - # Handle Q6_K tensor blocks (block_q6_k) - # Specific handling of block_q6_k is required. - # Each block_q6_k consists of 208 int8 values followed by 1 f16 value. - - # first flatten structure - newshape = 1 - for i in tensor.data.shape: - newshape *= i - - tensor.data.resize(newshape) + block_size = byteswap_tensors[tensor.tensor_type]["block_size"] + byteswap_func = byteswap_tensors[tensor.tensor_type]["byteswap_func"] - block_size = 210 n_blocks = len(tensor.data) // block_size for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): block_offs = block_num * block_size - # Byte-Swap f16 sized field - delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16) - delta.byteswap(inplace=True) + byteswap_func(tensor, block_offs) - # Byte-Swap if block_num % 100000 == 0: inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") + # restore old shape in case it's ever used + tensor.data.resize(oldshape) else: # Handle other tensor types tensor.data.byteswap(inplace=True) From 246385a1e9e47809d91b8fb3a8f9da9ee74902f4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Aug 2025 12:27:02 +0300 Subject: [PATCH 14/21] kv-cache : remove LLAMA_SET_ROWS checks (#15505) ggml-ci --- ggml/src/ggml-cann/common.h | 9 --- ggml/src/ggml-cann/ggml-cann.cpp | 5 -- src/llama-context.cpp | 22 ------- src/llama-context.h | 4 -- src/llama-graph.cpp | 4 -- src/llama-kv-cache.cpp | 104 ++++++------------------------- src/llama-kv-cache.h | 10 --- 7 files changed, 20 insertions(+), 138 deletions(-) diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 33794062f565d..88cc3f481ed38 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -374,7 +374,6 @@ struct ggml_backend_cann_context { #endif cann_task_queue task_queue; bool async_mode; - bool support_set_rows; // Rope Cache void* rope_init_ptr = nullptr; void* rope_sin_ptr = nullptr; @@ -400,14 +399,6 @@ struct ggml_backend_cann_context { async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF"); - - support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or("")); - GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF"); - - if (!support_set_rows) { - GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. " - "Falling back to eager mode.\n", __func__); - } } /** diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 81215425618a3..558121dff780b 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2251,11 +2251,6 @@ static enum ggml_status ggml_backend_cann_graph_compute( bool use_cann_graph = true; bool cann_graph_update_required = false; - // check environment LLAMA_SET_ROWS - if (!cann_ctx->support_set_rows) { - use_cann_graph = false; - } - if (use_cann_graph) { if (cann_ctx->cann_graph == nullptr) { cann_ctx->cann_graph.reset(new ggml_cann_graph()); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 99bfed75136b2..6b20161a389a0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -102,16 +102,6 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; - { - const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; - - if (!supports_set_rows && !cparams.kv_unified) { - LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); - cparams.kv_unified = true; - } - } - { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -890,12 +880,6 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - if (!supports_set_rows) { - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); - } - // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1226,12 +1210,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - if (!supports_set_rows) { - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); - } - return 0; } diff --git a/src/llama-context.h b/src/llama-context.h index 3dd9205446483..a372bcfbe41aa 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -283,10 +283,6 @@ struct llama_context { bool has_evaluated_once = false; - // env: LLAMA_SET_ROWS (temporary) - // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = true; - // env: LLAMA_GRAPH_REUSE_DISABLE bool graph_reuse_disable = false; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b928e9e16ead8..1f2fc3ab62d4e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -314,8 +314,6 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_kq_mask->ne[0] == mctx->get_n_kv(); res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); - res &= mctx->get_supports_set_rows(); // TODO: tmp - return res; } @@ -350,8 +348,6 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); - res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp - return res; } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 920c1d0dbdc74..4485f78d5f533 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -197,18 +197,6 @@ llama_kv_cache::llama_kv_cache( const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; - - const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows; - - if (!supports_set_rows) { - // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support"); - } - - if (!supports_set_rows) { - LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); - } } void llama_kv_cache::clear(bool data) { @@ -551,11 +539,8 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vectorne[0]; const int64_t n_tokens = k_cur->ne[2]; k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); - if (k_idxs && supports_set_rows) { - if (k->ne[2] > 1) { - k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); - } - - return ggml_set_rows(ctx, k, k_cur, k_idxs); + if (k->ne[2] > 1) { + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); } - // TODO: fallback to old ggml_cpy() method for backwards compatibility - // will be removed when ggml_set_rows() is adopted by all backends - - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); - - ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*n_embd_k_gqa, - ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); - - return ggml_cpy(ctx, k_cur, k_view); + return ggml_set_rows(ctx, k, k_cur, k_idxs); } ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { + GGML_UNUSED(sinfo); + const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -1072,48 +1043,25 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); - if (v_idxs && supports_set_rows) { - if (!v_trans) { - if (v->ne[2] > 1) { - v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); - } - - return ggml_set_rows(ctx, v, v_cur, v_idxs); - } - - // [TAG_V_CACHE_VARIABLE] - if (n_embd_v_gqa < v->ne[0]) { - v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + if (!v_trans) { + if (v->ne[2] > 1) { + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); } - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - - v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - - return ggml_set_rows(ctx, v_view, v_cur, v_idxs); + return ggml_set_rows(ctx, v, v_cur, v_idxs); } - // TODO: fallback to old ggml_cpy() method for backwards compatibility - // will be removed when ggml_set_rows() is adopted by all backends - - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + // [TAG_V_CACHE_VARIABLE] + if (n_embd_v_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + } - ggml_tensor * v_view = nullptr; + // the row becomes a single element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - if (!v_trans) { - v_view = ggml_view_1d(ctx, v, - n_tokens*n_embd_v_gqa, - ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head()); - } else { - v_cur = ggml_transpose(ctx, v_cur); + v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa, - (v->ne[1] )*ggml_element_size(v), - (sinfo.head())*ggml_element_size(v)); - } - - return ggml_cpy(ctx, v_cur, v_view); + return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { @@ -1143,10 +1091,6 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama } void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { - if (!supports_set_rows) { - return; - } - const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1163,10 +1107,6 @@ void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ub } void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { - if (!supports_set_rows) { - return; - } - const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -2004,10 +1944,6 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } -bool llama_kv_cache_context::get_supports_set_rows() const { - return kv->get_supports_set_rows(); -} - ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 3ca82917d3237..07d29bb8187e2 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -141,9 +141,6 @@ class llama_kv_cache : public llama_memory_i { uint32_t get_n_kv(const slot_info & sinfo) const; - // TODO: temporary - bool get_supports_set_rows() const; - // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; @@ -215,10 +212,6 @@ class llama_kv_cache : public llama_memory_i { // env: LLAMA_KV_CACHE_DEBUG int debug = 0; - // env: LLAMA_SET_ROWS (temporary) - // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = true; - const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; @@ -318,9 +311,6 @@ class llama_kv_cache_context : public llama_memory_context_i { uint32_t get_n_kv() const; - // TODO: temporary - bool get_supports_set_rows() const; - // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; From 5b262c0f2fd5051d6254e4a834a64f4eded60d05 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 28 Aug 2025 19:23:22 +0800 Subject: [PATCH 15/21] scripts: add sqlite3 check for compare-commits.sh (#15633) --- scripts/compare-commits.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh index a28cd5e535bff..1802d6e5ef9f0 100755 --- a/scripts/compare-commits.sh +++ b/scripts/compare-commits.sh @@ -25,6 +25,12 @@ fi # verify at the start that the compare script has all the necessary dependencies installed ./scripts/compare-llama-bench.py --check +if ! command -v sqlite3 >/dev/null 2>&1; then + echo "Error: sqlite3 is not installed or not in PATH" + echo "Please install sqlite3 to use this script" + exit 1 +fi + if [ "$tool" = "llama-bench" ]; then db_file="llama-bench.sqlite" target="llama-bench" From 33e76b4af2a8c66906e4fa9cd446ab86ee9abe8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Thu, 28 Aug 2025 15:49:50 +0200 Subject: [PATCH 16/21] model : jina-embeddings-v3 support (#13693) * initial jina-embeddings-v3 support * initial jina-embeddings-v3 support * initial jina-embeddings-v3 support * fix vocab parsing with only tokenizer.json * set mask token lstrip attribute * additional unk_token_id fallback just in case [no ci] * revert vocab_size() change [no ci] * merge tensor loading into general bert * rope * add lora embedding and loading (non-functional) * export separate lora ggufs instead * add adapter metadata api * use std::string * convert_hf_to_lora compatibility * fix assert * apply suggestions from review * apply suggestion from review --- common/arg.cpp | 4 +- common/common.cpp | 5 +++ common/common.h | 3 ++ convert_hf_to_gguf.py | 79 ++++++++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 20 +++++++++- include/llama.h | 18 +++++++++ src/llama-adapter.cpp | 72 +++++++++++++++++++++++++++++++++-- src/llama-adapter.h | 3 ++ src/llama-arch.cpp | 21 ++++++++++- src/llama-arch.h | 3 ++ src/llama-model.cpp | 37 ++++++++++++------ src/llama-model.h | 1 + src/llama-vocab.cpp | 2 +- tools/server/server.cpp | 2 + 14 files changed, 246 insertions(+), 24 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index d82f55890dee5..93f0108b2b93b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2555,7 +2555,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora"}, "FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)", [](common_params & params, const std::string & value) { - params.lora_adapters.push_back({ std::string(value), 1.0, nullptr }); + params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr }); } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); @@ -2563,7 +2563,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora-scaled"}, "FNAME", "SCALE", "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", [](common_params & params, const std::string & fname, const std::string & scale) { - params.lora_adapters.push_back({ fname, std::stof(scale), nullptr }); + params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr }); } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); diff --git a/common/common.cpp b/common/common.cpp index fdce1dcdec19b..054b43be770da 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -988,7 +988,12 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + char buf[1024]; la.ptr = lora.get(); + llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); + la.task_name = buf; + llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); + la.prompt_prefix = buf; iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters } diff --git a/common/common.h b/common/common.h index 390dda5e531be..87ea0606954a3 100644 --- a/common/common.h +++ b/common/common.h @@ -34,6 +34,9 @@ struct common_adapter_lora_info { std::string path; float scale; + std::string task_name; + std::string prompt_prefix; + struct llama_adapter_lora * ptr; }; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 31a11cbec0baa..6c8a034060536 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -72,6 +72,7 @@ class ModelBase: endianess: gguf.GGUFEndian use_temp_file: bool lazy: bool + dry_run: bool part_names: list[str] is_safetensors: bool hparams: dict[str, Any] @@ -111,6 +112,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager or (remote_hf_model_id is not None) + self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id if remote_hf_model_id is not None: self.is_safetensors = True @@ -4871,11 +4873,35 @@ def modify_tensors(self, data_torch, name, bid): @ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT + _lora_files = {} + _lora_names = [] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): + hparams = kwargs.pop("hparams", None) + if hparams is None: + hparams = ModelBase.load_hparams(dir_model, False) + + if lora_names := hparams.get("lora_adaptations"): + self._lora_names = lora_names + self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3 + + super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) self._xlmroberta_tokenizer_init() + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if self._lora_names: + for name in self._lora_names: + fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-") + self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run) + + return super().generate_extra_tensors() + + def set_type(self): + for lora_writer in self._lora_files.values(): + lora_writer.add_type(gguf.GGUFType.ADAPTER) + lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") + super().set_type() + def set_vocab(self): self._xlmroberta_set_vocab() @@ -4885,13 +4911,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith("roberta."): name = name[8:] + # jina-embeddings-v3 + if ".parametrizations." in name: + name = name.replace(".parametrizations.", ".") + if name.endswith(".original"): + name = name[:-9] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: data_torch = data_torch[self._position_offset:,:] + if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"): + if name.startswith("pooler.dense"): + return [] + + num_loras = data_torch.size(0) + assert num_loras == len(self._lora_names) + + # Split out each LoRA in their own GGUF + for i, lora_writer in enumerate(self._lora_files.values()): + new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower() + data = data_torch[i, :, :] + # Transpose/flip token_embd/types into correct shape + if new_name == "token_embd.weight.lora_b": + data = data.T + elif new_name.startswith("token_types.weight."): + new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b") + lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32) + + return [] + return super().modify_tensors(data_torch, name, bid) + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # jina-embeddings-v3 + if rotary_emb_base := self.hparams.get("rotary_emb_base"): + self.gguf_writer.add_rope_freq_base(rotary_emb_base) + lora_alpha = self.hparams.get("lora_alpha") + if lora_prompt_prefixes := self.hparams.get("task_instructions"): + assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys()) + for lora_name, lora_writer in self._lora_files.items(): + lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0) + lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name) + if lora_prompt_prefixes: + lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name]) + + def write(self): + super().write() + for lora_writer in self._lora_files.values(): + lora_writer.write_header_to_file() + lora_writer.write_kv_data_to_file() + lora_writer.write_tensors_to_file(progress=True) + lora_writer.close() + @ModelBase.register("GemmaForCausalLM") class GemmaModel(TextModel): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b9d1235d1706d..a581f9601f0fd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -231,8 +231,10 @@ class Tokenizer: MIDDLE_ID = "tokenizer.ggml.middle_token_id" class Adapter: - TYPE = "adapter.type" - LORA_ALPHA = "adapter.lora.alpha" + TYPE = "adapter.type" + LORA_ALPHA = "adapter.lora.alpha" + LORA_TASK_NAME = "adapter.lora.task_name" + LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix" class IMatrix: CHUNK_COUNT = "imatrix.chunk_count" @@ -315,6 +317,7 @@ class MODEL_ARCH(IntEnum): NOMIC_BERT_MOE = auto() NEO_BERT = auto() JINA_BERT_V2 = auto() + JINA_BERT_V3 = auto() BLOOM = auto() STABLELM = auto() QWEN = auto() @@ -647,6 +650,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", MODEL_ARCH.NEO_BERT: "neo-bert", MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3", MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", @@ -1234,6 +1238,18 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.LAYER_OUT_NORM, MODEL_TENSOR.CLS, ], + MODEL_ARCH.JINA_BERT_V3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/include/llama.h b/include/llama.h index c5622cc16b4c2..702535385057f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -553,6 +553,24 @@ extern "C" { struct llama_model * model, const char * path_lora); + // Functions to access the adapter's GGUF metadata scalar values + // - The functions return the length of the string on success, or -1 on failure + // - The output string is always null-terminated and cleared on failure + // - When retrieving a string, an extra byte must be allocated to account for the null terminator + // - GGUF array values are not supported by these functions + + // Get metadata value as a string by key name + LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size); + + // Get the number of metadata key/value pairs + LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter); + + // Get metadata key name by index + LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); + + // Get metadata value as a string by index + LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); + // Manually free a LoRA adapter // Note: loaded adapters will be free when the associated model is deleted LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index 8d94034aed95d..772ce1b448088 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -163,13 +163,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ // check metadata { + const gguf_context * gguf_ctx = ctx_gguf.get(); + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__); + + // get metadata as string + for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) { + gguf_type type = gguf_get_kv_type(gguf_ctx, i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i)) + : gguf_type_name(type); + const char * name = gguf_get_key(gguf_ctx, i); + const std::string value = gguf_kv_to_str(gguf_ctx, i); + + if (type != GGUF_TYPE_ARRAY) { + adapter.gguf_kv.emplace(name, value); + } + + const size_t MAX_VALUE_LEN = 40; + std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value; + replace_all(print_value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str()); + } + auto get_kv_str = [&](const std::string & key) -> std::string { - int id = gguf_find_key(ctx_gguf.get(), key.c_str()); - return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id)); + int id = gguf_find_key(gguf_ctx, key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id)); }; auto get_kv_f32 = [&](const std::string & key) -> float { - int id = gguf_find_key(ctx_gguf.get(), key.c_str()); - return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id); + int id = gguf_find_key(gguf_ctx, key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id); }; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); @@ -383,6 +408,45 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p return nullptr; } +int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) { + const auto & it = adapter->gguf_kv.find(key); + if (it == adapter->gguf_kv.end()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) { + return (int)adapter->gguf_kv.size(); +} + +int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)adapter->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = adapter->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->first.c_str()); +} + +int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)adapter->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = adapter->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + void llama_adapter_lora_free(llama_adapter_lora * adapter) { delete adapter; } diff --git a/src/llama-adapter.h b/src/llama-adapter.h index 65824e972765b..9084e7cab08fd 100644 --- a/src/llama-adapter.h +++ b/src/llama-adapter.h @@ -67,6 +67,9 @@ struct llama_adapter_lora { float alpha; + // gguf metadata + std::unordered_map gguf_kv; + llama_adapter_lora() = default; ~llama_adapter_lora() = default; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 0ca0a4c22f814..a61dc177ac9b8 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -22,6 +22,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -234,8 +235,10 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, - { LLM_KV_ADAPTER_TYPE, "adapter.type" }, - { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" }, + { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, @@ -575,6 +578,20 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS, "cls" }, }, }, + { + LLM_ARCH_JINA_BERT_V3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + }, + }, { LLM_ARCH_BLOOM, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 7008c2514c5d4..94b0bef719bed 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -26,6 +26,7 @@ enum llm_arch { LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_JINA_BERT_V3, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -230,6 +231,8 @@ enum llm_kv { LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + LLM_KV_ADAPTER_LORA_TASK_NAME, + LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, LLM_KV_POSNET_EMBEDDING_LENGTH, LLM_KV_POSNET_BLOCK_COUNT, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7d3429617bef9..30974a723fda1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; + case LLM_TYPE_558M: return "558M"; case LLM_TYPE_700M: return "700M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; @@ -772,6 +773,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JINA_BERT_V3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 24: + type = LLM_TYPE_558M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -2631,6 +2644,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_JINA_BERT_V3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); @@ -2666,24 +2680,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - - if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); } } @@ -7461,7 +7473,7 @@ struct llm_build_bert : public llm_graph_context { } // RoPE - if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -7520,7 +7532,7 @@ struct llm_build_bert : public llm_graph_context { 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); - } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, @@ -18241,6 +18253,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: @@ -18395,6 +18408,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -18885,6 +18899,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GROK: case LLM_ARCH_DBRX: case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: diff --git a/src/llama-model.h b/src/llama-model.h index af4460cc01eb0..fa44d800d5277 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -40,6 +40,7 @@ enum llm_type { LLM_TYPE_450M, LLM_TYPE_475M, LLM_TYPE_537M, + LLM_TYPE_558M, LLM_TYPE_700M, LLM_TYPE_770M, LLM_TYPE_780M, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index de5d1681dff85..ca02b63a58407 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2470,7 +2470,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // set attributes by model/tokenizer/architecture name if (false || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) - || _contains_any(general_arch, {"nomic-bert-moe"}) + || _contains_any(general_arch, {"nomic-bert-moe", "jina-bert-v3"}) ) { if (token_to_id.count("") == 0) { LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6eb5aeb582b3a..6aa319d2f1121 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -4898,6 +4898,8 @@ int main(int argc, char ** argv) { {"id", i}, {"path", lora.path}, {"scale", lora.scale}, + {"task_name", lora.task_name}, + {"prompt_prefix", lora.prompt_prefix}, }); } res_ok(res, result); From 4f4a8133347aa287426d68739526b9aa5c0bbf06 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Aug 2025 17:09:05 +0300 Subject: [PATCH 17/21] kv-cache : fix find_slot to not search for continuous slot (#15638) ggml-ci --- src/llama-kv-cache.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 4485f78d5f533..f1c6918738684 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -540,7 +540,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector Date: Thu, 28 Aug 2025 10:11:36 -0400 Subject: [PATCH 18/21] ggml : fix SSM_SCAN for n_groups > 1 (#15625) --- ggml/src/ggml-cpu/ops.cpp | 21 +++++++++++---------- ggml/src/ggml-cuda/ssm-scan.cu | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 10 ++++++---- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 93330b43a9b84..8c1f7948855ac 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); - // allows optimizing the modulo since n_group should be a power of 2 - GGML_ASSERT((ng & -ng) == ng); + GGML_ASSERT(nh % ng == 0); // heads per thread const int dh = (nh + nth - 1)/nth; @@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32( // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: maybe unroll more? for (int j = 0; j < 1; j++) { GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); - GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); - GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc); t0 = GGML_F32_VEC_MUL(t0, adA); t1 = GGML_F32_VEC_MUL(t1, axdt); @@ -9090,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); - ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); - az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc); ax[j] = GGML_F32_VEC_MUL(ax[j], adA); ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); @@ -9113,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = np; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) @@ -9130,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -9144,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: what happens when (d_state % svcntw()) != 0? for (int64_t k = 0; k < nc; k += svcntw()) { svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]); - svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]); - svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]); svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]); svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); @@ -9165,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = 0; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index dc9a7d58d057c..6b424381df5a7 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1) const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); const int seq_idx = blockIdx.y; - const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float); + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index fa80d6e405978..4fa16c4a553d2 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1983,14 +1983,15 @@ kernel void kernel_ssm_scan_f32( device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); const int64_t i = i0 + i1*nc; + const int64_t g = ir / (nh / ng); // repeat_interleave float s0 = s0_buff[i]; float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { @@ -2098,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group( device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); const int64_t i = i0 + i1*nc; + const int64_t g = ir / (nh / ng); // repeat_interleave float s0 = s0_buff[i]; float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { From 491b300d36f7f2a2b18015653d63ea5b8d26975f Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Thu, 28 Aug 2025 22:39:27 +0800 Subject: [PATCH 19/21] ggml-cpu: fix invalid hsum build in debug s390x (#15634) Signed-off-by: Aaron Teo --- ggml/src/ggml-cpu/ggml-cpu-impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 1f6844e16cd34..e08c30a348aa1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -489,7 +489,7 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) { /** * @see https://github.com/ggml-org/llama.cpp/pull/14037 */ -inline float vec_hsum(float32x4_t v) { +inline static float vec_hsum(float32x4_t v) { float32x4_t v_temp = v + vec_reve(v); return v_temp[0] + v_temp[1]; } From b139bd42dba747f1fef7a3d66d3ca83b1e3c8999 Mon Sep 17 00:00:00 2001 From: mnehete32 <33429707+mnehete32@users.noreply.github.com> Date: Fri, 29 Aug 2025 00:03:03 +0530 Subject: [PATCH 20/21] CUDA: add conv2d (#15635) * CUDA: add conv2d * CUDA: conv2d - correct formatting and added const --- ggml/src/ggml-cuda/conv2d.cu | 171 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/conv2d.cuh | 5 + ggml/src/ggml-cuda/ggml-cuda.cu | 5 + 3 files changed, 181 insertions(+) create mode 100644 ggml/src/ggml-cuda/conv2d.cu create mode 100644 ggml/src/ggml-cuda/conv2d.cuh diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu new file mode 100644 index 0000000000000..cf878d1fd18e5 --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -0,0 +1,171 @@ +#include "conv2d.cuh" + +struct conv_params { + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; +}; + +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); + } +}; + +template +static __global__ void conv2d_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= P.TOTAL) { + return; + } + + int64_t n, c_out, out_y, out_x; + Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); + + T acc = 0; + + for (int64_t c_in = 0; c_in < P.IC; ++c_in) { + kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); + + for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { + const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); + + for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { + const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); + + T input_val; + if (std::is_same::value) { + input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]); + } else { + input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; + } + + T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; + acc += (input_val * kernel_val); + } + } + } + + // [N, OC, OH, OW] + output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc; +} + +template +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_kernel<<>>(X_D, K_D, Y_D, P); +} + +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); +} + +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); +} + +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + // same number of input channels + GGML_ASSERT(input->ne[2] == kernel->ne[2]); + + cudaStream_t st = ctx.stream(); + + const int32_t * p = (const int32_t *) dst->op_params; + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y + + // No cwhn + GGML_ASSERT(p[6] == false); + + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + } else { + conv2d_cuda_f32(X_D, K_D, Y_D, params, st); + } +} diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh new file mode 100644 index 0000000000000..ce4802c7ed797 --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3a50527248045..4c02b57227a88 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -12,6 +12,7 @@ #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" +#include "ggml-cuda/conv2d.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2451,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; + case GGML_OP_CONV_2D: + ggml_cuda_op_conv2d(ctx, dst); + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -3501,6 +3505,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); } case GGML_OP_IM2COL: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: From b942adea968ba915fc3eb6efc36a9e97e192bb86 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 28 Aug 2025 15:27:36 -0500 Subject: [PATCH 21/21] fix: Compute the full sum in llama-eval-callback, not just the sum of printed values (#15637) This makes it much easier to compare between llama.cpp and transformers! https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart --- examples/eval-callback/eval-callback.cpp | 50 +++++++++++++++--------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 61eefc7248a7e..d4ef751fbb63a 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -28,9 +28,40 @@ static std::string ggml_ne_string(const ggml_tensor * t) { return str; } +static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I64) { + v = (float) *(int64_t *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + return v; +} + static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { GGML_ASSERT(n > 0); float sum = 0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + sum += v; + } + } + } + } for (int64_t i3 = 0; i3 < ne[3]; i3++) { LOG(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { @@ -50,25 +81,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne LOG("..., "); i0 = ne[0] - n; } - size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; - float v; - if (type == GGML_TYPE_F16) { - v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); - } else if (type == GGML_TYPE_F32) { - v = *(float *) &data[i]; - } else if (type == GGML_TYPE_I64) { - v = (float) *(int64_t *) &data[i]; - } else if (type == GGML_TYPE_I32) { - v = (float) *(int32_t *) &data[i]; - } else if (type == GGML_TYPE_I16) { - v = (float) *(int16_t *) &data[i]; - } else if (type == GGML_TYPE_I8) { - v = (float) *(int8_t *) &data[i]; - } else { - GGML_ABORT("fatal error"); - } + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); LOG("%12.4f", v); - sum += v; if (i0 < ne[0] - 1) LOG(", "); } LOG("],\n");