From 8898040470c388f0f8d5e973350d022e50514f61 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 11 Sep 2025 22:57:13 +0800 Subject: [PATCH 1/3] CUDA: use fastdiv + ggml_cuda_mad for mmvf --- ggml/src/ggml-cuda/mmvf.cu | 59 ++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 5b21ef05b3c35..d89d8b40a59bd 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,3 +1,4 @@ + #include "ggml.h" #include "common.cuh" #include "convert.cuh" @@ -7,14 +8,14 @@ template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 channel_ratio_fd, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio_fd, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const int row = blockIdx.x; const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio_fd); const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio_fd); const int sample_y = sample_dst; const int tid = threadIdx.x; @@ -47,8 +48,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x*tmpy.x; - sumf[j] += tmpx.y*tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else if constexpr (std::is_same_v) { @@ -61,8 +62,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x * tmpy.x; - sumf[j] += tmpx.y * tmpy.y; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else { @@ -94,8 +95,10 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[1]) * tmpy.y; + const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); } } } else { @@ -140,8 +143,8 @@ static void launch_mul_mat_vec_f_cuda( GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -167,50 +170,50 @@ static void launch_mul_mat_vec_f_cuda( case 32: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); From e1afe75368fe66ea817b78a2167a2d9007855ed9 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 12:19:55 +0800 Subject: [PATCH 2/3] use bf16 directly + fix formatting --- ggml/src/ggml-cuda/common.cuh | 11 +++++++++++ ggml/src/ggml-cuda/mmvf.cu | 21 +++++++++------------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d51abbeafa944..de317fb18aae7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -532,6 +532,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } + static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { acc += v*u; } @@ -570,6 +571,16 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, #endif // FAST_FP16_AVAILABLE } + +#if defined(GGML_USE_HIP) +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const __hip_bfloat162 v, const __hip_bfloat162 u) { + const float2 tmpv = __bfloat162float2(v); + const float2 tmpu = __bfloat162float2(u); + acc += tmpv.x * tmpu.x; + acc += tmpv.y * tmpu.y; +} +#endif + // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. template static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index d89d8b40a59bd..2a5f75f12f80b 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,4 +1,3 @@ - #include "ggml.h" #include "common.cuh" #include "convert.cuh" @@ -8,14 +7,14 @@ template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const uint3 channel_ratio_fd, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio_fd, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const int row = blockIdx.x; const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio_fd); + const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; const int sample_dst = blockIdx.z; - const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio_fd); + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; const int tid = threadIdx.x; @@ -89,16 +88,14 @@ static __global__ void mul_mat_vec_f( #endif // FP16_AVAILABLE } } else if constexpr (std::is_same_v) { - const int * x2 = (const int *) x; + const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const int tmpx = x2[col2]; + const nv_bfloat162 tmpx = x2[col2]; #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); - const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); - ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); - ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } } else { @@ -143,7 +140,7 @@ static void launch_mul_mat_vec_f_cuda( GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); From d6c71e928ce13d88182e263fadf7f26c451c5d54 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 14:39:52 +0800 Subject: [PATCH 3/3] Add exception for HIP code --- ggml/src/ggml-cuda/common.cuh | 11 ----------- ggml/src/ggml-cuda/mmvf.cu | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index de317fb18aae7..d51abbeafa944 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -532,7 +532,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } - static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { acc += v*u; } @@ -571,16 +570,6 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, #endif // FAST_FP16_AVAILABLE } - -#if defined(GGML_USE_HIP) -static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const __hip_bfloat162 v, const __hip_bfloat162 u) { - const float2 tmpv = __bfloat162float2(v); - const float2 tmpu = __bfloat162float2(u); - acc += tmpv.x * tmpu.x; - acc += tmpv.y * tmpu.y; -} -#endif - // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. template static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 2a5f75f12f80b..57ab839393aa0 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -88,6 +88,21 @@ static __global__ void mul_mat_vec_f( #endif // FP16_AVAILABLE } } else if constexpr (std::is_same_v) { +//TODO: add support for ggml_cuda_mad for hip_bfloat162 +#if defined(GGML_USE_HIP) + const int * x2 = (const int *) x; + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const int tmpx = x2[col2]; +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + const float tmpx0 = ggml_cuda_cast(reinterpret_cast(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); + } + } +#else const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const nv_bfloat162 tmpx = x2[col2]; @@ -98,6 +113,7 @@ static __global__ void mul_mat_vec_f( ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); } } +#endif } else { static_assert(std::is_same_v, "unsupported type"); }