From 9bfecf4294b7a3b41b10323cb142bdaaa55cd70f Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Sun, 20 Oct 2024 01:15:42 +0800 Subject: [PATCH 1/7] ggml : RISC-V vector gemv for q4_0_8x8 --- ggml/src/ggml-aarch64.c | 75 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index b27f411474f4c..2d30caec67987 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -991,6 +991,81 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * } } return; +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1(__riscv_vid_v_u8m1(vl), 7, vl); + const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2(lhs_idx_m1, lhs_idx_m1); + const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2(lhs_idx_m2, 8, vl); + const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4(lhs_idx_m2, lhs_idx_m2_hi); + const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000000000FFull, vl / 8))); + const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000000000FF00ull, vl / 8))); + const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000000000FF0000ull, vl / 8))); + const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000FF000000ull, vl / 8))); + const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000FF00000000ull, vl / 8))); + const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000FF0000000000ull, vl / 8))); + const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00FF000000000000ull, vl / 8))); + const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0xFF00000000000000ull, vl / 8))); + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1(a_ptr[l].qs, vl); + const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, lhs_raw_vec); + const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, __riscv_vslidedown_vx_i8m1(lhs_raw_vec, 16, vl)); + const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4(lhs_raw_vec_lo, lhs_idx_m4, vl * 4); + const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4(lhs_raw_vec_hi, lhs_idx_m4, vl * 4); + + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + + const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); + const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); + const vint16m8_t sumi = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); + + const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4); + const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m(mask7, sumi, iaccz, vl * 4); + const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1(iacc7, iacc7, 1, vl / 4); + const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask6, iacc7s, sumi, iaccz, vl * 4); + const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1(iacc6, iacc6, 1, vl / 4); + const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask5, iacc6s, sumi, iaccz, vl * 4); + const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1(iacc5, iacc5, 1, vl / 4); + const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask4, iacc5s, sumi, iaccz, vl * 4); + const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1(iacc4, iacc4, 1, vl / 4); + const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask3, iacc4s, sumi, iaccz, vl * 4); + const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1(iacc3, iacc3, 1, vl / 4); + const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask2, iacc3s, sumi, iaccz, vl * 4); + const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1(iacc2, iacc2, 1, vl / 4); + const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask1, iacc2s, sumi, iaccz, vl * 4); + const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1(iacc1, iacc1, 1, vl / 4); + const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask0, iacc1s, sumi, iaccz, vl * 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0, vl / 4); + + // vector version needs Zvfhmin extension + const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d); + const float b_scales[8] = { + GGML_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); + const vfloat32m1_t tmp2 = __riscv_vfmul_vv_f32m1(tmp1, b_scales_vec, vl / 4); + sumf = __riscv_vfadd_vv_f32m1(sumf, tmp2, vl / 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); + } + return; + } #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) { float sumf[8]; From 3f7fdf24b07a8f8ec4de46921ca74e156f0968b8 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Tue, 22 Oct 2024 02:42:35 +0800 Subject: [PATCH 2/7] ggml : Added WIP rvv q4_0_8x8 gemm --- ggml/src/ggml-aarch64.c | 117 +++++++++++++++++++++++++++------------- 1 file changed, 81 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 2d30caec67987..1cf3f713fa987 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -994,18 +994,9 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * #elif defined(__riscv_v_intrinsic) if (__riscv_vlenb() >= QK4_0) { const size_t vl = QK4_0; - const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1(__riscv_vid_v_u8m1(vl), 7, vl); - const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2(lhs_idx_m1, lhs_idx_m1); - const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2(lhs_idx_m2, 8, vl); - const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4(lhs_idx_m2, lhs_idx_m2_hi); - const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000000000FFull, vl / 8))); - const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000000000FF00ull, vl / 8))); - const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000000000FF0000ull, vl / 8))); - const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000FF000000ull, vl / 8))); - const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000FF00000000ull, vl / 8))); - const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000FF0000000000ull, vl / 8))); - const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00FF000000000000ull, vl / 8))); - const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0xFF00000000000000ull, vl / 8))); + const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF}; + const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2); + const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1013,11 +1004,12 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); for (int l = 0; l < nb; l++) { - const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1(a_ptr[l].qs, vl); - const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, lhs_raw_vec); - const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, __riscv_vslidedown_vx_i8m1(lhs_raw_vec, 16, vl)); - const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4(lhs_raw_vec_lo, lhs_idx_m4, vl * 4); - const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4(lhs_raw_vec_hi, lhs_idx_m4, vl * 4); + const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8)); + const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[8], vl / 8)); + const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[16], vl / 8)); + const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[24], vl / 8)); + const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4); + const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4); const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); @@ -1025,25 +1017,34 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); - const vint16m8_t sumi = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); - - const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4); - const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m(mask7, sumi, iaccz, vl * 4); - const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1(iacc7, iacc7, 1, vl / 4); - const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask6, iacc7s, sumi, iaccz, vl * 4); - const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1(iacc6, iacc6, 1, vl / 4); - const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask5, iacc6s, sumi, iaccz, vl * 4); - const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1(iacc5, iacc5, 1, vl / 4); - const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask4, iacc5s, sumi, iaccz, vl * 4); - const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1(iacc4, iacc4, 1, vl / 4); - const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask3, iacc4s, sumi, iaccz, vl * 4); - const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1(iacc3, iacc3, 1, vl / 4); - const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask2, iacc3s, sumi, iaccz, vl * 4); - const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1(iacc2, iacc2, 1, vl / 4); - const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask1, iacc2s, sumi, iaccz, vl * 4); - const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1(iacc1, iacc1, 1, vl / 4); - const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask0, iacc1s, sumi, iaccz, vl * 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0, vl / 4); + const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); + const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); + const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); + const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); + + const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); + const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); + const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); + const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); + const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); + const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); + const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); + const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); + const vint32m1_t iacc6s = __riscv_vslide1up_vx_i32m1(iacc7, __riscv_vmv_x_s_i32m1_i32(iacc6), vl / 4); + const vint32m1_t iacc5s = __riscv_vslide1up_vx_i32m1(iacc6s, __riscv_vmv_x_s_i32m1_i32(iacc5), vl / 4); + const vint32m1_t iacc4s = __riscv_vslide1up_vx_i32m1(iacc5s, __riscv_vmv_x_s_i32m1_i32(iacc4), vl / 4); + const vint32m1_t iacc3s = __riscv_vslide1up_vx_i32m1(iacc4s, __riscv_vmv_x_s_i32m1_i32(iacc3), vl / 4); + const vint32m1_t iacc2s = __riscv_vslide1up_vx_i32m1(iacc3s, __riscv_vmv_x_s_i32m1_i32(iacc2), vl / 4); + const vint32m1_t iacc1s = __riscv_vslide1up_vx_i32m1(iacc2s, __riscv_vmv_x_s_i32m1_i32(iacc1), vl / 4); + const vint32m1_t iacc0s = __riscv_vslide1up_vx_i32m1(iacc1s, __riscv_vmv_x_s_i32m1_i32(iacc0), vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0s, vl / 4); // vector version needs Zvfhmin extension const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d); @@ -3246,6 +3247,50 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * } } } + return; + } +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF}; + const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + // for (int m = 0; m < 4; m++) { + // for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + // } + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + + { + const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8)); + const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[32], vl / 8)); + const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[64], vl / 8)); + const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[96], vl / 8)); + const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4); + const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4); + } + } + // for (int m = 0; m < 4; m++) { + // for (int j = 0; j < ncols_interleaved; j++) + // s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + // } + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); + } + } + return; } #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) From 238cd6674eacdaffa1dcdab763dab19c8cd95684 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Tue, 22 Oct 2024 15:42:50 +0800 Subject: [PATCH 3/7] ggml : Added initial implementation of rvv gemm --- ggml/src/ggml-aarch64.c | 207 +++++++++++++++++++++++++++++++--------- 1 file changed, 160 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 1cf3f713fa987..c42667893d37a 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -994,9 +994,6 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * #elif defined(__riscv_v_intrinsic) if (__riscv_vlenb() >= QK4_0) { const size_t vl = QK4_0; - const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF}; - const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2); - const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1004,12 +1001,12 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); for (int l = 0; l < nb; l++) { - const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8)); - const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[8], vl / 8)); - const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[16], vl / 8)); - const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[24], vl / 8)); - const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4); - const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4); + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4)); + const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); + const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); @@ -1022,29 +1019,19 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); - const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); - const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); - const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); - const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); - const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); - const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); - const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2); - const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2); - const vint32m1_t iacc6s = __riscv_vslide1up_vx_i32m1(iacc7, __riscv_vmv_x_s_i32m1_i32(iacc6), vl / 4); - const vint32m1_t iacc5s = __riscv_vslide1up_vx_i32m1(iacc6s, __riscv_vmv_x_s_i32m1_i32(iacc5), vl / 4); - const vint32m1_t iacc4s = __riscv_vslide1up_vx_i32m1(iacc5s, __riscv_vmv_x_s_i32m1_i32(iacc4), vl / 4); - const vint32m1_t iacc3s = __riscv_vslide1up_vx_i32m1(iacc4s, __riscv_vmv_x_s_i32m1_i32(iacc3), vl / 4); - const vint32m1_t iacc2s = __riscv_vslide1up_vx_i32m1(iacc3s, __riscv_vmv_x_s_i32m1_i32(iacc2), vl / 4); - const vint32m1_t iacc1s = __riscv_vslide1up_vx_i32m1(iacc2s, __riscv_vmv_x_s_i32m1_i32(iacc1), vl / 4); - const vint32m1_t iacc0s = __riscv_vslide1up_vx_i32m1(iacc1s, __riscv_vmv_x_s_i32m1_i32(iacc0), vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0s, vl / 4); + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); // vector version needs Zvfhmin extension const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d); @@ -3252,16 +3239,11 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * #elif defined(__riscv_v_intrinsic) if (__riscv_vlenb() >= QK4_0) { const size_t vl = QK4_0; - const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF}; - const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - // for (int m = 0; m < 4; m++) { - // for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - // } vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); @@ -3271,19 +3253,150 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + // vector version needs Zvfhmin extension + const float a_scales[4] = { + GGML_FP16_TO_FP32(a_ptr[l].d[0]), + GGML_FP16_TO_FP32(a_ptr[l].d[1]), + GGML_FP16_TO_FP32(a_ptr[l].d[2]), + GGML_FP16_TO_FP32(a_ptr[l].d[3]) + }; + const float b_scales[8] = { + GGML_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + + vint16m4_t sumi_l0; { - const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8)); - const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[32], vl / 8)); - const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[64], vl / 8)); - const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[96], vl / 8)); - const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4); - const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4); + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[32], 0, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[64], 0, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[96], 0, vl / 4)); + const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); + const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); + const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); + const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); + const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); + const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); + const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); + const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); + sumi_l0 = sumi; + } + __asm__ __volatile__("" ::: "memory"); + + vint16m4_t sumi_l1; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[40], 0, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[72], 0, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[104], 0, vl / 4)); + const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); + const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); + const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); + const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); + const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); + const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); + const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); + const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); + sumi_l1 = sumi; + } + __asm__ __volatile__("" ::: "memory"); + + { + const vint16m8_t sumi = __riscv_vcreate_v_i16m4_i16m8(sumi_l0, sumi_l1); + const vuint32m8_t sumi_i32 = __riscv_vreinterpret_v_i32m8_u32m8(__riscv_vreinterpret_v_i16m8_i32m8(sumi)); + const vuint16m4_t sumi_h2_0 = __riscv_vnsrl_wx_u16m4(sumi_i32, 0, vl * 2); + const vuint16m4_t sumi_h2_1 = __riscv_vnsrl_wx_u16m4(sumi_i32, 16, vl * 2); + const vuint16m4_t sumi_h2 = __riscv_vadd_vv_u16m4(sumi_h2_0, sumi_h2_1, vl * 2); + const vuint32m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m4_u32m4(sumi_h2); + const vuint16m2_t sumi_h4_0 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 0, vl); + const vuint16m2_t sumi_h4_1 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 16, vl); + const vuint16m2_t sumi_h4 = __riscv_vadd_vv_u16m2(sumi_h4_0, sumi_h4_1, vl); + const vuint32m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h4); + const vint16m1_t sumi_h8_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 0, vl / 2)); + const vint16m1_t sumi_h8_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 16, vl / 2)); + const vint32m2_t sumi_h8 = __riscv_vwadd_vv_i32m2(sumi_h8_0, sumi_h8_1, vl / 2); + const vfloat32m2_t facc = __riscv_vfcvt_f_x_v_f32m2(sumi_h8, vl / 2); + + const vfloat32m1_t facc0 = __riscv_vget_v_f32m2_f32m1(facc, 0); + const vfloat32m1_t tmp01 = __riscv_vfmul_vf_f32m1(facc0, a_scales[0], vl / 4); + const vfloat32m1_t tmp02 = __riscv_vfmul_vv_f32m1(tmp01, b_scales_vec, vl / 4); + sumf0 = __riscv_vfadd_vv_f32m1(sumf0, tmp02, vl / 4); + const vfloat32m1_t facc1 = __riscv_vget_v_f32m2_f32m1(facc, 1); + const vfloat32m1_t tmp11 = __riscv_vfmul_vf_f32m1(facc1, a_scales[1], vl / 4); + const vfloat32m1_t tmp12 = __riscv_vfmul_vv_f32m1(tmp11, b_scales_vec, vl / 4); + sumf1 = __riscv_vfadd_vv_f32m1(sumf1, tmp12, vl / 4); + } + __asm__ __volatile__("" ::: "memory"); + + vint16m4_t sumi_l2; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[48], 0, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[80], 0, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[112], 0, vl / 4)); + const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); + const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); + const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); + const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); + const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); + const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); + const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); + const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); + sumi_l2 = sumi; + } + __asm__ __volatile__("" ::: "memory"); + + vint16m4_t sumi_l3; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[56], 0, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[88], 0, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[120], 0, vl / 4)); + const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); + const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); + const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); + const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); + const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); + const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); + const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); + const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); + sumi_l3 = sumi; + } + __asm__ __volatile__("" ::: "memory"); + + { + const vint16m8_t sumi = __riscv_vcreate_v_i16m4_i16m8(sumi_l2, sumi_l3); + const vuint32m8_t sumi_i32 = __riscv_vreinterpret_v_i32m8_u32m8(__riscv_vreinterpret_v_i16m8_i32m8(sumi)); + const vuint16m4_t sumi_h2_0 = __riscv_vnsrl_wx_u16m4(sumi_i32, 0, vl * 2); + const vuint16m4_t sumi_h2_1 = __riscv_vnsrl_wx_u16m4(sumi_i32, 16, vl * 2); + const vuint16m4_t sumi_h2 = __riscv_vadd_vv_u16m4(sumi_h2_0, sumi_h2_1, vl * 2); + const vuint32m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m4_u32m4(sumi_h2); + const vuint16m2_t sumi_h4_0 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 0, vl); + const vuint16m2_t sumi_h4_1 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 16, vl); + const vuint16m2_t sumi_h4 = __riscv_vadd_vv_u16m2(sumi_h4_0, sumi_h4_1, vl); + const vuint32m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h4); + const vint16m1_t sumi_h8_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 0, vl / 2)); + const vint16m1_t sumi_h8_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 16, vl / 2)); + const vint32m2_t sumi_h8 = __riscv_vwadd_vv_i32m2(sumi_h8_0, sumi_h8_1, vl / 2); + const vfloat32m2_t facc = __riscv_vfcvt_f_x_v_f32m2(sumi_h8, vl / 2); + + const vfloat32m1_t facc0 = __riscv_vget_v_f32m2_f32m1(facc, 0); + const vfloat32m1_t tmp01 = __riscv_vfmul_vf_f32m1(facc0, a_scales[2], vl / 4); + const vfloat32m1_t tmp02 = __riscv_vfmul_vv_f32m1(tmp01, b_scales_vec, vl / 4); + sumf2 = __riscv_vfadd_vv_f32m1(sumf2, tmp02, vl / 4); + const vfloat32m1_t facc1 = __riscv_vget_v_f32m2_f32m1(facc, 1); + const vfloat32m1_t tmp11 = __riscv_vfmul_vf_f32m1(facc1, a_scales[3], vl / 4); + const vfloat32m1_t tmp12 = __riscv_vfmul_vv_f32m1(tmp11, b_scales_vec, vl / 4); + sumf3 = __riscv_vfadd_vv_f32m1(sumf3, tmp12, vl / 4); } } - // for (int m = 0; m < 4; m++) { - // for (int j = 0; j < ncols_interleaved; j++) - // s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - // } __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); From c039415ecfc6d7fccd1123e6364d99db4b2b8447 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Tue, 22 Oct 2024 21:49:40 +0800 Subject: [PATCH 4/7] ggml : optimize gemm to avoid register spillover --- ggml/src/ggml-aarch64.c | 209 +++++++++++++++++++++------------------- 1 file changed, 111 insertions(+), 98 deletions(-) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index c42667893d37a..121f9d646ae54 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -1005,21 +1005,21 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4)); const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4)); const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4)); - const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); - const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); - const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); - const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); - const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); - const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); - const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); - const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi)); + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m)); const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); @@ -1047,8 +1047,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * }; const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); - const vfloat32m1_t tmp2 = __riscv_vfmul_vv_f32m1(tmp1, b_scales_vec, vl / 4); - sumf = __riscv_vfadd_vv_f32m1(sumf, tmp2, vl / 4); + sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4); } __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); } @@ -3252,6 +3251,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); // vector version needs Zvfhmin extension const float a_scales[4] = { @@ -3278,17 +3281,33 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[32], 0, vl / 4)); const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[64], 0, vl / 4)); const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[96], 0, vl / 4)); - const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); - const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); - const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); - const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); - const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); - const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); - const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); - const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); - sumi_l0 = sumi; + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l0 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); } - __asm__ __volatile__("" ::: "memory"); + // __asm__ __volatile__("" ::: "memory"); vint16m4_t sumi_l1; { @@ -3296,44 +3315,33 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[40], 0, vl / 4)); const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[72], 0, vl / 4)); const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[104], 0, vl / 4)); - const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); - const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); - const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); - const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); - const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); - const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); - const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); - const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); - sumi_l1 = sumi; + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l1 = sumi_hi_m; } - __asm__ __volatile__("" ::: "memory"); { - const vint16m8_t sumi = __riscv_vcreate_v_i16m4_i16m8(sumi_l0, sumi_l1); - const vuint32m8_t sumi_i32 = __riscv_vreinterpret_v_i32m8_u32m8(__riscv_vreinterpret_v_i16m8_i32m8(sumi)); - const vuint16m4_t sumi_h2_0 = __riscv_vnsrl_wx_u16m4(sumi_i32, 0, vl * 2); - const vuint16m4_t sumi_h2_1 = __riscv_vnsrl_wx_u16m4(sumi_i32, 16, vl * 2); - const vuint16m4_t sumi_h2 = __riscv_vadd_vv_u16m4(sumi_h2_0, sumi_h2_1, vl * 2); - const vuint32m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m4_u32m4(sumi_h2); - const vuint16m2_t sumi_h4_0 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 0, vl); - const vuint16m2_t sumi_h4_1 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 16, vl); - const vuint16m2_t sumi_h4 = __riscv_vadd_vv_u16m2(sumi_h4_0, sumi_h4_1, vl); - const vuint32m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h4); - const vint16m1_t sumi_h8_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 0, vl / 2)); - const vint16m1_t sumi_h8_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 16, vl / 2)); - const vint32m2_t sumi_h8 = __riscv_vwadd_vv_i32m2(sumi_h8_0, sumi_h8_1, vl / 2); - const vfloat32m2_t facc = __riscv_vfcvt_f_x_v_f32m2(sumi_h8, vl / 2); - - const vfloat32m1_t facc0 = __riscv_vget_v_f32m2_f32m1(facc, 0); - const vfloat32m1_t tmp01 = __riscv_vfmul_vf_f32m1(facc0, a_scales[0], vl / 4); - const vfloat32m1_t tmp02 = __riscv_vfmul_vv_f32m1(tmp01, b_scales_vec, vl / 4); - sumf0 = __riscv_vfadd_vv_f32m1(sumf0, tmp02, vl / 4); - const vfloat32m1_t facc1 = __riscv_vget_v_f32m2_f32m1(facc, 1); - const vfloat32m1_t tmp11 = __riscv_vfmul_vf_f32m1(facc1, a_scales[1], vl / 4); - const vfloat32m1_t tmp12 = __riscv_vfmul_vv_f32m1(tmp11, b_scales_vec, vl / 4); - sumf1 = __riscv_vfadd_vv_f32m1(sumf1, tmp12, vl / 4); + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); } - __asm__ __volatile__("" ::: "memory"); + // __asm__ __volatile__("" ::: "memory"); vint16m4_t sumi_l2; { @@ -3341,17 +3349,33 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[48], 0, vl / 4)); const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[80], 0, vl / 4)); const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[112], 0, vl / 4)); - const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); - const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); - const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); - const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); - const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); - const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); - const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); - const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); - sumi_l2 = sumi; + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l2 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); } - __asm__ __volatile__("" ::: "memory"); + // __asm__ __volatile__("" ::: "memory"); vint16m4_t sumi_l3; { @@ -3359,42 +3383,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[56], 0, vl / 4)); const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[88], 0, vl / 4)); const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[120], 0, vl / 4)); - const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8); - const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8); - const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4); - const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4); - const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4); - const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0); - const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1); - const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2); - sumi_l3 = sumi; + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l3 = sumi_hi_m; } - __asm__ __volatile__("" ::: "memory"); { - const vint16m8_t sumi = __riscv_vcreate_v_i16m4_i16m8(sumi_l2, sumi_l3); - const vuint32m8_t sumi_i32 = __riscv_vreinterpret_v_i32m8_u32m8(__riscv_vreinterpret_v_i16m8_i32m8(sumi)); - const vuint16m4_t sumi_h2_0 = __riscv_vnsrl_wx_u16m4(sumi_i32, 0, vl * 2); - const vuint16m4_t sumi_h2_1 = __riscv_vnsrl_wx_u16m4(sumi_i32, 16, vl * 2); - const vuint16m4_t sumi_h2 = __riscv_vadd_vv_u16m4(sumi_h2_0, sumi_h2_1, vl * 2); - const vuint32m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m4_u32m4(sumi_h2); - const vuint16m2_t sumi_h4_0 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 0, vl); - const vuint16m2_t sumi_h4_1 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 16, vl); - const vuint16m2_t sumi_h4 = __riscv_vadd_vv_u16m2(sumi_h4_0, sumi_h4_1, vl); - const vuint32m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h4); - const vint16m1_t sumi_h8_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 0, vl / 2)); - const vint16m1_t sumi_h8_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 16, vl / 2)); - const vint32m2_t sumi_h8 = __riscv_vwadd_vv_i32m2(sumi_h8_0, sumi_h8_1, vl / 2); - const vfloat32m2_t facc = __riscv_vfcvt_f_x_v_f32m2(sumi_h8, vl / 2); - - const vfloat32m1_t facc0 = __riscv_vget_v_f32m2_f32m1(facc, 0); - const vfloat32m1_t tmp01 = __riscv_vfmul_vf_f32m1(facc0, a_scales[2], vl / 4); - const vfloat32m1_t tmp02 = __riscv_vfmul_vv_f32m1(tmp01, b_scales_vec, vl / 4); - sumf2 = __riscv_vfadd_vv_f32m1(sumf2, tmp02, vl / 4); - const vfloat32m1_t facc1 = __riscv_vget_v_f32m2_f32m1(facc, 1); - const vfloat32m1_t tmp11 = __riscv_vfmul_vf_f32m1(facc1, a_scales[3], vl / 4); - const vfloat32m1_t tmp12 = __riscv_vfmul_vv_f32m1(tmp11, b_scales_vec, vl / 4); - sumf3 = __riscv_vfadd_vv_f32m1(sumf3, tmp12, vl / 4); + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); } } __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); From 78c78e2af24c912f75deb09aa2237313b3fe1f69 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Thu, 24 Oct 2024 14:18:01 +0800 Subject: [PATCH 5/7] ggml : Fix GCC rvv load alignment issue --- ggml/src/ggml-aarch64.c | 68 +++++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 121f9d646ae54..c31480216b8a7 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -1001,10 +1001,15 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); for (int l = 0; l < nb; l++) { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4)); + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4)); const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); @@ -3275,12 +3280,17 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * }; const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment vint16m4_t sumi_l0; { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[32], 0, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[64], 0, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[96], 0, vl / 4)); + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); @@ -3307,14 +3317,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); } - // __asm__ __volatile__("" ::: "memory"); + const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment vint16m4_t sumi_l1; { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[40], 0, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[72], 0, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[104], 0, vl / 4)); + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); @@ -3341,14 +3355,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); } - // __asm__ __volatile__("" ::: "memory"); + const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment vint16m4_t sumi_l2; { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[48], 0, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[80], 0, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[112], 0, vl / 4)); + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); @@ -3375,14 +3393,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); } - // __asm__ __volatile__("" ::: "memory"); + const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment vint16m4_t sumi_l3; { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[56], 0, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[88], 0, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[120], 0, vl / 4)); + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); From 37057a0fcf9da4ca09b68e55fa08f414b174b743 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Fri, 25 Oct 2024 16:32:13 +0800 Subject: [PATCH 6/7] ggml : Format gemm rvv code --- ggml/src/ggml-aarch64.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index c31480216b8a7..6ade9c145ae0b 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -3371,10 +3371,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - + sumi_l2 = sumi_hi_m; } - + { const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); @@ -3409,7 +3409,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - + sumi_l3 = sumi_hi_m; } From 274a7720d1161b6354f5be5baf6a8da99bb015b9 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Wed, 30 Oct 2024 14:17:26 +0800 Subject: [PATCH 7/7] ggml : Fix a typo in RVV q4_0_8_8 GEMM --- ggml/src/ggml-aarch64.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 6ade9c145ae0b..eb30f89448c20 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -3352,7 +3352,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); }