Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions ggml/src/ggml-cuda/argsort.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,81 @@
#include "argsort.cuh"

#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
using namespace cub;
#endif // GGML_CUDA_USE_CUB

static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
const int col = blockIdx.x * blockDim.x + threadIdx.x;
const int row = blockIdx.y;

if (col < ncols && row < nrows) {
indices[row * ncols + col] = col;
}
}

static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx <= nrows) {
offsets[idx] = idx * ncols;
}
}

#ifdef GGML_CUDA_USE_CUB
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);

int * temp_indices = temp_indices_alloc.get();
float * temp_keys = temp_keys_alloc.get();
int * d_offsets = offsets_alloc.get();

static const int block_size = 256;
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);

const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);

cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);

size_t temp_storage_bytes = 0;

if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
stream);
} else {
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
sizeof(float) * 8, stream);
}

ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();

if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
stream);
} else {
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
0, sizeof(float) * 8, stream);
}
}
#endif // GGML_CUDA_USE_CUB

// Bitonic sort implementation
template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
Expand Down Expand Up @@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
return n;
}

static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
static void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);

Expand All @@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);

if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else {
GGML_ABORT("fatal error");
}
Expand All @@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];

argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
#ifdef GGML_CUDA_USE_CUB
const int ncols_pad = next_power_of_2(ncols);
const size_t shared_mem = ncols_pad * sizeof(int);
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;

if (shared_mem > max_shared_mem || ncols > 1024) {
ggml_cuda_pool & pool = ctx.pool();
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
} else {
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
}
#else
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
#endif
}
5 changes: 4 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3642,8 +3642,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
#ifndef GGML_CUDA_USE_CUB
return op->src[0]->ne[0] <= 1024;
#else
return true;
#endif
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
Expand Down
Loading