diff --git a/docs/ops.md b/docs/ops.md index 938efac815fc0..226cd935d698a 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -22,7 +22,7 @@ Legend: | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | @@ -42,7 +42,7 @@ Legend: | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | -| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | @@ -84,7 +84,7 @@ Legend: | ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -111,6 +111,6 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | -| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index d7efa43cdf3da..bc6319f51fa8c 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -31,6 +31,14 @@ "SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL" "SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL" +"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" +"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" +"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" +"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" "SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" @@ -95,6 +103,14 @@ "SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL" "SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL" +"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" +"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" +"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" +"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL" +"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL" "SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" "SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index ea252577280d5..298c2a6ccd5fc 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -3263,27 +3263,27 @@ "Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=0","support","1","yes","Vulkan" "Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=1","support","1","yes","Vulkan" "Vulkan0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan" -"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan" -"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan" -"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan" +"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan" +"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan" +"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan" "Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","Vulkan" "Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","Vulkan" "Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan" diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 58f5125c9cf6e..810995d0cbf74 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) { return x < static_cast(min_val) ? static_cast(min_val) : (x > static_cast(max_val) ? static_cast(max_val) : x); } +template +static __dpct_inline__ T op_floor(T x) { + return sycl::floor(x); +} + +template +static __dpct_inline__ T op_ceil(T x) { + return sycl::ceil(x); +} + +template +static __dpct_inline__ T op_round(T x) { + return sycl::round(x); +} + +template +static __dpct_inline__ T op_trunc(T x) { + return sycl::trunc(x); +} + template static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { SYCL_GLOBAL_ID_LOOP(k, item_ct1) { @@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl: } } +template +static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_floor(x[i]); + } +} + +template +static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_ceil(x[i]); + } +} + +template +static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_round(x[i]); + } +} + +template +static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_trunc(x[i]); + } +} + template static void upscale(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, @@ -897,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens }, min_val, max_val); } +static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); @@ -1122,3 +1222,23 @@ void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0); ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst); } + +void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_floor(ctx, dst); +} + +void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_ceil(ctx, dst); +} + +void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_round(ctx, dst); +} + +void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_trunc(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index ed96c55f75a7a..fcf93295cb215 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -80,6 +80,10 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a7e077ec8ebe0..1a007ffe2bca6 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3698,6 +3698,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_sycl_elu(ctx, dst); break; + case GGML_UNARY_OP_FLOOR: + ggml_sycl_floor(ctx, dst); + break; + case GGML_UNARY_OP_CEIL: + ggml_sycl_ceil(ctx, dst); + break; + case GGML_UNARY_OP_ROUND: + ggml_sycl_round(ctx, dst); + break; + case GGML_UNARY_OP_TRUNC: + ggml_sycl_trunc(ctx, dst); + break; default: return false; } @@ -4262,6 +4274,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_ABS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: #if defined (GGML_SYCL_F16) return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); #else diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 82bb55ea0e184..fa98db2982ce7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3759,6 +3759,130 @@ struct test_clamp : public test_case { } }; +// GGML_OP_FLOOR +struct test_floor : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_floor(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 2, 2, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_floor(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.0f, 10.0f); + } + } +}; + +// GGML_OP_CEIL +struct test_ceil : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_ceil(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 2, 2, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_ceil(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.0f, 10.0f); + } + } +}; + +// GGML_OP_ROUND +struct test_round : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_round(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 2, 2, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_round(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.0f, 10.0f); + } + } +}; + +// GGML_OP_TRUNC +struct test_trunc : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_trunc(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 2, 2, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_trunc(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.0f, 10.0f); + } + } +}; + // GGML_OP_DIAG_MASK_INF struct test_diag_mask_inf : public test_case { const ggml_type type; @@ -6585,6 +6709,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cos (type)); test_cases.emplace_back(new test_clamp (type)); test_cases.emplace_back(new test_leaky_relu(type)); + test_cases.emplace_back(new test_floor (type)); + test_cases.emplace_back(new test_ceil (type)); + test_cases.emplace_back(new test_round (type)); + test_cases.emplace_back(new test_trunc (type)); test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_log (type, {7, 1, 5, 3})); @@ -6592,6 +6720,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_round (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3})); } test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));