From 82dd2c78f4fff2f12b050b536fcb0caf464a3677 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 10 Sep 2025 14:50:59 +0800 Subject: [PATCH 1/4] CUDA: MUL_MAT_ID optimizations for mmf --- ggml/src/ggml-cuda/mmf.cuh | 58 ++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index bf724bc57b8a0..a302be5260f46 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -57,31 +57,37 @@ static __global__ void mul_mat_f( T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); if constexpr (has_ids) { - __shared__ int has_any; - if (threadIdx.y == 0) { - int local_has_any = 0; - for (int j = threadIdx.x; j < cols_per_block; j += warp_size) { - int slot = -1; - for (int k = 0; k < nchannels_dst; ++k) { - const int idv = ids[j*stride_row_id + k*stride_col_id]; - if (idv == expert_idx) { - slot = k; - break; + int found = 0; + + for (int j = threadIdx.y; j < cols_per_block; j += nwarps) { + const int32_t * __restrict__ id_row = ids + j*stride_row_id; + + if (threadIdx.x == 0) { + slot_map[j] = -1; + } + + for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) { + int k = k_base + threadIdx.x; + int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx); + + unsigned mask = __ballot_sync(0xffffffff, match); + if (mask) { + int leader = __ffs(mask) - 1; + if (threadIdx.x == leader) { + slot_map[j] = k_base + leader; } - } - if (j < cols_per_block) { - local_has_any |= (slot >= 0); - slot_map[j] = slot; + found = 1; + break; } } - has_any = warp_reduce_any(local_has_any); } - __syncthreads(); - if (has_any == 0) { + + if (!__syncthreads_or(found)) { return; } } + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { tile_A A[ntA][warp_size / tile_A::J]; #pragma unroll @@ -106,14 +112,7 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; } else { - float val = 0.0f; - if (j < cols_per_block) { - const int slot = slot_map[j]; - if (slot >= 0) { - val = y[slot*stride_channel_y + j*stride_col_y + col]; - } - } - tile_xy[j0*tile_k_padded + threadIdx.x] = val; + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; } } } else if constexpr (std::is_same_v || std::is_same_v) { @@ -125,14 +124,7 @@ static __global__ void mul_mat_f( const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } else { - float2 tmp = make_float2(0.0f, 0.0f); - if (j < cols_per_block) { - const int slot = slot_map[j]; - if (slot >= 0) { - const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y); - tmp = y2_slot[j*stride_col_y + col]; - } - } + float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } } From bb831b2af71648e8f0e8f6e9ae5b7d7fb7147e12 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 11 Sep 2025 00:50:49 +0800 Subject: [PATCH 2/4] unroll n_expert_used loop + remove warp syncs --- ggml/src/ggml-cuda/mmf.cuh | 109 ++++++++++++++++++++++++++++++++----- 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index a302be5260f46..248ad5904d857 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -11,7 +11,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols); -template +template __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) static __global__ void mul_mat_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, @@ -57,27 +57,40 @@ static __global__ void mul_mat_f( T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); if constexpr (has_ids) { + int found = 0; - for (int j = threadIdx.y; j < cols_per_block; j += nwarps) { +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; const int32_t * __restrict__ id_row = ids + j*stride_row_id; if (threadIdx.x == 0) { slot_map[j] = -1; } - for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) { - int k = k_base + threadIdx.x; - int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx); + if constexpr (n_expert_used == 0) { + for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) { + int k = k_base + threadIdx.x; + int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx); - unsigned mask = __ballot_sync(0xffffffff, match); - if (mask) { - int leader = __ffs(mask) - 1; - if (threadIdx.x == leader) { - slot_map[j] = k_base + leader; + if (match) { + slot_map[j] = k; + found = 1; + break; + } + } + } else { +#pragma unroll + for (int k_base = 0; k_base < n_expert_used; k_base += warp_size) { + int k = k_base + threadIdx.x; + int match = (k < n_expert_used) && (id_row[k*stride_col_id] == expert_idx); + + if (match) { + slot_map[j] = k; + found = 1; + break; } - found = 1; - break; } } } @@ -202,6 +215,71 @@ static __global__ void mul_mat_f( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } +template +static inline void launch_mul_mat_ids( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nchannels_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { + + const int n_expert_used = nchannels_dst; + + switch (n_expert_used) { + case 1: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 2: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 4: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 6: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 8: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 16: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 32: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + default: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + } +} + + template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -212,10 +290,11 @@ static inline void mul_mat_f_switch_ids( const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { if (ids) { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + launch_mul_mat_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, + block_nums, block_dims, nbytes_shared_total, stream); } else { mul_mat_f<<>> (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, From bf08ea517566ccedf0e6a4fa9ebcd7721ddcab6f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 12 Sep 2025 10:27:07 +0800 Subject: [PATCH 3/4] Remove tempalte from n_expert_used --- ggml/src/ggml-cuda/mmf.cuh | 104 ++++--------------------------------- 1 file changed, 11 insertions(+), 93 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 248ad5904d857..cddee758ff28a 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -11,7 +11,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols); -template +template __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) static __global__ void mul_mat_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, @@ -57,10 +57,8 @@ static __global__ void mul_mat_f( T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); if constexpr (has_ids) { - int found = 0; -#pragma unroll for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { const int j = j0 + threadIdx.y; const int32_t * __restrict__ id_row = ids + j*stride_row_id; @@ -69,28 +67,14 @@ static __global__ void mul_mat_f( slot_map[j] = -1; } - if constexpr (n_expert_used == 0) { - for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) { - int k = k_base + threadIdx.x; - int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx); + for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) { + int k = k_base + threadIdx.x; + int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx); - if (match) { - slot_map[j] = k; - found = 1; - break; - } - } - } else { -#pragma unroll - for (int k_base = 0; k_base < n_expert_used; k_base += warp_size) { - int k = k_base + threadIdx.x; - int match = (k < n_expert_used) && (id_row[k*stride_col_id] == expert_idx); - - if (match) { - slot_map[j] = k; - found = 1; - break; - } + if (match) { + slot_map[j] = k; + found = 1; + break; } } } @@ -215,71 +199,6 @@ static __global__ void mul_mat_f( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } -template -static inline void launch_mul_mat_ids( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nchannels_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t stride_col_id, const int64_t stride_row_id, - const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, - const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { - - const int n_expert_used = nchannels_dst; - - switch (n_expert_used) { - case 1: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 2: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 4: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 6: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 8: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 16: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 32: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, - stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - } -} - - template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -290,11 +209,10 @@ static inline void mul_mat_f_switch_ids( const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { if (ids) { - launch_mul_mat_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, - block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { mul_mat_f<<>> (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, From 5016ca50f750d377044c5f4e2b3ee383d7a7b20d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 15 Sep 2025 10:47:46 +0800 Subject: [PATCH 4/4] Update k inside the loop as it's not a candidate for unrolling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmf.cuh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index cddee758ff28a..61e3bf30152c7 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -67,9 +67,8 @@ static __global__ void mul_mat_f( slot_map[j] = -1; } - for (int k_base = 0; k_base < nchannels_dst; k_base += warp_size) { - int k = k_base + threadIdx.x; - int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx); + for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { + int match = id_row[k*stride_col_id] == expert_idx; if (match) { slot_map[j] = k;