diff --git a/docs/ops.md b/docs/ops.md index bd26c0eb45a85..5df72d25015d2 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -22,6 +22,7 @@ Legend: | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | @@ -41,6 +42,7 @@ Legend: | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | +| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | @@ -82,6 +84,7 @@ Legend: | ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | +| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -108,5 +111,6 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | +| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/CPU.csv b/docs/ops/CPU.csv index 21e0d1b3c9117..1820028c9a2fe 100644 --- a/docs/ops/CPU.csv +++ b/docs/ops/CPU.csv @@ -59,6 +59,14 @@ "CPU","EXP","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" +"CPU","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" "CPU","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" "CPU","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" "CPU","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" @@ -119,6 +127,14 @@ "CPU","EXP","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","CPU" "CPU","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU" +"CPU","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU" +"CPU","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU" "CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=0","support","1","yes","CPU" "CPU","REGLU","type=f16,ne_a=[5,7,11,13],v=0,swapped=0","support","1","yes","CPU" "CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=1","support","1","yes","CPU" diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 60c6b63d05978..d948b00cc7f30 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -577,6 +577,10 @@ extern "C" { GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_XIELU, + GGML_UNARY_OP_FLOOR, + GGML_UNARY_OP_CEIL, + GGML_UNARY_OP_ROUND, + GGML_UNARY_OP_TRUNC, GGML_UNARY_OP_COUNT, }; @@ -1151,6 +1155,46 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_floor( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_floor_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_ceil_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_round_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + /** + * Truncates the fractional part of each element in the tensor (towards zero). + * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0 + * Similar to std::trunc in C/C++. + */ + + GGML_API struct ggml_tensor * ggml_trunc( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_trunc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + + // xIELU activation function // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ba2a36d999128..29c870600ba93 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2184,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 1c43865ff65fc..b52f0f8472cfe 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8993,6 +8993,22 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_FLOOR: + { + ggml_compute_forward_floor(params, dst); + } break; + case GGML_UNARY_OP_CEIL: + { + ggml_compute_forward_ceil(params, dst); + } break; + case GGML_UNARY_OP_ROUND: + { + ggml_compute_forward_round(params, dst); + } break; + case GGML_UNARY_OP_TRUNC: + { + ggml_compute_forward_trunc(params, dst); + } break; case GGML_UNARY_OP_XIELU: { ggml_compute_forward_xielu(params, dst); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index cf1a4615d042c..a047537b34f78 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -73,6 +73,22 @@ static inline float op_log(float x) { return logf(x); } +static inline float op_floor(float x) { + return floorf(x); +} + +static inline float op_ceil(float x) { + return ceilf(x); +} + +static inline float op_round(float x) { + return roundf(x); +} + +static inline float op_trunc(float x) { + return truncf(x); +} + template static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) { constexpr auto src0_to_f32 = type_conversion_table::to_f32; @@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * unary_op(params, dst); } +void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + +void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { const float alpha_n = ggml_get_op_params_f32(dst, 1); const float alpha_p = ggml_get_op_params_f32(dst, 2); diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index 697c1e0da0ace..fa45d9f0e636f 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2bce1375ba3c0..86f1c31afd7a6 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1144,9 +1144,13 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "EXP", "GELU_ERF", "XIELU", + "FLOOR", + "CEIL", + "ROUND", + "TRUNC", }; -static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); +static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2749,6 +2753,62 @@ static struct ggml_tensor * ggml_glu_impl( return result; } +// ggml_floor + +struct ggml_tensor * ggml_floor( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR); +} + +struct ggml_tensor * ggml_floor_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR); +} + +// ggml_ceil + +struct ggml_tensor * ggml_ceil( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL); +} + +struct ggml_tensor * ggml_ceil_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL); +} + +//ggml_round + +struct ggml_tensor * ggml_round( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND); +} + +struct ggml_tensor * ggml_round_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND); +} + +//ggml_trunc + +struct ggml_tensor * ggml_trunc( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC); +} + +struct ggml_tensor * ggml_trunc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC); +} + struct ggml_tensor * ggml_glu( struct ggml_context * ctx, struct ggml_tensor * a,