From d138a03ddfb23b6a373f757503cf107658eff1e2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 29 Nov 2025 00:15:13 +0100 Subject: [PATCH 01/25] Add support for CUMSUM and TRI for CUDA. --- ggml/src/ggml-cuda/common.cuh | 50 +++++++++++++ ggml/src/ggml-cuda/cumsum.cu | 126 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/cumsum.cuh | 5 ++ ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++ ggml/src/ggml-cuda/tri.cu | 104 ++++++++++++++++++++++++++ ggml/src/ggml-cuda/tri.cuh | 5 ++ tests/test-backend-ops.cpp | 6 ++ 7 files changed, 306 insertions(+) create mode 100644 ggml/src/ggml-cuda/cumsum.cu create mode 100644 ggml/src/ggml-cuda/cumsum.cuh create mode 100644 ggml/src/ggml-cuda/tri.cu create mode 100644 ggml/src/ggml-cuda/tri.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0b10e5f6ae0..c53208bed8b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -461,6 +461,56 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(mask, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(mask, a.x, offset, width); + const float t_y = __shfl_up_sync(mask, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + if (lane_id >= offset) { + a += t; + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 00000000000..e14be0721c6 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,126 @@ +#include "cumsum.cuh" + +// Kernel to compute cumulative sum along the innermost dimension (ne[0]) +// Each block processes one row (ne[0] elements) +// Algorithm matches Metal implementation: +// 1. Each warp computes prefix sum within itself +// 2. Last thread of each warp stores result in shared memory +// 3. All warps sync +// 4. Each element adds the sum of all preceding warps + +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + + // Shared memory to store warp sums (always use float for accumulation) + extern __shared__ float shmem[]; + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + const int tid = threadIdx.x; + const int lane_id = tid % WARP_SIZE; + + // Phase 1: Each thread processes elements at stride blockDim.x + // Compute warp-level prefix sums + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + // Load value and compute prefix sum within warp + float val = static_cast(src_row[i0]); + val = warp_prefix_inclusive_sum(val); + dst_row[i0] = static_cast(val); + + // Last thread of warp stores its sum to shared memory at position based on data index + if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { + const int shmem_idx = i0 / WARP_SIZE; + shmem[shmem_idx] = val; + } + } + + // Sync once after all warp prefix sums are computed + __syncthreads(); + + // Phase 2: Add the sum of all preceding warp groups to each element + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + const int shmem_idx = i0 / WARP_SIZE; + float sum = 0.0f; + for (int j = 0; j < shmem_idx; ++j) { + sum += shmem[j]; + } + dst_row[i0] = static_cast(static_cast(dst_row[i0]) + sum); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + // Shared memory size: one float per warp + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + const size_t shmem_size = num_warps * sizeof(float); + + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3 + ); +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 00000000000..782d1d92e9b --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a844a3d99a2..689e5dfc384 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -54,6 +54,8 @@ #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/tri.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml.h" #include @@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 00000000000..b531f696302 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,104 @@ +#include "tri.cuh" +#include "ggml.h" + +// Triangle type comparison - determines which elements to keep +__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; + case GGML_TRI_TYPE_UPPER: return i > r; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; + default: return false; + } +} + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype) { + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + // Each thread processes elements at stride blockDim.x + for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = tri_compare(i0, i1, ttype) + ? src_row[i0] : static_cast(0.f); + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3, + ttype + ); +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 00000000000..a4cc66750d3 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 60bab47b9f2..306fa15b923 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7938,6 +7938,12 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); + + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { From 67207d21f9f84f1e0ac407606f40bd382100c096 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 29 Nov 2025 00:17:29 +0100 Subject: [PATCH 02/25] Minor optimizations. --- ggml/src/ggml-cuda/cumsum.cu | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index e14be0721c6..030397d403d 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -32,6 +32,10 @@ static __global__ void cumsum_kernel( const int tid = threadIdx.x; const int lane_id = tid % WARP_SIZE; + if (tid >= ne00) { + return; + } + // Phase 1: Each thread processes elements at stride blockDim.x // Compute warp-level prefix sums for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { @@ -69,13 +73,18 @@ static void cumsum_cuda( const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, cudaStream_t stream) { - dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); dim3 grid_dims(ne01, ne02, ne03); // Shared memory size: one float per warp const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; const size_t shmem_size = num_warps * sizeof(float); + int block_size = num_warps * WARP_SIZE; + if (block_size > CUDA_CUMSUM_BLOCK_SIZE) { + block_size = CUDA_CUMSUM_BLOCK_SIZE; + } + dim3 block_dims(block_size, 1, 1); + cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, From fab002949f9f4458577a3f314dc5772d9e25ec68 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 29 Nov 2025 00:40:51 +0100 Subject: [PATCH 03/25] Correct warp_prefix_inclusive_sum in float2 variant to return float2 --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c53208bed8b..c747c1c80df 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -476,7 +476,7 @@ static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { } template -static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) { +static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { const int lane_id = threadIdx.x % width; const auto mask = __activemask(); #pragma unroll From 51c40a5a3951b7eeca080dd7a7c9f84025eaec90 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 16:10:05 +0100 Subject: [PATCH 04/25] Optimize TRI --- ggml/src/ggml-cuda/cumsum.cu | 15 ++++++++------- ggml/src/ggml-cuda/tri.cu | 30 ++++++++++++------------------ 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 030397d403d..e758fd8bdba 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,3 +1,5 @@ +#include + #include "cumsum.cuh" // Kernel to compute cumulative sum along the innermost dimension (ne[0]) @@ -26,8 +28,8 @@ static __global__ void cumsum_kernel( return; } - const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); - T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + const T * src_row = src + i1 * nb01 + i2*nb02 + i3*nb03; + T * dst_row = dst + i1 * nb1 + i2*nb2 + i3*nb3; const int tid = threadIdx.x; const int lane_id = tid % WARP_SIZE; @@ -78,18 +80,17 @@ static void cumsum_cuda( // Shared memory size: one float per warp const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; const size_t shmem_size = num_warps * sizeof(float); + const size_t type_size = sizeof(T); int block_size = num_warps * WARP_SIZE; - if (block_size > CUDA_CUMSUM_BLOCK_SIZE) { - block_size = CUDA_CUMSUM_BLOCK_SIZE; - } + block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); dim3 block_dims(block_size, 1, 1); cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - nb00, nb01, nb02, nb03, - nb0, nb1, nb2, nb3 + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size ); } diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index b531f696302..9ac13e33d4a 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -1,28 +1,18 @@ #include "tri.cuh" #include "ggml.h" -// Triangle type comparison - determines which elements to keep -__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { - switch (type) { - case GGML_TRI_TYPE_LOWER: return i < r; - case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; - case GGML_TRI_TYPE_UPPER: return i > r; - case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; - default: return false; - } -} - template static __global__ void tri_kernel( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const ggml_tri_type ttype) { + const int add_to_split, const bool prefix_keep) { const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; + const int64_t split_point = i1 + add_to_split; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; @@ -30,11 +20,11 @@ static __global__ void tri_kernel( const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - + // Each thread processes elements at stride blockDim.x for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { - dst_row[i0] = tri_compare(i0, i1, ttype) - ? src_row[i0] : static_cast(0.f); + const bool keep = ((i0 < split_point) == prefix_keep); + dst_row[i0] = keep ? src_row[i0] : T(0); } } @@ -49,13 +39,17 @@ static void tri_cuda( dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); dim3 grid_dims(ne01, ne02, ne03); + const size_t type_size = sizeof(T); + + const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; + const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); tri_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - nb00, nb01, nb02, nb03, - nb0, nb1, nb2, nb3, - ttype + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size, + add_to_split, prefix_keep ); } From c30f56543eb4c7c2be522f5c2a4458da787e5169 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 16:12:03 +0100 Subject: [PATCH 05/25] Whitespace --- ggml/src/ggml-cuda/tri.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index 9ac13e33d4a..ddc0fb64ce2 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -20,7 +20,7 @@ static __global__ void tri_kernel( const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - + // Each thread processes elements at stride blockDim.x for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { const bool keep = ((i0 < split_point) == prefix_keep); From 31b55fabd03e5f038a2222a7e63e167efe58850d Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 16:15:41 +0100 Subject: [PATCH 06/25] Fix strides. --- ggml/src/ggml-cuda/tri.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index ddc0fb64ce2..8e7ed14b03f 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -18,8 +18,8 @@ static __global__ void tri_kernel( return; } - const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); - T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; + T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; // Each thread processes elements at stride blockDim.x for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { From d1ca1c2592c196360b1c20e955fc340665ed9af4 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 18:02:47 +0100 Subject: [PATCH 07/25] Implement double loop --- ggml/src/ggml-cuda/tri.cu | 65 ++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index 8e7ed14b03f..0e7dda79318 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -1,14 +1,13 @@ +#include "ggml-cuda/common.cuh" #include "tri.cuh" #include "ggml.h" -template +template static __global__ void tri_kernel( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const int add_to_split, const bool prefix_keep) { - + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; @@ -21,10 +20,20 @@ static __global__ void tri_kernel( const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; - // Each thread processes elements at stride blockDim.x - for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { - const bool keep = ((i0 < split_point) == prefix_keep); - dst_row[i0] = keep ? src_row[i0] : T(0); + if constexpr (prefix_keep) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = T(0); + } + } else { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = T(0); + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } } } @@ -44,13 +53,39 @@ static void tri_cuda( const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); - tri_kernel<<>>( - src, dst, - ne00, ne01, ne02, ne03, - nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, - nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size, - add_to_split, prefix_keep - ); + if (prefix_keep) { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { // only 0 and 1 supported + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } else { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } } void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { From 5289b530285370604130ca2d43e65a03db4987e7 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 18:03:50 +0100 Subject: [PATCH 08/25] Whitespace --- ggml/src/ggml-cuda/tri.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index 0e7dda79318..a3b1601fe46 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -21,17 +21,17 @@ static __global__ void tri_kernel( T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; if constexpr (prefix_keep) { - for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { dst_row[i0] = src_row[i0]; } - for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { dst_row[i0] = T(0); } } else { - for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { dst_row[i0] = T(0); } - for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { dst_row[i0] = src_row[i0]; } } From f422ba8ee0d581a36a31d95a55138970d07baf90 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 21:33:43 +0100 Subject: [PATCH 09/25] Fix HIP compilation bugs --- ggml/src/ggml-cuda/common.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c747c1c80df..e4d0f2d5708 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -461,10 +461,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ unsigned int get_warp_mask() { +#ifdef __HIP_PLATFORM_AMD__ + return __ballot(1); // HIP equivalent +#else + return __activemask(); // CUDA +#endif +} + template static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { const int lane_id = threadIdx.x % width; - const auto mask = __activemask(); + const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { const T t = __shfl_up_sync(mask, x, offset, width); @@ -478,7 +486,7 @@ static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { template static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { const int lane_id = threadIdx.x % width; - const auto mask = __activemask(); + const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { const float t_x = __shfl_up_sync(mask, a.x, offset, width); @@ -495,12 +503,12 @@ template static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #ifdef FP16_AVAILABLE const int lane_id = threadIdx.x % width; - const auto mask = __activemask(); + const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { - const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + const half2 t = __shfl_up_sync(mask, a, offset, width); if (lane_id >= offset) { - a += t; + a = __hadd2(a, t); } } return a; From df917ccf24874d39a05863b3c8726872ec00d9d0 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Tue, 2 Dec 2025 14:29:56 +0100 Subject: [PATCH 10/25] Optimizations + big case performance tests --- ggml/src/ggml-cuda/cumsum.cu | 89 +++++++++++++++++++++--------------- tests/test-backend-ops.cpp | 3 +- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index e758fd8bdba..d6fc0835002 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -17,53 +17,72 @@ static __global__ void cumsum_kernel( const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { - // Shared memory to store warp sums (always use float for accumulation) - extern __shared__ float shmem[]; + const int tid = threadIdx.x; + const int lane = tid & (WARP_SIZE - 1); + const int warp = tid / WARP_SIZE; + const int warps_per_block = blockDim.x / WARP_SIZE; + + extern __shared__ float smem[]; + float* s_vals = smem; + float* s_warp_sums = smem + blockDim.x; + float* s_carry = smem + blockDim.x + warps_per_block; + float* s_chunk_total = s_carry + 1; + + // Initialize carry + if (tid == 0) { + *s_carry = 0.0f; + } + __syncthreads(); const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - const T * src_row = src + i1 * nb01 + i2*nb02 + i3*nb03; - T * dst_row = dst + i1 * nb1 + i2*nb2 + i3*nb3; + const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03; + T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3; - const int tid = threadIdx.x; - const int lane_id = tid % WARP_SIZE; - - if (tid >= ne00) { - return; - } + for (int64_t start = 0; start < ne00; start += blockDim.x) { + int64_t idx = start + tid; + float val = (idx < ne00) ? static_cast(src_row[idx]) : 0.0f; - // Phase 1: Each thread processes elements at stride blockDim.x - // Compute warp-level prefix sums - for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { - // Load value and compute prefix sum within warp - float val = static_cast(src_row[i0]); + // 1. Warp inclusive scan val = warp_prefix_inclusive_sum(val); - dst_row[i0] = static_cast(val); + s_vals[tid] = val; - // Last thread of warp stores its sum to shared memory at position based on data index - if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { - const int shmem_idx = i0 / WARP_SIZE; - shmem[shmem_idx] = val; + // Store warp total + if (lane == WARP_SIZE - 1) { + s_warp_sums[warp] = val; } - } + __syncthreads(); + + // 2. Exclusive scan of warp sums (warp 0 only) + if (warp == 0) { + float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; + float inc = warp_prefix_inclusive_sum(w); + if (tid < warps_per_block) { + s_warp_sums[tid] = inc - w; // exclusive sum + } + if (tid == warps_per_block - 1) { + *s_chunk_total = inc; // total sum of this chunk + } + } + __syncthreads(); - // Sync once after all warp prefix sums are computed - __syncthreads(); + float carry = *s_carry; + float final_val = s_vals[tid] + s_warp_sums[warp] + carry; + if (idx < ne00) { + dst_row[idx] = static_cast(final_val); + } + __syncthreads(); - // Phase 2: Add the sum of all preceding warp groups to each element - for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { - const int shmem_idx = i0 / WARP_SIZE; - float sum = 0.0f; - for (int j = 0; j < shmem_idx; ++j) { - sum += shmem[j]; + // Update carry for next chunk + if (tid == 0) { + *s_carry += *s_chunk_total; } - dst_row[i0] = static_cast(static_cast(dst_row[i0]) + sum); + __syncthreads(); } } @@ -76,15 +95,13 @@ static void cumsum_cuda( cudaStream_t stream) { dim3 grid_dims(ne01, ne02, ne03); - - // Shared memory size: one float per warp const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; - const size_t shmem_size = num_warps * sizeof(float); - const size_t type_size = sizeof(T); - int block_size = num_warps * WARP_SIZE; block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); dim3 block_dims(block_size, 1, 1); + const int warps_per_block = block_size / WARP_SIZE; + const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); + const size_t type_size = sizeof(T); cumsum_kernel<<>>( src, dst, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 306fa15b923..b16a1bbc5b1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7942,7 +7942,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); - test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 })); for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { From 76382d7908792701a10796839d5a9b98db3ce10b Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Tue, 2 Dec 2025 15:53:23 +0100 Subject: [PATCH 11/25] Implement using CUB with fallback to custom kernel --- ggml/src/ggml-cuda/cumsum.cu | 117 ++++++++++++++++++++++++++++++----- tests/test-backend-ops.cpp | 3 +- 2 files changed, 105 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index d6fc0835002..ebd7579f0c2 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,15 +1,85 @@ #include #include "cumsum.cuh" +#include "ggml-impl.h" + +// Check if CUB is available +#ifdef __has_include +# if __has_include() +# define HAS_CUB_DEVICE_SCAN 1 +# include +# else +# define HAS_CUB_DEVICE_SCAN 0 +# endif +#else +# define HAS_CUB_DEVICE_SCAN 0 +#endif + +#if HAS_CUB_DEVICE_SCAN + +template +static __global__ void cumsum_cub_kernel( + const T* __restrict__ src, + T* __restrict__ dst, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t nb01, int64_t nb02, int64_t nb03, + int64_t nb1, int64_t nb2, int64_t nb3) +{ + using BlockScan = cub::BlockScan; + + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ T block_carry; // carry from previous tile + __shared__ T block_total; // total of current tile -// Kernel to compute cumulative sum along the innermost dimension (ne[0]) -// Each block processes one row (ne[0] elements) -// Algorithm matches Metal implementation: -// 1. Each warp computes prefix sum within itself -// 2. Last thread of each warp stores result in shared memory -// 3. All warps sync -// 4. Each element adds the sum of all preceding warps + const int tid = threadIdx.x; + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + const int64_t i3 = blockIdx.z; + if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { + return; + } + + const T* src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03; + T* dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3; + + if (tid == 0) { + block_carry = 0; + } + __syncthreads(); + + for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) { + int64_t idx = start + tid; + + T x = (idx < ne00) ? src_row[idx] : T(0); + + T inclusive; + BlockScan(temp_storage).InclusiveSum(x, inclusive); + + // Last thread stores total + if (tid == BLOCK_SIZE - 1) { + block_total = inclusive; + } + __syncthreads(); + + T final = inclusive + block_carry; + + if (idx < ne00) { + dst_row[idx] = final; + } + __syncthreads(); + + if (tid == 0) { + block_carry += block_total; + } + __syncthreads(); + } +} + +#endif // HAS_CUB_DEVICE_SCAN + +// Fallback kernel implementation (original) template static __global__ void cumsum_kernel( const T * src, T * dst, @@ -94,6 +164,16 @@ static void cumsum_cuda( const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, cudaStream_t stream) { + const size_t type_size = sizeof(T); + bool use_cub = false; +#if HAS_CUB_DEVICE_SCAN + // Check if we can use CUB (data must be contiguous along innermost dimension) + const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size); + + if (is_contiguous) { + use_cub = true; + } +#endif // HAS_CUB_DEVICE_SCAN dim3 grid_dims(ne01, ne02, ne03); const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; int block_size = num_warps * WARP_SIZE; @@ -101,14 +181,23 @@ static void cumsum_cuda( dim3 block_dims(block_size, 1, 1); const int warps_per_block = block_size / WARP_SIZE; const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); - const size_t type_size = sizeof(T); - cumsum_kernel<<>>( - src, dst, - ne00, ne01, ne02, ne03, - nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, - nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size - ); + if (use_cub) { + cumsum_cub_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + GGML_LOG_ERROR("Running fallback version"); + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } } void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b16a1bbc5b1..43c26e6be95 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7709,6 +7709,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 })); @@ -7941,7 +7942,7 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); - test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 })); From 01d4033ef55ce7c49b6990b9e06a00e35c5e67d0 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Tue, 2 Dec 2025 16:13:55 +0100 Subject: [PATCH 12/25] Remove error message. --- ggml/src/ggml-cuda/cumsum.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index ebd7579f0c2..b6ceb7a1495 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -190,7 +190,6 @@ static void cumsum_cuda( nb1 / type_size, nb2 / type_size, nb3 / type_size ); } else { - GGML_LOG_ERROR("Running fallback version"); cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, From 10a2ea9d70b87a98762c2c5a91a319dfe0f33fbf Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Wed, 3 Dec 2025 15:13:55 +0100 Subject: [PATCH 13/25] Fixes from code review --- ggml/src/ggml-cuda/cumsum.cu | 55 ++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index b6ceb7a1495..f23f7a87fe5 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,35 +1,21 @@ #include - #include "cumsum.cuh" -#include "ggml-impl.h" - -// Check if CUB is available -#ifdef __has_include -# if __has_include() -# define HAS_CUB_DEVICE_SCAN 1 -# include -# else -# define HAS_CUB_DEVICE_SCAN 0 -# endif -#else -# define HAS_CUB_DEVICE_SCAN 0 -#endif -#if HAS_CUB_DEVICE_SCAN +#ifdef GGML_CUDA_USE_CUB +# include template static __global__ void cumsum_cub_kernel( const T* __restrict__ src, T* __restrict__ dst, - int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, - int64_t nb01, int64_t nb02, int64_t nb03, - int64_t nb1, int64_t nb2, int64_t nb3) + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb1, const int64_t nb2, const int64_t nb3) { using BlockScan = cub::BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; __shared__ T block_carry; // carry from previous tile - __shared__ T block_total; // total of current tile const int tid = threadIdx.x; @@ -51,33 +37,40 @@ static __global__ void cumsum_cub_kernel( for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) { int64_t idx = start + tid; - T x = (idx < ne00) ? src_row[idx] : T(0); T inclusive; - BlockScan(temp_storage).InclusiveSum(x, inclusive); + T block_total; + BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total); - // Last thread stores total - if (tid == BLOCK_SIZE - 1) { - block_total = inclusive; - } __syncthreads(); - T final = inclusive + block_carry; + T final_val = inclusive + block_carry; + // store result if (idx < ne00) { - dst_row[idx] = final; + dst_row[idx] = final_val; } + __syncthreads(); if (tid == 0) { block_carry += block_total; } + __syncthreads(); } } - -#endif // HAS_CUB_DEVICE_SCAN +#else +template +static __global__ void cumsum_cub_kernel( + const T* __restrict__ src, + T* __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb1, const int64_t nb2, const int64_t nb3) {} +// empty function to avoid triggering compilation errors on non-CUB paths, just in case compiler doesn't optimize away +#endif // GGML_CUDA_USE_CUB // Fallback kernel implementation (original) template @@ -166,14 +159,14 @@ static void cumsum_cuda( const size_t type_size = sizeof(T); bool use_cub = false; -#if HAS_CUB_DEVICE_SCAN +#ifdef GGML_CUDA_USE_CUB // Check if we can use CUB (data must be contiguous along innermost dimension) const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size); if (is_contiguous) { use_cub = true; } -#endif // HAS_CUB_DEVICE_SCAN +#endif // GGML_CUDA_USE_CUB dim3 grid_dims(ne01, ne02, ne03); const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; int block_size = num_warps * WARP_SIZE; From 7a83b056cb1ae8d55f01dcfac3fb02128d0d21b0 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Wed, 3 Dec 2025 21:54:09 +0100 Subject: [PATCH 14/25] Comment out CPU-unsupported F16/BF16 cases to fix CI --- ggml/src/ggml-cuda/cumsum.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index f23f7a87fe5..457eabe32af 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -208,7 +208,8 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { stream ); } break; - case GGML_TYPE_F16: + // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms + /*case GGML_TYPE_F16: { cumsum_cuda( (const half *)src0->data, (half *)dst->data, @@ -227,7 +228,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream ); - } break; + } break;*/ default: GGML_ABORT("fatal error"); } From bbe374353d6768a8c5088cc0c3b82755c0a89b2d Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 13:04:51 +0100 Subject: [PATCH 15/25] Fine, you win :P --- ggml/src/ggml-cuda/cumsum.cu | 3 ++- ggml/src/ggml-cuda/tri.cu | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 457eabe32af..542f69aa7d8 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,5 +1,6 @@ #include #include "cumsum.cuh" +#include "convert.cuh" #ifdef GGML_CUDA_USE_CUB # include @@ -109,7 +110,7 @@ static __global__ void cumsum_kernel( for (int64_t start = 0; start < ne00; start += blockDim.x) { int64_t idx = start + tid; - float val = (idx < ne00) ? static_cast(src_row[idx]) : 0.0f; + float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f; // 1. Warp inclusive scan val = warp_prefix_inclusive_sum(val); diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index a3b1601fe46..dfe2b1fcc5a 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -1,4 +1,5 @@ -#include "ggml-cuda/common.cuh" +#include "common.cuh" +#include "convert.cuh" #include "tri.cuh" #include "ggml.h" @@ -25,11 +26,11 @@ static __global__ void tri_kernel( dst_row[i0] = src_row[i0]; } for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { - dst_row[i0] = T(0); + dst_row[i0] = ggml_cuda_cast(0.0f); } } else { for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { - dst_row[i0] = T(0); + dst_row[i0] = ggml_cuda_cast(0.0f); } for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { dst_row[i0] = src_row[i0]; From 069413abc76889b11eba1b1e06266d0e17bb2c19 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 14:45:07 +0100 Subject: [PATCH 16/25] Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS --- ggml/src/ggml-cuda/cumsum.cu | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 542f69aa7d8..4ab4c793b2c 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,9 +1,12 @@ #include #include "cumsum.cuh" #include "convert.cuh" +#include "ggml.h" #ifdef GGML_CUDA_USE_CUB # include +#endif + template static __global__ void cumsum_cub_kernel( @@ -11,8 +14,8 @@ static __global__ void cumsum_cub_kernel( T* __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb1, const int64_t nb2, const int64_t nb3) -{ + const int64_t nb1, const int64_t nb2, const int64_t nb3) { +#ifdef GGML_CUDA_USE_CUB using BlockScan = cub::BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; @@ -61,17 +64,10 @@ static __global__ void cumsum_cub_kernel( __syncthreads(); } -} #else -template -static __global__ void cumsum_cub_kernel( - const T* __restrict__ src, - T* __restrict__ dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb1, const int64_t nb2, const int64_t nb3) {} -// empty function to avoid triggering compilation errors on non-CUB paths, just in case compiler doesn't optimize away -#endif // GGML_CUDA_USE_CUB + NO_DEVICE_CODE; +#endif +} // Fallback kernel implementation (original) template @@ -81,6 +77,8 @@ static __global__ void cumsum_kernel( const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + GGML_UNUSED_VARS(nb00, nb0); + const int tid = threadIdx.x; const int lane = tid & (WARP_SIZE - 1); const int warp = tid / WARP_SIZE; @@ -138,7 +136,7 @@ static __global__ void cumsum_kernel( float carry = *s_carry; float final_val = s_vals[tid] + s_warp_sums[warp] + carry; if (idx < ne00) { - dst_row[idx] = static_cast(final_val); + dst_row[idx] = ggml_cuda_cast(final_val); } __syncthreads(); From 5aa7438e2fd2406d981b840ea586c0fd75553ef2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 15:14:08 +0100 Subject: [PATCH 17/25] Vary warp-size based on physical warp size --- ggml/src/ggml-cuda/cumsum.cu | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 4ab4c793b2c..74e20e53401 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -3,11 +3,16 @@ #include "convert.cuh" #include "ggml.h" +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +# define CUMSUM_WARP_SIZE 64 +#else +# define CUMSUM_WARP_SIZE 32 +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + #ifdef GGML_CUDA_USE_CUB # include #endif - template static __global__ void cumsum_cub_kernel( const T* __restrict__ src, @@ -80,9 +85,9 @@ static __global__ void cumsum_kernel( GGML_UNUSED_VARS(nb00, nb0); const int tid = threadIdx.x; - const int lane = tid & (WARP_SIZE - 1); - const int warp = tid / WARP_SIZE; - const int warps_per_block = blockDim.x / WARP_SIZE; + const int lane = tid & (CUMSUM_WARP_SIZE - 1); + const int warp = tid / CUMSUM_WARP_SIZE; + const int warps_per_block = blockDim.x / CUMSUM_WARP_SIZE; extern __shared__ float smem[]; float* s_vals = smem; @@ -115,7 +120,7 @@ static __global__ void cumsum_kernel( s_vals[tid] = val; // Store warp total - if (lane == WARP_SIZE - 1) { + if (lane == CUMSUM_WARP_SIZE - 1) { s_warp_sums[warp] = val; } __syncthreads(); @@ -167,11 +172,11 @@ static void cumsum_cuda( } #endif // GGML_CUDA_USE_CUB dim3 grid_dims(ne01, ne02, ne03); - const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; - int block_size = num_warps * WARP_SIZE; + const int num_warps = (ne00 + CUMSUM_WARP_SIZE - 1) / CUMSUM_WARP_SIZE; + int block_size = num_warps * CUMSUM_WARP_SIZE; block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); dim3 block_dims(block_size, 1, 1); - const int warps_per_block = block_size / WARP_SIZE; + const int warps_per_block = block_size / CUMSUM_WARP_SIZE; const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); if (use_cub) { From 579eba6e040bb080e28eecd585fc74f2bfbff82b Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 15:15:10 +0100 Subject: [PATCH 18/25] Add GGML_UNUSED_VARS in tri as well --- ggml/src/ggml-cuda/tri.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index dfe2b1fcc5a..a5444ba01e7 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -14,6 +14,8 @@ static __global__ void tri_kernel( const int64_t i1 = blockIdx.x; const int64_t split_point = i1 + add_to_split; + GGML_UNUSED_VARS(nb00, nb0); + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } From 08b3f2d21b606a5be02f7d32bc6a22132a9f9944 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 15:36:05 +0100 Subject: [PATCH 19/25] Use constexpr and call prefix_inclusive with warp_size template param --- ggml/src/ggml-cuda/common.cuh | 9 +++++++++ ggml/src/ggml-cuda/cumsum.cu | 27 ++++++++++++--------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e4d0f2d5708..18b04c3b54f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -319,6 +319,15 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) } +static constexpr __host__ int ggml_cuda_get_physical_warp_size_host() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} + + // Maximum number of bytes that can be copied in a single instruction. static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() { #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 74e20e53401..334afc6d169 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,14 +1,9 @@ #include #include "cumsum.cuh" #include "convert.cuh" +#include "ggml-cuda/common.cuh" #include "ggml.h" -#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) -# define CUMSUM_WARP_SIZE 64 -#else -# define CUMSUM_WARP_SIZE 32 -#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) - #ifdef GGML_CUDA_USE_CUB # include #endif @@ -85,9 +80,10 @@ static __global__ void cumsum_kernel( GGML_UNUSED_VARS(nb00, nb0); const int tid = threadIdx.x; - const int lane = tid & (CUMSUM_WARP_SIZE - 1); - const int warp = tid / CUMSUM_WARP_SIZE; - const int warps_per_block = blockDim.x / CUMSUM_WARP_SIZE; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int lane = tid & (warp_size - 1); + const int warp = tid / warp_size; + const int warps_per_block = blockDim.x / warp_size; extern __shared__ float smem[]; float* s_vals = smem; @@ -116,11 +112,11 @@ static __global__ void cumsum_kernel( float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f; // 1. Warp inclusive scan - val = warp_prefix_inclusive_sum(val); + val = warp_prefix_inclusive_sum(val); s_vals[tid] = val; // Store warp total - if (lane == CUMSUM_WARP_SIZE - 1) { + if (lane == warp_size - 1) { s_warp_sums[warp] = val; } __syncthreads(); @@ -128,7 +124,7 @@ static __global__ void cumsum_kernel( // 2. Exclusive scan of warp sums (warp 0 only) if (warp == 0) { float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; - float inc = warp_prefix_inclusive_sum(w); + float inc = warp_prefix_inclusive_sum(w); if (tid < warps_per_block) { s_warp_sums[tid] = inc - w; // exclusive sum } @@ -172,11 +168,12 @@ static void cumsum_cuda( } #endif // GGML_CUDA_USE_CUB dim3 grid_dims(ne01, ne02, ne03); - const int num_warps = (ne00 + CUMSUM_WARP_SIZE - 1) / CUMSUM_WARP_SIZE; - int block_size = num_warps * CUMSUM_WARP_SIZE; + constexpr int warp_size = ggml_cuda_get_physical_warp_size_host(); + const int num_warps = (ne00 + warp_size - 1) / warp_size; + int block_size = num_warps * warp_size; block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); dim3 block_dims(block_size, 1, 1); - const int warps_per_block = block_size / CUMSUM_WARP_SIZE; + const int warps_per_block = block_size / warp_size; const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); if (use_cub) { From 9cd0eff1ab149084da8698023c7cf00c70391aba Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 4 Dec 2025 15:36:57 +0100 Subject: [PATCH 20/25] Update ggml/src/ggml-cuda/cumsum.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/cumsum.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 334afc6d169..5151502c727 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -72,10 +72,10 @@ static __global__ void cumsum_cub_kernel( // Fallback kernel implementation (original) template static __global__ void cumsum_kernel( - const T * src, T * dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { GGML_UNUSED_VARS(nb00, nb0); From 9574264c1b921c547019342f322ae5cbb62517c7 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Thu, 4 Dec 2025 15:38:24 +0100 Subject: [PATCH 21/25] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/cumsum.cu | 36 ++++++++++++++++++------------------ ggml/src/ggml-cuda/tri.cu | 20 ++++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 5151502c727..6263291af47 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -6,15 +6,15 @@ #ifdef GGML_CUDA_USE_CUB # include -#endif +#endif // GGML_CUDA_USE_CUB template static __global__ void cumsum_cub_kernel( - const T* __restrict__ src, - T* __restrict__ dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const T * __restrict__ src, + T * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb1, const int64_t nb2, const int64_t nb3) { #ifdef GGML_CUDA_USE_CUB using BlockScan = cub::BlockScan; @@ -31,8 +31,8 @@ static __global__ void cumsum_cub_kernel( return; } - const T* src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03; - T* dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3; + const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; if (tid == 0) { block_carry = 0; @@ -66,7 +66,7 @@ static __global__ void cumsum_cub_kernel( } #else NO_DEVICE_CODE; -#endif +#endif // GGML_CUDA_USE_CUB } // Fallback kernel implementation (original) @@ -86,10 +86,10 @@ static __global__ void cumsum_kernel( const int warps_per_block = blockDim.x / warp_size; extern __shared__ float smem[]; - float* s_vals = smem; - float* s_warp_sums = smem + blockDim.x; - float* s_carry = smem + blockDim.x + warps_per_block; - float* s_chunk_total = s_carry + 1; + float * s_vals = smem; + float * s_warp_sums = smem + blockDim.x; + float * s_carry = smem + blockDim.x + warps_per_block; + float * s_chunk_total = s_carry + 1; // Initialize carry if (tid == 0) { @@ -151,11 +151,11 @@ static __global__ void cumsum_kernel( template static void cumsum_cuda( - const T * src, T * dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - cudaStream_t stream) { + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { const size_t type_size = sizeof(T); bool use_cub = false; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index a5444ba01e7..44156b63e70 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -5,10 +5,10 @@ template static __global__ void tri_kernel( - const T * src, T * dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; @@ -42,12 +42,12 @@ static __global__ void tri_kernel( template static void tri_cuda( - const T * src, T * dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const ggml_tri_type ttype, - cudaStream_t stream) { + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); dim3 grid_dims(ne01, ne02, ne03); From efd619a60bbb03d8e4334f5eafd0d155ac56de2e Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 15:39:22 +0100 Subject: [PATCH 22/25] Change to tid % warp_size --- ggml/src/ggml-cuda/cumsum.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 6263291af47..05802819f91 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -81,7 +81,7 @@ static __global__ void cumsum_kernel( const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int lane = tid & (warp_size - 1); + const int lane = tid % warp_size; const int warp = tid / warp_size; const int warps_per_block = blockDim.x / warp_size; From 86a0853f7eb8d830b64d127ddff8921213d9d68f Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 16:30:03 +0100 Subject: [PATCH 23/25] Fix strides; hardcode mask; add ggml_lane_mask_t --- ggml/src/ggml-cuda/common.cuh | 33 ++++++++++++++++----------------- ggml/src/ggml-cuda/cumsum.cu | 6 +++--- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 18b04c3b54f..d970610d6b5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -319,15 +319,6 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) } -static constexpr __host__ int ggml_cuda_get_physical_warp_size_host() { -#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) - return 64; -#else - return 32; -#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) -} - - // Maximum number of bytes that can be copied in a single instruction. static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() { #ifdef GGML_USE_HIP @@ -470,7 +461,13 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -static __device__ __forceinline__ unsigned int get_warp_mask() { +#ifdef __HIP_PLATFORM_AMD__ +typedef uint64_t ggml_lane_mask_t; +#else +typedef uint32_t ggml_lane_mask_t; +#endif // __HIP_PLATFORM_AMD__ + +static __device__ __forceinline__ ggml_lane_mask_t get_warp_mask() { #ifdef __HIP_PLATFORM_AMD__ return __ballot(1); // HIP equivalent #else @@ -481,10 +478,9 @@ static __device__ __forceinline__ unsigned int get_warp_mask() { template static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { const int lane_id = threadIdx.x % width; - const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { - const T t = __shfl_up_sync(mask, x, offset, width); + const T t = __shfl_up_sync(0xffffffff, x, offset, width); if (lane_id >= offset) { x += t; } @@ -495,11 +491,10 @@ static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { template static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { const int lane_id = threadIdx.x % width; - const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { - const float t_x = __shfl_up_sync(mask, a.x, offset, width); - const float t_y = __shfl_up_sync(mask, a.y, offset, width); + const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width); + const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width); if (lane_id >= offset) { a.x += t_x; a.y += t_y; @@ -512,10 +507,9 @@ template static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #ifdef FP16_AVAILABLE const int lane_id = threadIdx.x % width; - const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { - const half2 t = __shfl_up_sync(mask, a, offset, width); + const half2 t = __shfl_up_sync(0xffffffff, a, offset, width); if (lane_id >= offset) { a = __hadd2(a, t); } @@ -951,6 +945,11 @@ const ggml_cuda_device_info & ggml_cuda_info(); void ggml_cuda_set_device(int device); int ggml_cuda_get_device(); +static __host__ int ggml_cuda_get_physical_warp_size_host() { + const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()]; + return info.warp_size; +} + struct ggml_cuda_pool { virtual ~ggml_cuda_pool() = default; diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 05802819f91..22e7888a263 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -13,8 +13,8 @@ static __global__ void cumsum_cub_kernel( const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s1, const int64_t s2, const int64_t s3) { #ifdef GGML_CUDA_USE_CUB using BlockScan = cub::BlockScan; @@ -168,7 +168,7 @@ static void cumsum_cuda( } #endif // GGML_CUDA_USE_CUB dim3 grid_dims(ne01, ne02, ne03); - constexpr int warp_size = ggml_cuda_get_physical_warp_size_host(); + const int warp_size = ggml_cuda_get_physical_warp_size_host(); const int num_warps = (ne00 + warp_size - 1) / warp_size; int block_size = num_warps * warp_size; block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); From de45c6323095270b9ef3fc204b71d900b38b47b0 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 16:59:34 +0100 Subject: [PATCH 24/25] Missing renames, remove unused get_warp_mask(), explicit calls to ggml_cuda_info() --- ggml/src/ggml-cuda/common.cuh | 19 ------------------- ggml/src/ggml-cuda/cumsum.cu | 27 ++++++++++++++------------- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d970610d6b5..a744f0eb110 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -461,20 +461,6 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -#ifdef __HIP_PLATFORM_AMD__ -typedef uint64_t ggml_lane_mask_t; -#else -typedef uint32_t ggml_lane_mask_t; -#endif // __HIP_PLATFORM_AMD__ - -static __device__ __forceinline__ ggml_lane_mask_t get_warp_mask() { -#ifdef __HIP_PLATFORM_AMD__ - return __ballot(1); // HIP equivalent -#else - return __activemask(); // CUDA -#endif -} - template static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { const int lane_id = threadIdx.x % width; @@ -945,11 +931,6 @@ const ggml_cuda_device_info & ggml_cuda_info(); void ggml_cuda_set_device(int device); int ggml_cuda_get_device(); -static __host__ int ggml_cuda_get_physical_warp_size_host() { - const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()]; - return info.warp_size; -} - struct ggml_cuda_pool { virtual ~ggml_cuda_pool() = default; diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 22e7888a263..9c4980901b3 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -74,10 +74,10 @@ template static __global__ void cumsum_kernel( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) { - GGML_UNUSED_VARS(nb00, nb0); + GGML_UNUSED_VARS(s00, s0); const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); @@ -104,8 +104,8 @@ static __global__ void cumsum_kernel( return; } - const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03; - T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3; + const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; for (int64_t start = 0; start < ne00; start += blockDim.x) { int64_t idx = start + tid; @@ -153,22 +153,23 @@ template static void cumsum_cuda( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3, cudaStream_t stream) { const size_t type_size = sizeof(T); bool use_cub = false; #ifdef GGML_CUDA_USE_CUB // Check if we can use CUB (data must be contiguous along innermost dimension) - const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size); + const bool is_contiguous = (s00 == type_size) && (s0 == type_size); if (is_contiguous) { use_cub = true; } #endif // GGML_CUDA_USE_CUB dim3 grid_dims(ne01, ne02, ne03); - const int warp_size = ggml_cuda_get_physical_warp_size_host(); + const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()]; + const int warp_size = info.warp_size; const int num_warps = (ne00 + warp_size - 1) / warp_size; int block_size = num_warps * warp_size; block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); @@ -180,15 +181,15 @@ static void cumsum_cuda( cumsum_cub_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - nb01 / type_size, nb02 / type_size, nb03 / type_size, - nb1 / type_size, nb2 / type_size, nb3 / type_size + s01 / type_size, s02 / type_size, s03 / type_size, + s1 / type_size, s2 / type_size, s3 / type_size ); } else { cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, - nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + s00 / type_size, s01 / type_size, s02 / type_size, s03 / type_size, + s0 / type_size, s1 / type_size, s2 / type_size, s3 / type_size ); } } From 8a7375c8675ca0dd1e5f48a6746dc64d7f0a3454 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Thu, 4 Dec 2025 17:07:09 +0100 Subject: [PATCH 25/25] Too hasty... --- ggml/src/ggml-cuda/cumsum.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 9c4980901b3..d2f2def8bdc 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -153,15 +153,15 @@ template static void cumsum_cuda( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, cudaStream_t stream) { const size_t type_size = sizeof(T); bool use_cub = false; #ifdef GGML_CUDA_USE_CUB // Check if we can use CUB (data must be contiguous along innermost dimension) - const bool is_contiguous = (s00 == type_size) && (s0 == type_size); + const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size); if (is_contiguous) { use_cub = true; @@ -181,15 +181,15 @@ static void cumsum_cuda( cumsum_cub_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - s01 / type_size, s02 / type_size, s03 / type_size, - s1 / type_size, s2 / type_size, s3 / type_size + nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb1 / type_size, nb2 / type_size, nb3 / type_size ); } else { cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - s00 / type_size, s01 / type_size, s02 / type_size, s03 / type_size, - s0 / type_size, s1 / type_size, s2 / type_size, s3 / type_size + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size ); } }