diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.c b/ggml/src/ggml-hexagon/htp/hvx-exp.c index d0735e9325e1c..21bf46a542f9e 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.c +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.c @@ -16,13 +16,8 @@ #include "hvx-utils.h" #include "ops-utils.h" -static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec) { - static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) - - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); +static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { + const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); HVX_Vector out = hvx_vec_exp_fp32(in_vec); @@ -47,6 +42,12 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int HVX_Vector vec_out = Q6_V_vzero(); + static const float kInf = INFINITY; + static const float kMaxExp = 88.02f; // log(INF) + + const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_fp32(kInf); + if (0 == unaligned_loop) { HVX_Vector * p_vec_in1 = (HVX_Vector *) src; HVX_Vector * p_vec_out = (HVX_Vector *) dst; @@ -55,9 +56,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++); - *p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in); + *p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++); + *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf); } } } else { @@ -67,9 +68,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf); } } } @@ -83,9 +84,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int if (true == negate) { HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - vec_out = hvx_vec_exp_fp32_guard(neg_vec_in); + vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); } else { - vec_out = hvx_vec_exp_fp32_guard(in); + vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf); } hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.c b/ggml/src/ggml-hexagon/htp/hvx-inverse.c index 953d3e6c16709..4d70634fcd483 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.c +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.c @@ -16,6 +16,15 @@ #include "hvx-utils.h" #include "ops-utils.h" +static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_fp32(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); +} + void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -32,19 +41,22 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n"); } + static const uint32_t kNanInfMask = 0x7f800000; + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask); + if (0 == unaligned_loop) { HVX_Vector * p_vec_in = (HVX_Vector *) src; HVX_Vector * p_vec_out = (HVX_Vector *) dst; #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++); + *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask); } } else { #pragma unroll(4) for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); } } @@ -53,7 +65,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const float * dstf = (float *) dst + num_elems_whole; HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = hvx_vec_inverse_fp32_guard(in); + HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); } diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 5f94645cde3b1..28b0014fb5a22 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -726,24 +726,6 @@ static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) { return Q6_Vsf_equals_Vqf32(r_qf); } -static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf) { - static const float kInf = INFINITY; - static const uint32_t kNanMask = 0x7fffffff; - static const uint32_t kNanMin = 0x7f800000; - - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(inf, v_sf); - - HVX_Vector out = hvx_vec_inverse_fp32(v_sf); - - const HVX_Vector nan_mask = Q6_V_vsplat_R(kNanMask); - const HVX_Vector nan_min = Q6_V_vsplat_R(kNanMin); - HVX_Vector masked_out = Q6_V_vand_VV(out, nan_mask); - const HVX_VectorPred pred = Q6_Q_vcmp_gtand_QVuwVuw(pred_inf, nan_min, masked_out); - - return Q6_V_vmux_QVV(pred, out, Q6_V_vzero()); -} - #define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 #define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 #define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 @@ -958,14 +940,16 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) { return Q6_Vsf_equals_Vqf32(temp); } -static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v) { - static const float kMaxExp = -88.02f; // log(INF) - - const HVX_Vector max_exp = Q6_V_vsplat_R(*((uint32_t *) &kMaxExp)); - const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(v, max_exp); +static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, + HVX_Vector one, + HVX_Vector max_exp, + HVX_Vector min_exp) { + const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); + const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v); - return Q6_V_vmux_QVV(pred_inf, out, Q6_V_vzero()); + out = Q6_V_vmux_QVV(pred_max, out, one); + return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); } static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { @@ -977,9 +961,16 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + static const float kMinExp = -87.f; // 0 + static const float kMaxExp = 87.f; // 1 + + const HVX_Vector one = hvx_vec_splat_fp32(1.f); + const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); + #pragma unroll(4) for (int i = 0; i < step_of_1; i++) { - v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i]); + v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp); } }