Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d138a03
Add support for CUMSUM and TRI for CUDA.
pwilkin Nov 28, 2025
67207d2
Minor optimizations.
pwilkin Nov 28, 2025
fab0029
Correct warp_prefix_inclusive_sum in float2 variant to return float2
pwilkin Nov 28, 2025
51c40a5
Optimize TRI
pwilkin Dec 1, 2025
c30f565
Whitespace
pwilkin Dec 1, 2025
31b55fa
Fix strides.
pwilkin Dec 1, 2025
d1ca1c2
Implement double loop
pwilkin Dec 1, 2025
5289b53
Whitespace
pwilkin Dec 1, 2025
f422ba8
Fix HIP compilation bugs
pwilkin Dec 1, 2025
df917cc
Optimizations + big case performance tests
pwilkin Dec 2, 2025
76382d7
Implement using CUB with fallback to custom kernel
pwilkin Dec 2, 2025
01d4033
Remove error message.
pwilkin Dec 2, 2025
10a2ea9
Fixes from code review
pwilkin Dec 3, 2025
7a83b05
Comment out CPU-unsupported F16/BF16 cases to fix CI
pwilkin Dec 3, 2025
bbe3743
Fine, you win :P
pwilkin Dec 4, 2025
069413a
Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS
pwilkin Dec 4, 2025
5aa7438
Vary warp-size based on physical warp size
pwilkin Dec 4, 2025
579eba6
Add GGML_UNUSED_VARS in tri as well
pwilkin Dec 4, 2025
08b3f2d
Use constexpr and call prefix_inclusive with warp_size template param
pwilkin Dec 4, 2025
9cd0eff
Update ggml/src/ggml-cuda/cumsum.cu
pwilkin Dec 4, 2025
9574264
Apply suggestions from code review
pwilkin Dec 4, 2025
efd619a
Change to tid % warp_size
pwilkin Dec 4, 2025
86a0853
Fix strides; hardcode mask; add ggml_lane_mask_t
pwilkin Dec 4, 2025
de45c63
Missing renames, remove unused get_warp_mask(), explicit calls to ggm…
pwilkin Dec 4, 2025
8a7375c
Too hasty...
pwilkin Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,53 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}

template<typename T, int width = WARP_SIZE>
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
const int lane_id = threadIdx.x % width;
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const T t = __shfl_up_sync(0xffffffff, x, offset, width);
if (lane_id >= offset) {
x += t;
}
}
return x;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
const int lane_id = threadIdx.x % width;
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
if (lane_id >= offset) {
a.x += t_x;
a.y += t_y;
}
}
return a;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
#ifdef FP16_AVAILABLE
const int lane_id = threadIdx.x % width;
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
if (lane_id >= offset) {
a = __hadd2(a, t);
}
}
return a;

#else
NO_DEVICE_CODE;
return a;
#endif // FP16_AVAILABLE
}

static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE

Expand Down
237 changes: 237 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
#include <algorithm>
#include "cumsum.cuh"
#include "convert.cuh"
#include "ggml-cuda/common.cuh"
#include "ggml.h"

#ifdef GGML_CUDA_USE_CUB
# include <cub/device/device_scan.cuh>
#endif // GGML_CUDA_USE_CUB

template<typename T, int BLOCK_SIZE>
static __global__ void cumsum_cub_kernel(
const T * __restrict__ src,
T * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s1, const int64_t s2, const int64_t s3) {
#ifdef GGML_CUDA_USE_CUB
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;

__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ T block_carry; // carry from previous tile

const int tid = threadIdx.x;

const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int64_t i3 = blockIdx.z;

if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
return;
}

const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;

if (tid == 0) {
block_carry = 0;
}
__syncthreads();

for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
int64_t idx = start + tid;
T x = (idx < ne00) ? src_row[idx] : T(0);

T inclusive;
T block_total;
BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);

__syncthreads();

T final_val = inclusive + block_carry;

// store result
if (idx < ne00) {
dst_row[idx] = final_val;
}

__syncthreads();

if (tid == 0) {
block_carry += block_total;
}

__syncthreads();
}
#else
NO_DEVICE_CODE;
#endif // GGML_CUDA_USE_CUB
}

// Fallback kernel implementation (original)
template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {

GGML_UNUSED_VARS(s00, s0);

const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int lane = tid % warp_size;
const int warp = tid / warp_size;
const int warps_per_block = blockDim.x / warp_size;

extern __shared__ float smem[];
float * s_vals = smem;
float * s_warp_sums = smem + blockDim.x;
float * s_carry = smem + blockDim.x + warps_per_block;
float * s_chunk_total = s_carry + 1;

// Initialize carry
if (tid == 0) {
*s_carry = 0.0f;
}
__syncthreads();

const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;

for (int64_t start = 0; start < ne00; start += blockDim.x) {
int64_t idx = start + tid;
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;

// 1. Warp inclusive scan
val = warp_prefix_inclusive_sum<T, warp_size>(val);
s_vals[tid] = val;

// Store warp total
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();

// 2. Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
if (tid < warps_per_block) {
s_warp_sums[tid] = inc - w; // exclusive sum
}
if (tid == warps_per_block - 1) {
*s_chunk_total = inc; // total sum of this chunk
}
}
__syncthreads();

float carry = *s_carry;
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
if (idx < ne00) {
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
}
__syncthreads();

// Update carry for next chunk
if (tid == 0) {
*s_carry += *s_chunk_total;
}
__syncthreads();
}
}

template<typename T>
static void cumsum_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
cudaStream_t stream) {

const size_t type_size = sizeof(T);
bool use_cub = false;
#ifdef GGML_CUDA_USE_CUB
// Check if we can use CUB (data must be contiguous along innermost dimension)
const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);

if (is_contiguous) {
use_cub = true;
}
#endif // GGML_CUDA_USE_CUB
dim3 grid_dims(ne01, ne02, ne03);
const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
const int warp_size = info.warp_size;
const int num_warps = (ne00 + warp_size - 1) / warp_size;
int block_size = num_warps * warp_size;
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
dim3 block_dims(block_size, 1, 1);
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);

if (use_cub) {
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == dst->type);
switch(src0->type) {
case GGML_TYPE_F32:
{
cumsum_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
// We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
/*case GGML_TYPE_F16:
{
cumsum_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_BF16:
{
cumsum_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;*/
default:
GGML_ABORT("fatal error");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_CUMSUM_BLOCK_SIZE 256

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/tri.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml.h"

#include <algorithm>
Expand Down Expand Up @@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
Expand Down Expand Up @@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
Expand Down
Loading
Loading