From d4d1c05b32f6702c8ba9490dc58a6dc80400276e Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Sat, 1 Nov 2025 16:22:41 +0530 Subject: [PATCH] CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops --- ggml/src/ggml-cuda/ggml-cuda.cu | 16 ++++++++++++++++ ggml/src/ggml-cuda/unary.cu | 32 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 8 ++++++++ 3 files changed, 56 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 61a8f1df87de1..5667ec0c4d709 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2499,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_XIELU: ggml_cuda_op_xielu(ctx, dst); break; + case GGML_UNARY_OP_FLOOR: + ggml_cuda_op_floor(ctx, dst); + break; + case GGML_UNARY_OP_CEIL: + ggml_cuda_op_ceil(ctx, dst); + break; + case GGML_UNARY_OP_ROUND: + ggml_cuda_op_round(ctx, dst); + break; + case GGML_UNARY_OP_TRUNC: + ggml_cuda_op_trunc(ctx, dst); + break; default: return false; } @@ -3769,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous(op->src[0]); default: return false; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 5f0d3a6726aef..c1dc6ddbf8f81 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) { return (x > 0.f) ? x : expm1f(x); } +static __device__ __forceinline__ float op_floor(float x) { + return floorf(x); +} + +static __device__ __forceinline__ float op_ceil(float x) { + return ceilf(x); +} + +static __device__ __forceinline__ float op_round(float x) { + return round(x); +} + +static __device__ __forceinline__ float op_trunc(float x) { + return trunc(x); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } + +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 6c738cefecfd2..2800c75ba3f7a 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);