From b9bcb7d7abbecacd345585f74dbe2dcd5c67efdb Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 2 Jul 2025 15:58:10 +0800 Subject: [PATCH 1/3] CUDA: add dynamic shared mem to softmax, refactor general usage --- ggml/src/ggml-cuda/common.cuh | 14 ++++++++++++ ggml/src/ggml-cuda/cross-entropy-loss.cu | 16 ++----------- ggml/src/ggml-cuda/mmq.cuh | 10 ++------ ggml/src/ggml-cuda/softmax.cu | 29 ++++++++++++++++++++++-- tests/test-backend-ops.cpp | 6 ++--- 5 files changed, 48 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ea20355023825..4f5510d65c141 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -175,6 +175,20 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) +#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \ + const int id = ggml_cuda_get_device(); \ + if (!shared_memory_limit_raised[id]) { \ + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ + shared_memory_limit_raised[id] = true; \ + } \ + } while (0) +#else +#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0) +#endif + #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index 0ce4afbb222bd..f328553b56ad9 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32), nbytes_shared); cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); } else { cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); @@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten const size_t smpbo = ggml_cuda_info().devices[id].smpbo; if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32), nbytes_shared); cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); } else { cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 80baf459c15f2..9696a32046212 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index e0eb921d85d53..dc6ca10d3da34 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -2,6 +2,7 @@ #include "ggml.h" #include "softmax.cuh" #include +#include template static __device__ __forceinline__ float t2f32(T val) { @@ -181,6 +182,24 @@ static __global__ void soft_max_back_f32( } } +template +void increase_shared_mem_limits(std::size_t smpbo) +{ + auto apply_limit = [smpbo](auto I) { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + + CUDA_SET_SHARED_MEMORY_LIMIT( + (soft_max_f32), smpbo); + CUDA_SET_SHARED_MEMORY_LIMIT( + (soft_max_f32), smpbo); + }; + + //unary fold + ( apply_limit(std::integral_constant{}), ... ); +} + + template static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) { int nth = WARP_SIZE; @@ -193,8 +212,14 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - // FIXME: this limit could be raised by ~2-4x on Ampere or newer - if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + + if (nbytes_shared <= smpbo) { + + increase_shared_mem_limits<0, 32, 64, 128, 256, 512, 1024, 2048, 4096>(smpbo); + switch (ncols_x) { case 32: soft_max_f32<<>> diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a4f94c2594038..22c6a13105828 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4901,12 +4901,12 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 1024, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 1024, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); From 34e5142d37e0eda8be63df993d30cf8e7f8ac4cb Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 2 Jul 2025 17:18:09 +0800 Subject: [PATCH 2/3] Review: refactor switch statement, change cross_entropy to use full size --- ggml/src/ggml-cuda/common.cuh | 2 +- ggml/src/ggml-cuda/cross-entropy-loss.cu | 4 +- ggml/src/ggml-cuda/softmax.cu | 80 ++++++++---------------- 3 files changed, 30 insertions(+), 56 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4f5510d65c141..954f74d408f9f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -187,7 +187,7 @@ static const char * cu_get_error_str(CUresult err) { } while (0) #else #define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0) -#endif +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index f328553b56ad9..0c8b0819724e4 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -123,7 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); if (nbytes_shared <= smpbo) { - CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32), smpbo); cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); } else { cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); @@ -169,7 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten const size_t smpbo = ggml_cuda_info().devices[id].smpbo; if (nbytes_shared <= smpbo) { - CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32), smpbo); cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); } else { cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index dc6ca10d3da34..fb10a4eaec54a 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -182,21 +182,34 @@ static __global__ void soft_max_back_f32( } } -template -void increase_shared_mem_limits(std::size_t smpbo) +template +static void launch_soft_max_kernels(float * x, const T * mask, float * dst, + const soft_max_params & p, cudaStream_t stream) { - auto apply_limit = [smpbo](auto I) { - constexpr int ncols = decltype(I)::value; - constexpr int block = (ncols > 1024 ? 1024 : ncols); - - CUDA_SET_SHARED_MEMORY_LIMIT( - (soft_max_f32), smpbo); - CUDA_SET_SHARED_MEMORY_LIMIT( - (soft_max_f32), smpbo); + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + + if (p.ncols == ncols) { + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>> + (x, mask, dst, p); + return true; + } + return false; }; - //unary fold - ( apply_limit(std::integral_constant{}), ... ); + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + + //default case + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>>(x, mask, dst, p); } @@ -217,47 +230,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons if (nbytes_shared <= smpbo) { - - increase_shared_mem_limits<0, 32, 64, 128, 256, 512, 1024, 2048, 4096>(smpbo); - - switch (ncols_x) { - case 32: - soft_max_f32<<>> - (x, mask, dst, params); - break; - case 64: - soft_max_f32<<>> - (x, mask, dst, params); - break; - case 128: - soft_max_f32<<>> - (x, mask, dst, params); - break; - case 256: - soft_max_f32<<>> - (x, mask, dst, params); - break; - case 512: - soft_max_f32<<>> - (x, mask, dst, params); - break; - case 1024: - soft_max_f32<<>> - (x, mask, dst, params); - break; - case 2048: - soft_max_f32<<>> - (x, mask, dst, params); - break; - case 4096: - soft_max_f32<<>> - (x, mask, dst, params); - break; - default: - soft_max_f32<<>> - (x, mask, dst, params); - break; - } + + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream); } else { const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); soft_max_f32<<>>(x, mask, dst, params); From 7b162818b60c897e060824b70e573c34162040f3 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 2 Jul 2025 21:09:38 +0800 Subject: [PATCH 3/3] rebase --- ggml/src/ggml-cuda/softmax.cu | 11 +++++------ tests/test-backend-ops.cpp | 7 ++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index fb10a4eaec54a..14543e978cf0f 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -183,8 +183,8 @@ static __global__ void soft_max_back_f32( } template -static void launch_soft_max_kernels(float * x, const T * mask, float * dst, - const soft_max_params & p, cudaStream_t stream) +static void launch_soft_max_kernels(const float * x, const T * mask, float * dst, + const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) { const int id = ggml_cuda_get_device(); const size_t smpbo = ggml_cuda_info().devices[id].smpbo; @@ -195,7 +195,7 @@ static void launch_soft_max_kernels(float * x, const T * mask, float * dst, if (p.ncols == ncols) { CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); - soft_max_f32<<>> + soft_max_f32<<>> (x, mask, dst, p); return true; } @@ -209,7 +209,7 @@ static void launch_soft_max_kernels(float * x, const T * mask, float * dst, //default case CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); - soft_max_f32<<>>(x, mask, dst, p); + soft_max_f32<<>>(x, mask, dst, p); } @@ -230,8 +230,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons if (nbytes_shared <= smpbo) { - - launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream); + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); soft_max_f32<<>>(x, mask, dst, params); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 22c6a13105828..ec8ab060180a7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4901,12 +4901,13 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 1024, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 1024, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));