From 6efd872732159ab88ee7b3c1d77ba5ebc83079bd Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Thu, 24 Apr 2025 22:48:14 +0200 Subject: [PATCH 1/4] Force FP32 compute in cuBLAS GEMM --- ggml/src/ggml-cuda/ggml-cuda.cu | 40 +++++++++------------------------ 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e0e0d2137f3be..8943d084f10b2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1245,37 +1245,19 @@ static void ggml_cuda_op_mul_mat_cublas( } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - - if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { - const float alpha = 1.0f; - const float beta = 0.0f; - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta, dst_dd_i, CUDA_R_32F, ldc, - CUBLAS_COMPUTE_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } else { - ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; + const float alpha = 1.0f; + const float beta = 0.0f; - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); - } } else { ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id)); From db52579ac7633c8bd96826be8f9bc2ee27a4d8e8 Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Fri, 25 Apr 2025 11:49:09 +0200 Subject: [PATCH 2/4] Revert "Force FP32 compute in cuBLAS GEMM" This reverts commit 6efd872732159ab88ee7b3c1d77ba5ebc83079bd. --- ggml/src/ggml-cuda/ggml-cuda.cu | 40 ++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8943d084f10b2..e0e0d2137f3be 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1245,19 +1245,37 @@ static void ggml_cuda_op_mul_mat_cublas( } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - const float alpha = 1.0f; - const float beta = 0.0f; - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta, dst_dd_i, CUDA_R_32F, ldc, - CUBLAS_COMPUTE_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id)); From 70975676f79192fd4bdddd2b10bcc54e5550bd69 Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Fri, 25 Apr 2025 11:49:50 +0200 Subject: [PATCH 3/4] Force F32 compute in GLM4 ffn down --- src/llama-graph.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a85e97288e1ae..d9ae64373caa9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -803,6 +803,10 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have precision issues in F16 + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } } if (down_b) { From 06113f00273e008228f690a1c51404d25210d96f Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Fri, 25 Apr 2025 12:58:48 +0200 Subject: [PATCH 4/4] Edit comment to clarify issue MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- src/llama-graph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d9ae64373caa9..b52e3f6203a4b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -804,7 +804,7 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have precision issues in F16 + // GLM4 seems to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } }