diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 23ec8bb08a732..f363910a8f730 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -579,6 +579,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) @@ -597,23 +598,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c - ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c) if (NOT DOTPROD_ENABLED MATCHES -1) list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c - ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c) endif() if (NOT I8MM_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c) endif() if (NOT SME_ENABLED MATCHES -1) list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 3eaa5e3f4100f..1d5b44f9fe3cf 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -4,6 +4,7 @@ // KleidiAI micro-kernels #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" +#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" @@ -11,20 +12,31 @@ #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" +#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" +#include "kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" +#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" #include "kai_common.h" #include "simd-mappings.h" +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + #include "kernels.h" #define NELEMS(x) sizeof(x) / sizeof(*x) @@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max); } +template +static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { + Fn(m, n, k, lhs, rhs, static_cast(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); +} + template static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { return Fn(m, k, bl, mr, kr, sr); @@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); } +template +static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) { + Fn(m, k, mr, kr, sr, m_idx_start, static_cast(lhs), lhs_stride, lhs_packed); +} + template static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { return Fn(n, k, nr, kr, bl); @@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n static_cast(params)); } +template +static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, + size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale, + void* rhs_packed, size_t extra_bytes, const void* params) { + Fn(num_groups, n, k, nr, kr, sr, + static_cast(rhs), + static_cast(bias), + static_cast(scale), + rhs_packed, extra_bytes, + static_cast(params)); +} + template static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, @@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16( GGML_UNUSED(kr); } +static void dequantize_row_qsi8cxp( + const void *packed_data, + int32_t row_idx, + int64_t k, + float *out, + size_t nr, + size_t packed_row_stride, + size_t kr, + size_t bl, + size_t num_bytes_multiplier +) { + GGML_UNUSED(bl); + GGML_UNUSED(num_bytes_multiplier); + + const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0; + const size_t group_idx = row_idx / nr; + const size_t row_in_group = row_idx % nr; + + const uint8_t * group_ptr = static_cast(packed_data) + group_idx * packed_row_stride; + const int8_t * data_base = reinterpret_cast(group_ptr); + + const size_t num_blocks = k_internal / kr; + + for (size_t block = 0; block < num_blocks; ++block) { + const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr; + for (size_t i = 0; i < kr; ++i) { + const size_t k_idx = block * kr + i; + if (k_idx < (size_t) k) { + out[k_idx] = static_cast(block_ptr[i]); + } + } + } + + const uint8_t * sums_ptr = group_ptr + nr * k_internal; + GGML_UNUSED(sums_ptr); + + const float * scale_ptr = reinterpret_cast(sums_ptr + nr * sizeof(int32_t)); + const float scale = scale_ptr[row_in_group]; + + if (scale == 0.0f) { + for (size_t i = 0; i < (size_t) k; ++i) { + out[i] = 0.0f; + } + return; + } + + for (size_t i = 0; i < (size_t) k; ++i) { + out[i] *= scale; + } +} + static ggml_kleidiai_kernels gemm_gemv_kernels[] = { #if defined(__ARM_FEATURE_SME) { @@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { #endif }; +static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { +#if defined(__ARM_FEATURE_SME) + { + /* SME GEMM */ + { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_float_fn10, + }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl, + }, + /* SME GEMV */ + { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_float_fn10, + }, + /* .gemv_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl, + }, + /* .rhs_info = */ { + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon, + /* .to_float = */ dequantize_row_qsi8cxp, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_scale_fn12, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q8_0, + /* .op_type = */ GGML_TYPE_F32, + }, +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* I8MM GEMM */ + { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_float_fn10, + }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl, + }, + /* I8MM GEMV (dotprod fallback) */ + { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_float_fn10, + }, + /* .gemv_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl, + }, + /* .rhs_info = */ { + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon, + /* .to_float = */ dequantize_row_qsi8cxp, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_scale_fn12, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q8_0, + /* .op_type = */ GGML_TYPE_F32, + }, +#endif +#if defined(__ARM_FEATURE_DOTPROD) + { + /* DOTPROD GEMM */ + { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_float_fn10, + }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl, + }, + /* DOTPROD GEMV */ + { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_float_fn10, + }, + /* .gemv_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl, + }, + /* .rhs_info = */ { + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon, + /* .to_float = */ dequantize_row_qsi8cxp, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_scale_fn12, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q8_0, + /* .op_type = */ GGML_TYPE_F32, + }, +#endif +}; + ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { ggml_kleidiai_kernels * kernel = nullptr; @@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c break; } } + if (!kernel) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { + if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu && + gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type && + gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type && + gemm_gemv_kernels_q8[i].op_type == tensor->type) { + kernel = &gemm_gemv_kernels_q8[i]; + break; + } + } + } #endif } @@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) return kernels; } + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) { + ggml_kleidiai_kernels * kernels = nullptr; + +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) { + if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) { + kernels = &gemm_gemv_kernels_q8[i]; + break; + } + } +#endif + + return kernels; +} diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index a84795a6b2e50..129245400b47f 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels { ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor); ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 8b3df7d78009e..6f2a90fbda7bd 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -5,10 +5,13 @@ #include #include #include +#include +#include #include #include #include #include +#include #if defined(__linux__) #include #include @@ -38,8 +41,9 @@ struct ggml_kleidiai_context { cpu_feature features; - ggml_kleidiai_kernels * kernels; -} static ctx = { CPU_FEATURE_NONE, NULL }; + ggml_kleidiai_kernels * kernels_q4; + ggml_kleidiai_kernels * kernels_q8; +} static ctx = { CPU_FEATURE_NONE, NULL, NULL }; static const char* cpu_feature_to_string(cpu_feature f) { switch (f) { @@ -73,10 +77,14 @@ static void init_kleidiai_context(void) { if (sme_enabled != 0) { ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } - ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); + ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features); + ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features); #ifndef NDEBUG - if (ctx.kernels) { - GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu)); + if (ctx.kernels_q4) { + GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu)); + } + if (ctx.kernels_q8) { + GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu)); } #endif } @@ -130,6 +138,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (kernels->rhs_type == GGML_TYPE_Q4_0) { if (!lhs_info->packed_size_ex) return false; size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr); + } else if (kernels->rhs_type == GGML_TYPE_Q8_0) { + if (!lhs_info->packed_size_ex) return false; + size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false; const int64_t lhs_batch_size0 = op->src[1]->ne[2]; @@ -149,11 +160,13 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (dst->op == GGML_OP_MUL_MAT) { if (dst->src[0]->type == GGML_TYPE_Q4_0) { return compute_forward_q4_0(params, dst); + } else if (dst->src[0]->type == GGML_TYPE_Q8_0) { + return compute_forward_q8_0(params, dst); } else if (dst->src[0]->type == GGML_TYPE_F16) { return compute_forward_fp16(params, dst); } } else if (dst->op == GGML_OP_GET_ROWS) { - if (dst->src[0]->type == GGML_TYPE_Q4_0) { + if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) { return compute_forward_get_rows(params, dst); } } @@ -400,19 +413,120 @@ class tensor_traits : public ggml::cpu::tensor_traits { return true; } - bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); - if (!ctx.kernels) { + bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + if (!kernels) { return false; } + bool is_gemv = src1->ne[1] == 1; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + + if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || + !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { + return false; + } + + const int ith = params->ith; + const int nth_raw = params->nth; + const int nth = nth_raw > 0 ? nth_raw : 1; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = static_cast(params->wdata); + const uint8_t * rhs_packed = static_cast(src0->data); + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = 0; + if (n_start < n) { + n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + } + + const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } + + if (m_start < m) { + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + + lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + } + + ggml_barrier(params->threadpool); + + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void * lhs_ptr = static_cast(lhs_packed + lhs_packed_offset); + float * dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + if (n_to_process > 0) { + kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + sizeof(float), -FLT_MAX, FLT_MAX); + } + + return true; + } + + bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; GGML_TENSOR_BINARY_OP_LOCALS - rhs_packing_info * rhs_info = &ctx.kernels->rhs_info; - kernel_info * kernel = &ctx.kernels->gemm; + ggml_kleidiai_kernels * kernels = nullptr; + size_t block_len = 0; + size_t num_bytes_multiplier = 0; + + if (dst->src[0]->type == GGML_TYPE_Q4_0) { + if (!ctx.kernels_q4) { + return false; + } + kernels = ctx.kernels_q4; + block_len = QK4_0; + num_bytes_multiplier = sizeof(uint16_t); + } else if (dst->src[0]->type == GGML_TYPE_Q8_0) { + if (!ctx.kernels_q8) { + return false; + } + kernels = ctx.kernels_q8; + block_len = QK8_0; + num_bytes_multiplier = sizeof(float); + } else { + return false; + } + + rhs_packing_info * rhs_info = &kernels->rhs_info; + kernel_info * kernel = &kernels->gemm; if (!rhs_info->to_float || !kernel->get_nr) { return false; } @@ -423,8 +537,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t block_rows = kernel->get_nr(); const size_t kr = kernel->get_kr(); - const size_t num_bytes_multiplier = sizeof(uint16_t); - const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0); + const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len); const int ith = params->ith; const int nth = params->nth; @@ -439,7 +552,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]); float *out = (float *)((char *)dst->data + i * nb1); - rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier); + rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier); } return true; @@ -447,21 +560,91 @@ class tensor_traits : public ggml::cpu::tensor_traits { public: int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { - GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0); - GGML_ASSERT(ctx.kernels); const size_t n = tensor->ne[1]; const size_t k = tensor->ne[0]; - size_t nr = ctx.kernels->gemm.get_nr(); - size_t kr = ctx.kernels->gemm.get_kr(); - size_t sr = ctx.kernels->gemm.get_sr(); - struct kai_rhs_pack_qs4cxs1s0_param params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms); + if (tensor->type == GGML_TYPE_Q4_0) { + if (!ctx.kernels_q4) { + return -1; + } + size_t nr = ctx.kernels_q4->gemm.get_nr(); + size_t kr = ctx.kernels_q4->gemm.get_kr(); + size_t sr = ctx.kernels_q4->gemm.get_sr(); + + struct kai_rhs_pack_qs4cxs1s0_param params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, + static_cast(data), + nullptr, nullptr, tensor->data, 0, ¶ms); + GGML_UNUSED(data_size); + return 0; + } else if (tensor->type == GGML_TYPE_Q8_0) { + if (!ctx.kernels_q8) { + return -1; + } + + const size_t row_stride = tensor->nb[1]; + const size_t k_blocks = (k + QK8_0 - 1) / QK8_0; + + std::vector qdata(n * k, 0); + std::vector scales(n, 0.0f); + + for (size_t row = 0; row < n; ++row) { + const auto * row_blocks = reinterpret_cast( + static_cast(data) + row * row_stride); + + float max_abs = 0.0f; + for (size_t block = 0; block < k_blocks; ++block) { + const block_q8_0 & blk = row_blocks[block]; + const float d = GGML_FP16_TO_FP32(blk.d); + for (size_t l = 0; l < QK8_0; ++l) { + const size_t linear_idx = block * QK8_0 + l; + if (linear_idx >= k) { + break; + } + const float value = d * blk.qs[l]; + max_abs = std::max(max_abs, std::fabs(value)); + } + } + + float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f; + scales[row] = scale; + const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f; + + for (size_t block = 0; block < k_blocks; ++block) { + const block_q8_0 & blk = row_blocks[block]; + const float d = GGML_FP16_TO_FP32(blk.d); + for (size_t l = 0; l < QK8_0; ++l) { + const size_t linear_idx = block * QK8_0 + l; + if (linear_idx >= k) { + break; + } + const float value = d * blk.qs[l]; + int32_t q = scale > 0.0f ? static_cast(std::lround(value * inv_scale)) : 0; + q = std::clamp(q, -127, 127); + qdata[row * k + linear_idx] = static_cast(q); + } + } + } + + size_t nr = ctx.kernels_q8->gemm.get_nr(); + size_t kr = ctx.kernels_q8->gemm.get_kr(); + size_t sr = ctx.kernels_q8->gemm.get_sr(); + + struct kai_rhs_pack_qsi8cx_params params; + params.lhs_zero_point = 1; + params.scale_multiplier = 1.0f; + + ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0, + qdata.data(), nullptr, scales.data(), + tensor->data, 0, ¶ms); + GGML_UNUSED(data_size); + return 0; + } - return 0; GGML_UNUSED(data_size); + return -1; } }; @@ -518,27 +701,45 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b } static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { - GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0); - GGML_ASSERT(ctx.kernels); + GGML_UNUSED(buft); - const size_t n = tensor->ne[1]; - const size_t k = tensor->ne[0]; - const size_t nr = ctx.kernels->gemm.get_nr(); - const size_t kr = ctx.kernels->gemm.get_kr(); + const size_t n = tensor->ne[1]; + const size_t k = tensor->ne[0]; + + ggml_kleidiai_kernels * kernels = nullptr; + size_t block_len = 0; + + if (tensor->type == GGML_TYPE_Q4_0) { + GGML_ASSERT(ctx.kernels_q4); + kernels = ctx.kernels_q4; + block_len = QK4_0; + } else if (tensor->type == GGML_TYPE_Q8_0) { + GGML_ASSERT(ctx.kernels_q8); + kernels = ctx.kernels_q8; + block_len = QK8_0; + } else { + return 0; + } - return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0); + const size_t nr = kernels->gemm.get_nr(); + const size_t kr = kernels->gemm.get_kr(); + const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len); + const size_t raw = ggml_nbytes(tensor); - GGML_UNUSED(buft); + return packed > raw ? packed : raw; } namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) && - op->src[0]->type == GGML_TYPE_Q4_0 && + (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) && op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { + if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) { + return false; + } if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; }