From 6f8d9794eda71741320eada29df8136378ddd684 Mon Sep 17 00:00:00 2001 From: cailinxi Date: Thu, 7 Aug 2025 09:57:06 +0800 Subject: [PATCH 01/16] ggml: add spacemit backend Change-Id: I249bdc043485d815a9c351867137bc1e27cc2e23 --- cmake/riscv64-spacemit-linux-gnu-gcc.cmake | 31 + docs/build-riscv64-spacemit.md | 87 + ggml/src/ggml-cpu/CMakeLists.txt | 11 +- ggml/src/ggml-cpu/ggml-cpu.c | 40 + ggml/src/ggml-cpu/ggml-cpu.cpp | 10 + .../ggml-cpu/spacemit/ggml_spacemit_ime.cpp | 1056 ++++++ .../src/ggml-cpu/spacemit/ggml_spacemit_ime.h | 9 + .../spacemit/ggml_spacemit_ime_kernels.cpp | 3219 +++++++++++++++++ .../spacemit/ggml_spacemit_ime_kernels.h | 30 + ggml/src/ggml-cpu/vec.cpp | 248 ++ ggml/src/ggml-cpu/vec.h | 215 +- 11 files changed, 4952 insertions(+), 4 deletions(-) create mode 100644 cmake/riscv64-spacemit-linux-gnu-gcc.cmake create mode 100644 docs/build-riscv64-spacemit.md create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h diff --git a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake new file mode 100644 index 0000000000000..e1df484c7c5b8 --- /dev/null +++ b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake @@ -0,0 +1,31 @@ +# Copyright (c) 2023 SpacemiT. All rights reserved. +set(CMAKE_SYSTEM_NAME Linux) +SET(CMAKE_SYSTEM_PROCESSOR riscv64) +set(CMAKE_SYSTEM_VERSION 1) + +if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)") +message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}") +else() +set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple") +if(DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) +else() + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") +endif() + +set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") +set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc) +set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++) +set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip) +set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu") +set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot") +endif() + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) +set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}") +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic") +add_definitions(-D__fp16=_Float16) diff --git a/docs/build-riscv64-spacemit.md b/docs/build-riscv64-spacemit.md new file mode 100644 index 0000000000000..87cd58d5053f6 --- /dev/null +++ b/docs/build-riscv64-spacemit.md @@ -0,0 +1,87 @@ +> [!IMPORTANT] +> This build documentation is specific only to RISC-V SpacemiT SOCs. + +## Build llama.cpp locally (for riscv64) + +1. Prepare Toolchain For RISCV +~~~ +wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz +~~~ + +2. Build +Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`. +```bash + +cmake -B build-riscv64-spacemit \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DLLAMA_CURL=OFF \ + -DGGML_RV_ZFH=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ + -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \ + -DCMAKE_INSTALL_PREFIX=build-riscv64-spacemit/installed + +cmake --build build-riscv64-spacemit --parallel $(nproc) --config Release + +pushd build-riscv64-spacemit +make install +popd +``` + +## Simulation +You can use QEMU to perform emulation on non-RISC-V architectures. + +1. Download QEMU +~~~ +wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz +~~~ + +2. Run Simulation +After build your llama.cpp, you can run the executable file via QEMU for simulation, for example: +~~~ +export QEMU_ROOT_PATH={your QEMU file path} +export RISCV_ROOT_PATH_IME1={your RISC-V compiler path} + +${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build-riscv64-spacemit/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1 +~~~ +## Performance +#### Quantization Support For Matrix +~~~ +model name : Spacemit(R) X60 +isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt +mmu : sv39 +uarch : spacemit,x60 +mvendorid : 0x710 +marchid : 0x8000000058000001 +~~~ + +Q4_0 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02| + +Q4_1 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04| + + +Q4_K +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| \ No newline at end of file diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f188d1638dc5d..bc43a3a6003d0 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -433,7 +433,16 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/repack.cpp ) - if (GGML_RVV) + if (GGML_CPU_RISCV64_SPACEMIT) + list(APPEND ARCH_FLAGS -march=rv64gcv_zfh_zba_zicbop -mabi=lp64d -DGGML_RV_ZFH -D${RISCV64_SPACEMIT_IME_SPEC}) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT) + list(APPEND GGML_CPU_SOURCES + ggml-cpu/spacemit/ggml_spacemit_ime.cpp + ggml-cpu/spacemit/ggml_spacemit_ime.h + ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp + ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h + ) + elseif (GGML_RVV) if (GGML_XTHEADVECTOR) list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) elseif (GGML_RV_ZFH) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index d89cd8f4ef652..31802d125945d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3209,6 +3209,26 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); } +#elif defined(__riscv) && defined(__riscv_v) && defined(__riscv_zfh) + int64_t n_loop = n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "slli t1, t0, 1 \n\t" + "slli t2, t0, 2 \n\t" + "vle32.v v0, (%[IN]) \n\t" + "add %[IN], %[IN], t2 \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "vfncvt.f.f.w v4, v0 \n\t" + "vse16.v v4, (%[DST]) \n\t" + "add %[DST], %[DST], t1 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + + : [ IN ] "+r"(x), [ DST ] "+r"(y), [ n ] "+r"(n_loop) + : + : "cc", "t0", "t1", "t2"); + i += n; #endif for (; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(x[i]); @@ -3250,6 +3270,26 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); vec_xst(v_yh, 0, (float *)(y + i)); } +#elif defined(__riscv) && defined(__riscv_v) && defined(__riscv_zfh) + int64_t n_loop = n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "slli t1, t0, 2 \n\t" + "slli t2, t0, 1 \n\t" + "vle16.v v0, (%[IN]) \n\t" + "add %[IN], %[IN], t2 \n\t" + "vfwcvt.f.f.v v4, v0 \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "vse32.v v4, (%[DST]) \n\t" + "add %[DST], %[DST], t1 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + + : [ IN ] "+r"(x), [ DST ] "+r"(y), [ n ] "+r"(n_loop) + : + : "cc", "t0", "t1", "t2"); + i += n; #endif for (; i < n; ++i) { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 8dacd36714b4c..17da22cfe1533 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -18,6 +18,10 @@ # include "kleidiai/kleidiai.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ggml_spacemit_ime.h" +#endif + #if defined(_WIN32) # define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX @@ -45,6 +49,12 @@ std::vector & ggml_backend_cpu_get_extra_buffer_type } #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_KLEIDIAI if (ggml_backend_cpu_kleidiai_buffer_type()) { bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp new file mode 100644 index 0000000000000..24f6d328be3e3 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp @@ -0,0 +1,1056 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "ggml-common.h" +#include "ggml-backend-impl.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" +#include "traits.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT +#include + +#include "ggml_spacemit_ime.h" +#include "ggml_spacemit_ime_kernels.h" +#include "vec.h" + + +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#endif + +#if !defined(__riscv_zfh) || !defined(GGML_RV_ZFH) +#error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +#error "RISCV64_SPACEMIT_IME1 not defined" +#endif + +#else + +#error "riscv not enabled in this build" + +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + + +#if defined(RISCV64_SPACEMIT_IME1) +#define QGEMM_STRIDEN_THREAD_ALIGN 16 +#else +#define QGEMM_STRIDEN_THREAD_ALIGN 32 +#endif + +typedef enum { + ScaleFp32 = 0, + ScaleFp16, +} QNBIT_GEMM_SCALE_TYPE; + +template +struct QNBIT_GEMM_DATA_PARAMS { + const T* A = nullptr; ///< address of A (float32/16 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C + + QNBIT_GEMM_SCALE_TYPE ScaleType = QNBIT_GEMM_SCALE_TYPE::ScaleFp32; ///< datatype of B scale(FP32 or FP16). +}; + +constexpr +size_t +DivRoundup(size_t up, size_t down) +{ + return (up + down - 1) / down; +} +constexpr size_t +Q8BlkSize(size_t BlkLen) +{ + const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(BlkSize % alignof(float) == 0); + return BlkSize; +} + +namespace ggml::cpu::riscv64_spacemit { + +const int num_ai_cores = std::thread::hardware_concurrency() / 2; + +} // namespace ggml::cpu::riscv64_spacemit + +static void SQ4BitGemm_CompInt8( + const size_t BlkLen, const size_t K, + const QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, const size_t RangeStartM, + const size_t RangeCountM, const size_t RangeStartN, + const size_t RangeCountN) { + const size_t scale_stride = + DataParams->ScaleType == QNBIT_GEMM_SCALE_TYPE::ScaleFp16 + ? sizeof(uint16_t) + : sizeof(float); + + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = DivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * (BlkLen * BlkBitWidth / 8); + const std::byte* QuantA = + static_cast(PerGemmWorkspace) + RangeStartM * lda; + + const size_t zero_point_stride = DataParams->QuantBZeroPoint != nullptr ? sizeof(uint8_t) : 0; + const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); + const std::byte* QuantBData = + static_cast(DataParams->PackedQuantBData) + + RangeStartN * packed_b_stride; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + size_t CountN; + const size_t ComputeBlockCountN = RangeCountM == 1 ? RangeCountN : 16; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, ComputeBlockCountN); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * packed_b_stride; + const std::byte* b_col_zp = (zero_point_stride != 0) + ? b_col + : nullptr; + float* c_blk = C + n; + + size_t RowsRemaining = RangeCountM; + + while (RowsRemaining > 0) { + const auto RowsHandled = sqnbitgemm_spacemit_ime::SQ4BitGemmKernel_CompInt8( + BlkLen, a_row, b_col, nullptr, b_col_zp, c_blk, + RowsRemaining, CountN, K, k_blks, ldc, nullptr, scale_stride); + + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; + } + } +} + + + +template +constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template +struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template +struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), "wrong block_with_zp<4,16> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0* in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1* in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16* dst = (block_q4_0x16*)t->data; + const block_q4_0* src = (const block_q4_0*)data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16* dst = (block_q4_1x16*)t->data; + const block_q4_1* src = (const block_q4_1*)data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, const uint8_t* GGML_RESTRICT q, uint8_t* GGML_RESTRICT d, uint8_t* GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16* dst = (block_q4_1x16*)t->data; + const block_q4_K* src = (const block_q4_K*)data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t* q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(struct ggml_tensor*, const void*, size_t); + +template <> +int repack(struct ggml_tensor* t, const void* data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> +int repack(struct ggml_tensor* t, const void* data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> +int repack(struct ggml_tensor* t, const void* data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(struct ggml_tensor* t, const void* data, size_t data_size) = 0; +}; + +template +class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor* op, size_t& size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; + size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params* params, struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_Q4_0 || // + op->src[0]->type == GGML_TYPE_Q4_1 || // + op->src[0]->type == GGML_TYPE_Q4_K) { + forward_mul_mat_q4(params, op); + return true; + } + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_mul_mat_q4(ggml_compute_params* params, ggml_tensor* op) { + const ggml_tensor* src0 = op->src[0]; + const ggml_tensor* src1 = op->src[1]; + ggml_tensor* dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + int ith = params->ith; + int nth = params->nth; + + [[maybe_unused]] const enum ggml_type type = src0->type; + + void* w_data = (void*)src0->data; + const float* feature = (const float*)src1->data; + float* output = (float*)dst->data; + + const auto BatchN = ne12 * ne13; + [[maybe_unused]] const auto BatchWeight = ne02 * ne03; + const auto M = ne11; + const auto K = ne10; + const auto N = ne01; + + assert(BatchWeight == 1); + // constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = QK4_0; + + size_t BlockCountK = DivRoundup(K, BlkLen); + const size_t Size = M * BlockCountK * Q8BlkSize(BlkLen); + auto Alignment = alignof(double); + size_t PerGemmWorkspaceStride = DivRoundup(Size, Alignment) * Alignment; + const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; + const size_t desired_wsize = WorkspaceSize + Alignment - 1; + + if (params->wsize < desired_wsize && ith == 0) { + throw std::runtime_error( + "wsize less than MlasSQNBitGemmBatchWorkspaceSize"); + } + + std::vector> DataParams(BatchN); + + for (int i = 0; i < BatchN; i++) { + DataParams[i].A = feature + M * K * i; + DataParams[i].lda = K; + DataParams[i].QuantBDataWorkspace = w_data; + DataParams[i].PackedQuantBData = (const std::byte*)w_data; + DataParams[i].QuantBScale = nullptr; + + if constexpr (std::is_same_v) { + DataParams[i].QuantBZeroPoint = nullptr; + } else { + DataParams[i].QuantBZeroPoint = (const uint8_t*)w_data; + } + + DataParams[i].Bias = nullptr; + DataParams[i].C = output + M * N * i; + DataParams[i].ldc = N; + DataParams[i].ScaleType = QNBIT_GEMM_SCALE_TYPE::ScaleFp16; + } + Alignment = alignof(double);; + const uintptr_t WorkspaceAddress = reinterpret_cast(params->wdata); + void* Workspace = reinterpret_cast( + (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))); + + BlockCountK = DivRoundup(K, BlkLen); + const auto PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + PerGemmWorkspaceStride = + DivRoundup(PerGemmWorkspaceSize, Alignment) * Alignment; + + const auto QuantizeARow = sqnbitgemm_spacemit_ime::QuantizeARow_CompInt8; + const auto QuantizeAM4Row = sqnbitgemm_spacemit_ime::QuantizeAM4Row_CompInt8; + + BlockCountK = DivRoundup(K, BlkLen); + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + + { + const size_t BlockSizeM = 4; + size_t BlockCountM = DivRoundup(M, BlockSizeM); + int task_count = BatchN * BlockCountM; + int task_per_thread = (task_count + nth - 1) / nth; + int start = ith * task_per_thread; + int end = std::min((ith + 1) * task_per_thread, task_count); + for (int compute_idx = start; compute_idx < end; compute_idx++) { + auto gemm_idx = compute_idx / BlockCountM; + auto m_idx = compute_idx % BlockCountM * BlockSizeM; + const auto& data = DataParams[gemm_idx]; + auto RowsTobeHandled = (M - m_idx) > 4 ? 4 : (M - m_idx); + if (RowsTobeHandled == 4) { + const float* ARowPtr = data.A + m_idx * data.lda; + std::byte* QuantARowPtr = static_cast(Workspace) + + gemm_idx * PerGemmWorkspaceStride + + m_idx * QuantAStride; + QuantizeAM4Row(BlkLen, ARowPtr, K, QuantARowPtr); + + } else { + while (RowsTobeHandled) { + const float* ARowPtr = data.A + m_idx * data.lda; + std::byte* QuantARowPtr = static_cast(Workspace) + + gemm_idx * PerGemmWorkspaceStride + + m_idx * QuantAStride; + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + RowsTobeHandled -= 1; + m_idx += 1; + } + } + } + } + + ggml_barrier(params->threadpool); + + if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) + return; + + nth = std::min(nth, int{ggml::cpu::riscv64_spacemit::num_ai_cores}); + + int ThreadsPerGemm = nth / BatchN; + constexpr size_t StrideM = 128; + + size_t nc = N; + const size_t BlockedM = DivRoundup(M, StrideM); + const size_t max_nc = DivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min(nc, DivRoundup(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * + QGEMM_STRIDEN_THREAD_ALIGN); + } + + const size_t StrideN = nc; + const size_t ThreadCountM = DivRoundup(M, StrideM); + const size_t ThreadCountN = DivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + { + int task_count = BatchN * ThreadsPerGemm; + int task_per_thread = (task_count + nth - 1) / nth; + int start = ith * task_per_thread; + int end = std::min((ith + 1) * task_per_thread, task_count); + for (int compute_idx = start; compute_idx < end; compute_idx++) { + const auto gemm_i = compute_idx / ThreadsPerGemm; + const auto blk_i = compute_idx % ThreadsPerGemm; + const auto* Data = &DataParams[gemm_i]; + + const auto ThreadIdN = blk_i / ThreadCountM; + const auto ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + void* PerGemmWorkspace = reinterpret_cast(Workspace) + + gemm_i * PerGemmWorkspaceStride; + + SQ4BitGemm_CompInt8(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, + RangeCountM, RangeStartN, RangeCountN); + } + } + } + + int repack(struct ggml_tensor* t, const void* data, size_t data_size) override { + GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), + (int)NB_COLS, (int)INTER_SIZE); + return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); + } +}; + +class tensor_traits_common : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor* op, size_t& size) override { + switch (op->op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + size = 0; + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params* params, struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_NORM: + forward_norm_f32(params, op); + return true; + case GGML_OP_RMS_NORM: + forward_rms_norm_f32(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_norm_f32(ggml_compute_params* params, ggml_tensor* op) { + const ggml_tensor* src0 = op->src[0]; + ggml_tensor* dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto* input = (float*)src0->data; + auto* output = (float*)dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto* p_input = const_cast(input + offset); + + auto* p_output = output + offset; + auto* p_temp_output = p_output; + auto* p_gamma_data = (const float*) nullptr; + auto* p_beta_data = (const float*) nullptr; + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + void forward_rms_norm_f32(ggml_compute_params* params, ggml_tensor* op) { + const ggml_tensor* src0 = op->src[0]; + ggml_tensor* dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto* input = (float*)src0->data; + auto* output = (float*)dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto* p_input = const_cast(input + offset); + auto* p_output = output + offset; + auto* p_temp_output = p_output; + auto* p_gamma_data = (const float*)nullptr; + auto* p_beta_data = (const float*)nullptr; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + // float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + int repack(struct ggml_tensor* t, const void* data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; + +static const tensor_traits q4_0_16x8_q8_0; +static const tensor_traits q4_1_16x8_q8_0; +static const tensor_traits q4_k_16x8_q8_0; +static const tensor_traits_common rvv_impl; + +} // namespace ggml::cpu::riscv64_spacemit + +static const ggml::cpu::tensor_traits* ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor* cur) { + if (cur->type == GGML_TYPE_Q4_0) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_1) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_K) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_F32) { + return &ggml::cpu::riscv64_spacemit::rvv_impl; + } + + return nullptr; +} + +static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor* tensor) { + tensor->extra = (void*)const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); + + GGML_UNUSED(buffer); + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor* tensor, + const void* data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base*)tensor->extra; + if (tensor_traits) { + auto OK = tensor_traits->repack(tensor, data, size); + GGML_ASSERT(OK == 0); + } + + GGML_UNUSED(buffer); +} + +static const char* ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_RISCV64_SPACEMIT"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 64; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const struct ggml_tensor* tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + + size_t nbytes; + const size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } else { + nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + if (tensor->type == GGML_TYPE_Q4_K) { + GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); + nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + } + } else { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } + } + + GGML_UNUSED(buft); + return nbytes; +} + +namespace ggml::cpu::riscv64_spacemit { + +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + if (op->src[0]->type == GGML_TYPE_F32) { + return true; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + ggml::cpu::tensor_traits* get_tensor_traits(const struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + return (ggml::cpu::tensor_traits*)op->src[0]->extra; + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return (ggml::cpu::tensor_traits*)(&ggml::cpu::riscv64_spacemit::rvv_impl); + default: + // GGML_ABORT("fatal error"); + break; + } + + return nullptr; + } +}; + +} // namespace ggml::cpu::riscv64_spacemit + +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes, + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::riscv64_spacemit::extra_buffer_type(), + }; + + return &ggml_backend_cpu_buffer_type_riscv64_spacemit; +} \ No newline at end of file diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h new file mode 100644 index 0000000000000..8020508ac2eef --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h @@ -0,0 +1,9 @@ +#pragma once + +#include "traits.h" +#include "ggml.h" +#include + +// #include +// GGML internal header +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); \ No newline at end of file diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp new file mode 100644 index 0000000000000..947490c8d9ab0 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp @@ -0,0 +1,3219 @@ +#include "ggml_spacemit_ime_kernels.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// HPMADOT_OFF = 0 is not supported when ref_C used +#define HPMADOT_OFF 0 +namespace sqnbitgemm_spacemit_ime +{ +template +struct Q4_0X16 { + T scale[16]; + // uint8_t zp[8]; + // blklen = 16 + // uint8_t[16][8] + // b0 b8, b1 b9, b2 b10, b3 b11, b4 b12, b5 b13, b6 b14, b7 b15 + + // blklen = 32 + // uint8_t[16][2][8] + // b0 b8, b1 b9, b2 b10, b3 b11, b4 b12, b5 b13, b6 b14, b7 b15 + // b16 b24, b17 b25, b18 b26, b19 b27, b20 b28, b21 b29, b22 b30, b23 b31 +}; + +#define QUANTIZEM4ROW_KERNEL \ + "vmv.s.x v16, zero \n\t" \ + "vfabs.v v8, v0 \n\t" \ + "vfredmax.vs v16, v8, v16 \n\t" \ + "vfmv.f.s f10, v16 \n\t" \ + "fmul.s f10, f10, %[RMAXREC] \n\t" \ + "fsw f10, (a1) \n\t" \ + "fdiv.s f11, %[FONE], f10 \n\t" \ + "vfmul.vf v16, v0, f11 \n\t" \ + "vfcvt.x.f.v v16, v16 \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vnclip.wx v16, v16, zero \n\t" \ + "vnclip.wx v17, v17, zero \n\t" \ + "vnclip.wx v18, v18, zero \n\t" \ + "vnclip.wx v19, v19, zero \n\t" \ + "vnclip.wx v20, v20, zero \n\t" \ + "vnclip.wx v21, v21, zero \n\t" \ + "vnclip.wx v22, v22, zero \n\t" \ + "vnclip.wx v23, v23, zero \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vnclip.wx v24, v16, zero \n\t" \ + "vnclip.wx v25, v17, zero \n\t" \ + "vnclip.wx v26, v18, zero \n\t" \ + "vnclip.wx v27, v19, zero \n\t" \ + "vnclip.wx v28, v20, zero \n\t" \ + "vnclip.wx v29, v21, zero \n\t" \ + "vnclip.wx v30, v22, zero \n\t" \ + "vnclip.wx v31, v23, zero \n\t" + +#define QUANTIZEM4ROW_STORE \ + "addi t1, %[BlkLen], 0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v24, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v25, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v26, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v27, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v28, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v29, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v30, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v31, (s1) \n\t" + +void +QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA) +{ + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float* SRC = A + row_index * CountK; + std::byte* DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [ SRC ] "+r"(SRC) + : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), + [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 128) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float* SRC = A + row_index * CountK; + std::byte* DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "addi t2, t2, -128 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v16, v24, v16 \n\t" + "vfredmax.vs v24, v16, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e64, m4 \n\t" + "vsse64.v v16, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "sub t2, t2, t2 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "jal x0, QUANTIZE%= \n\t" + + "QUIT%=: \n\t" + : [ SRC ] "+r"(SRC) + : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), + [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 256) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float* SRC = A + row_index * CountK; + std::byte* DST = QuantA + row_index * sizeof(float); + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "addi t2, t2, -256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v24, v16 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t1, t2, 0 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + + "TAIL_LOOP%=: \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t2, e32, m1 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 32 \n\t" + "vfmul.vf v1, v0, f11 \n\t" + "vfcvt.x.f.v v2, v1 \n\t" + "vsetvli t0, zero, e16, mf2 \n\t" + "vnclip.wx v3, v2, zero \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vnclip.wx v3, v3, zero \n\t" + "vse8.v v3, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bnez t2, TAIL_LOOP%= \n\t" + + "QUIT%=: \n\t" + : [ SRC ] "+r"(SRC) + : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), + [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } +} + +void +QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA) +{ + const float* SRC = A; + std::byte* DST = QuantA; + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + std::byte* QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; + + if (CountK <= BlkLen) { + float max_abs_A = 0.0f; + for (size_t k = 0; k < CountK; k++) { + max_abs_A = std::max(max_abs_A, fabsf(A[k])); + } + float scale_A = max_abs_A * range_max_reciprocal; + + ((float*)QuantA)[0] = scale_A; + + auto* QuantAData_offset = (int8_t*)(QuantA + sizeof(float)); + + for (size_t k = 0; k < CountK; k++) { + QuantAData_offset[k] = + (int8_t)std::clamp(roundf(A[k] / scale_A), (float)std::numeric_limits::lowest(), + (float)std::numeric_limits::max()); + } + for (size_t k = CountK; k < BlkLen; k++) { + QuantAData_offset[k] = 0; + } + + return; + } + + if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { + __asm__ volatile( + "vsetvli t0, zero, e8, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[CNT], e8, m8 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "sub %[CNT], %[CNT], t0 \n\t" + "bnez %[CNT], LOOP%= \n\t" + : [ DST ] "+r"(QuantA_offset), [ CNT ] "+r"(offset) + : + : "cc", "t0"); + } + if (BlkLen == 16) { + float buffer[64] = {0.0f}; + __asm__ volatile( + "addi t3, zero, 16*8 \n\t" + "addi t2, zero, 16 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m2 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v2, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v4, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v6, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v10, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v12, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v14, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v18, v2 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v22, v6 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v26, v10 \n\t" + "vfabs.v v28, v12 \n\t" + "vfabs.v v30, v14 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v18, v18, v19 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v22, v22, v23 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v26, v26, v27 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vse32.v v16, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v18, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v20, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v22, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v24, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v26, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v28, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v30, (a1) \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f11, f3, f7 \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f11, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f12, f3, f7 \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fsw f12, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f13, f3, f7 \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f13, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f14, f3, f7 \n\t" + "fmul.s f14, f14, %[RMAXREC] \n\t" + "fsw f14, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f14, %[FONE], f14 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f15, f3, f7 \n\t" + "fmul.s f15, f15, %[RMAXREC] \n\t" + "fsw f15, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f15, %[FONE], f15 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f16, f3, f7 \n\t" + "fmul.s f16, f16, %[RMAXREC] \n\t" + "fsw f16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f16, %[FONE], f16 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f17, f3, f7 \n\t" + "fmul.s f17, f17, %[RMAXREC] \n\t" + "fsw f17, (%[DST]) \n\t" + "addi %[DST], %[DST], -136 \n\t" + "fdiv.s f17, %[FONE], f17 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v18, v2, f11 \n\t" + "vfmul.vf v20, v4, f12 \n\t" + "vfmul.vf v22, v6, f13 \n\t" + "vfmul.vf v24, v8, f14 \n\t" + "vfmul.vf v26, v10, f15 \n\t" + "vfmul.vf v28, v12, f16 \n\t" + "vfmul.vf v30, v14, f17 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v18, v18 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v22, v22 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v26, v26 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vfcvt.x.f.v v30, v30 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v18, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v20, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v22, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v26, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v28, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v30, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vse32.v v16, (%[BUFFER]) \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m2 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [ SRC ] "+r"(SRC), [ DST ] "+r"(DST), [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ BUFFER ] "r"(buffer) + : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", + "f13", "f14", "f15", "f16", "f17"); + } else if (BlkLen == 32) { + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ SRC ] "r"(SRC), [ DST ] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); + } else if (BlkLen == 64) { + __asm__ volatile( + "addi t3, zero, 64*2 \n\t" + "addi t2, zero, 64 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 68 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vfmax.vv v24, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v25 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 132 \n\t" + "vse8.v v24, (s2) \n\t" + "addi s2, s2, 132 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 256 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [ K ] "+r"(CountK) + : [ SRC ] "r"(SRC), [ DST ] "r"(DST), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); + } else if (BlkLen == 128) { + __asm__ volatile( + "addi t2, zero, 128 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "sub %[K], %[K], t2 \n\t" + "QUANT%=: \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "vfmax.vv v28, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v30, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vfredmax.vs v31, v30, v31 \n\t" + "vfmv.f.s f10, v31 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v8, (a2) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "jal x0, QUANT%= \n\t" + "END%=: \n\t" + + : [ DST ] "+r"(DST), [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ SRC ] "r"(SRC) + : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); + } else { + float buffer[8] = {0.0f}; + size_t cnt = BlkLen / 256; + + __asm__ volatile( + "slli t3, %[BLK], 2 \n\t" + "blt %[K], %[BLK], LOOP_TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_CMP%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v16, v16, v24 \n\t" + "vfmax.vv v0, v0, v16 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, LOOP_CMP%= \n\t" + "sub %[SRC], %[SRC], t3 \n\t" + "addi t6, %[CNT], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_QUANT%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "vse8.v v4, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bnez t6, LOOP_QUANT%= \n\t" + "sub %[K], %[K], %[BLK] \n\t" + "bge %[K], %[BLK], LOOP_MAIN%= \n\t" + "blez %[K], END%= \n\t" + "LOOP_TAIL%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[K], 0 \n\t" + "addi s1, %[SRC], 0 \n\t" + "TAIL_CMP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t6, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "sub t6, t6, t0 \n\t" + "vfabs.v v0, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, TAIL_CMP%= \n\t" + "addi t6, %[K], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[K], 0 \n\t" + "TAIL_QUANT%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t1, t6, e32, m8 \n\t" + "vle32.v v0, (s1) \n\t" + "addi s1, s1, 256 \n\t" + "sub t6, t6, t1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 64 \n\t" + "bnez t6, TAIL_QUANT%= \n\t" + "END%=: \n\t" + : [ SRC ] "+r"(SRC), [ DST ] "+r"(DST), [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ BLK ] "r"(BlkLen), [ BUFFER ] "r"(buffer), + [ CNT ] "r"(cnt) + : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); + } +} + +namespace +{ +#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \ + "vmadot v16, v14, v0 \n\t" \ + "vmadot v18, v14, v1 \n\t" \ + "vmadot v20, v14, v2 \n\t" \ + "vmadot v22, v14, v3 \n\t" \ + "vmadot v16, v15, v4 \n\t" \ + "vmadot v18, v15, v5 \n\t" \ + "vmadot v20, v15, v6 \n\t" \ + "vmadot v22, v15, v7 \n\t" + +#define SQ4BIT_KERNEL_ACC_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 16 \n\t" \ + "addi s3, s1, 32 \n\t" \ + "addi s4, s1, 48 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 8 \n\t" \ + "addi s3, s1, 16 \n\t" \ + "addi s4, s1, 24 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \ + "vle8.v v4, (s1) \n\t" \ + "addi s1, s1, 128 \n\t" \ + "vle8.v v5, (s2) \n\t" \ + "addi s2, s2, 128 \n\t" \ + "vle8.v v6, (s3) \n\t" \ + "addi s3, s3, 128 \n\t" \ + "vle8.v v7, (s4) \n\t" \ + "addi s4, s4, 128 \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vle8.v v14, (s5) \n\t" \ + "addi s5, s5, 16 \n\t" \ + "vle8.v v15, (s6) \n\t" \ + "addi s6, s6, 16 \n\t" \ + "addi t5, t5, -1 \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vand.vi v0, v4, 15 \n\t" \ + "vand.vi v1, v5, 15 \n\t" \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vsrl.vi v4, v4, 4 \n\t" \ + "vsrl.vi v5, v5, 4 \n\t" \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v1, (s7) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v8, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v9, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v10, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v11, v1, v13 \n\t" \ + "vadd.vi v13, v13, -12 \n\t" + +// using for M4Kernel +#define LOAD_B_16x8x2 \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vle8.v v6, (s1) \n\t" \ + "addi s1, s1, 32*4 \n\t" \ + "vle8.v v7, (s2) \n\t" \ + "addi s2, s2, 32*4 \n\t" \ + "vle8.v v8, (s3) \n\t" \ + "addi s3, s3, 32*4 \n\t" \ + "vle8.v v9, (s4) \n\t" \ + "addi s4, s4, 32*4 \n\t" \ + \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vand.vi v4, v8, 15 \n\t" \ + "vand.vi v5, v9, 15 \n\t" \ + \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" \ + "vsrl.vi v8, v8, 4 \n\t" \ + "vsrl.vi v9, v9, 4 \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16_FP16 \ + "addi s2, s5, -8 \n\t" \ + "addi s3, s5, 8 \n\t" \ + "addi s4, s5, 16 \n\t" \ + "addi s6, s5, 24 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e16, mf4 \n\t" \ + "vle16.v v9, (s5) \n\t" \ + "vle16.v v11, (s3) \n\t" \ + "vle16.v v13, (s4) \n\t" \ + "vle16.v v15, (s6) \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vle16.v v9, (s2), v0.t \n\t" \ + "vle16.v v11, (s5), v0.t \n\t" \ + "vle16.v v13, (s3), v0.t \n\t" \ + "vle16.v v15, (s4), v0.t \n\t" \ + "vfwcvt.f.f.v v8, v9 \n\t" \ + "vfwcvt.f.f.v v10, v11 \n\t" \ + "vfwcvt.f.f.v v12, v13 \n\t" \ + "vfwcvt.f.f.v v14, v15 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16 \ + "addi s2, s5, -16 \n\t" \ + "addi s3, s5, 16 \n\t" \ + "addi s4, s5, 32 \n\t" \ + "addi s6, s5, 48 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vle32.v v8, (s5) \n\t" \ + "vle32.v v10, (s3) \n\t" \ + "vle32.v v12, (s4) \n\t" \ + "vle32.v v14, (s6) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v8, (s2), v0.t \n\t" \ + "vle32.v v10, (s5), v0.t \n\t" \ + "vle32.v v12, (s3), v0.t \n\t" \ + "vle32.v v14, (s4), v0.t \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +//[s1| BIAS, s2, s3, s4] +#define LOAD_BIAS \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "addi s1, %[BIAS], -16 \n\t" \ + "addi s2, %[BIAS], 16 \n\t" \ + "addi s3, %[BIAS], 32 \n\t" \ + "addi s4, %[BIAS], 48 \n\t" \ + \ + "vle32.v v24, (%[BIAS]) \n\t" \ + "vle32.v v26, (s2) \n\t" \ + "vle32.v v28, (s3) \n\t" \ + "vle32.v v30, (s4) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v24, (s1), v0.t \n\t" \ + "vle32.v v26, (%[BIAS]), v0.t \n\t" \ + "vle32.v v28, (s2), v0.t \n\t" \ + "vle32.v v30, (s3), v0.t \n\t" \ + "vmv.v.v v25, v24 \n\t" \ + "vmv.v.v v27, v26 \n\t" \ + "vmv.v.v v29, v28 \n\t" \ + "vmv.v.v v31, v30 \n\t" + +#define SQ4BIT_KERNEL_COMP_4x16x16 \ + "vmadot v16, v10, v2 \n\t" \ + "vmadot v18, v10, v3 \n\t" \ + "vmadot v20, v10, v4 \n\t" \ + "vmadot v22, v10, v5 \n\t" \ + "vmadot v16, v11, v6 \n\t" \ + "vmadot v18, v11, v7 \n\t" \ + "vmadot v20, v11, v8 \n\t" \ + "vmadot v22, v11, v9 \n\t" + +#define SAVE_RESULT_4x16 \ + "addi a1, %[C], 0 \n\t" \ + "add a2, %[C], %[LDC] \n\t" \ + "add a3, a2, %[LDC] \n\t" \ + "add a4, a3, %[LDC] \n\t" \ + "addi a2, a2, -16 \n\t" \ + "addi a4, a4, -16 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + \ + "vse32.v v24, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v25, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v26, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v27, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v28, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v29, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v30, (a1) \n\t" \ + "vse32.v v31, (a3) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + \ + "vse32.v v24, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v25, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v26, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v27, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v28, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v29, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v30, (a2), v0.t \n\t" \ + "vse32.v v31, (a4), v0.t \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v11, (s6) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v12, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v13, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v14, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v15, v11, v1 \n\t" \ + "vadd.vi v1, v1, -12 \n\t" + +template +void +SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias, + const size_t ldc) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float* CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [ N ] "r"(N), [ SRC ] "r"(tmp), [ DST ] "r"(CPtr), [ LDC ] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} +template +void +SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias, + const size_t ldc) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float* CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [ N ] "r"(N), [ SRC ] "r"(tmp), [ DST ] "r"(CPtr), [ LDC ] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} +template +void +SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t INNER = BlkLen / 16; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + // zp offset + "addi s7, %[B], 32 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + const size_t INNER = BlkLen / 16; + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + // zp offset + "addi s7, %[B], 64 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + + // load scale + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load a scale + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + // a scale * b scale + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s7, %[B], 64 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +inline void +SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float* Bias, + const size_t ldc, + const size_t scalestride) +{ + if (scalestride == 4) { + SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias, ldc); + + } else if (scalestride == 2) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( + BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); + } +} + +template +inline void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float* Bias, + const size_t ldc, + const size_t scalestride) +{ + if (scalestride == 4) { + SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias); + } else if (scalestride == 2) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + } +} + +} // namespace + +size_t +SQ4BitGemmKernel_CompInt8(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias, + const size_t ScaleStride) +{ + GGML_UNUSED(CountM); + GGML_UNUSED(CountK); + GGML_UNUSED(ldc); + if (CountM >= 4) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 4; + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 1; + } +} +} \ No newline at end of file diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h new file mode 100644 index 0000000000000..e112e7fd7d050 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include "traits.h" +#include "ggml.h" + +namespace sqnbitgemm_spacemit_ime +{ +size_t +SQ4BitGemmKernel_CompInt8(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias, + const size_t ScaleStride); +void +QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); + +void +QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); + +} \ No newline at end of file diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 07b377bdd82a7..5130b5dfcb6ff 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -109,6 +109,28 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G sumf += x[i]*y[i]; } #endif +#elif defined(__riscv) && defined(__riscv_v) + float sumf = 0.0f; + __asm__ volatile( + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v16, v16, v16 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "slli t1, t0, 2 \n\t" + "vle32.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t1 \n\t" + "vle32.v v8, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t1 \n\t" + "vfmacc.vv v16, v0, v8 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vfredusum.vs v24, v16, v24 \n\t" + "vfmv.f.s %[res], v24 \n\t" + : [ n ] "+r"(n), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ res ] "=f"(sumf) + : + : "cc", "t0", "t1"); #else // scalar ggml_float sumf = 0.0; @@ -224,6 +246,32 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G // if you hit this, you are likely running outside the FP range assert(!isnan(sumf) && !isinf(sumf)); +#elif defined(__riscv) && defined(__riscv_v) + float result = 0.0f; + __asm__ volatile( + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v16, v16, v16 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "slli t1, t0, 1 \n\t" + "vle16.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t1 \n\t" + "vle16.v v2, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t1 \n\t" + "vfwcvt.f.f.v v4, v0 \n\t" + "vfwcvt.f.f.v v8, v2 \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "vfmacc.vv v16, v4, v8 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vfredusum.vs v24, v16, v24 \n\t" + "vfmv.f.s %[res], v24 \n\t" + : [ n ] "+r"(n), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ res ] "=f"(result) + : + : "cc", "t0", "t1"); + sumf += result; #else for (int i = 0; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); @@ -251,6 +299,93 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { for (; i + 3 < n; i += 4) { vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); } +#elif defined(__riscv) && defined(__riscv_v) + int N = n; + i += n; + constexpr struct { + float LowerRange; + float UpperRange; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_10; + float beta_8; + float beta_6; + float beta_4; + float beta_2; + float beta_0; + float one_half; + } LogisticConstants = { + -18.0f, + 18.0f, + 4.37031012579801e-11f, + 1.15627324459942e-07f, + 6.08574864600143e-05f, + 8.51377133304701e-03f, + 2.48287947061529e-01f, + 6.10247389755681e-13f, + 5.76102136993427e-09f, + 6.29106785017040e-06f, + 1.70198817374094e-03f, + 1.16817656904453e-01f, + 9.93151921023180e-01f, + 0.5f, + }; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m1,tu,mu \n\t" + "sub %[n], %[n], t0 \n\t" + "slli t0, t0, 2 \n\t" + "vfmv.v.f v20, %[b0] \n\t" + "vfmv.v.f v21, %[a1] \n\t" + "vfmv.v.f v22, %[b2] \n\t" + "vfmv.v.f v23, %[a3] \n\t" + "vfmv.v.f v24, %[b4] \n\t" + "vfmv.v.f v25, %[a5] \n\t" + "vfmv.v.f v26, %[b6] \n\t" + "vfmv.v.f v27, %[a7] \n\t" + "vfmv.v.f v28, %[b8] \n\t" + "vle32.v v0, (%[x]) \n\t" + "add %[x], %[x], t0 \n\t" + "vfmax.vf v1, v0, %[lr] \n\t" + "vfmin.vf v1, v1, %[ur] \n\t" + "vfmul.vv v4, v1, v1 \n\t" + "vmv.v.v v8, v4 \n\t" + "vfmadd.vf v8, %[a9], v27 \n\t" + "vfmadd.vv v8, v4, v25 \n\t" + "vfmadd.vv v8, v4, v23 \n\t" + "vfmadd.vv v8, v4, v21 \n\t" + "vfmul.vv v8, v8, v1 \n\t" + "vmv.v.v v12, v4 \n\t" + "vfmadd.vf v12, %[b10], v28 \n\t" + "vfmadd.vv v12, v4, v26 \n\t" + "vfmadd.vv v12, v4, v24 \n\t" + "vfmadd.vv v12, v4, v22 \n\t" + "vfmadd.vv v12, v4, v20 \n\t" + "vfdiv.vv v12, v8, v12 \n\t" + "vfadd.vf v12, v12, %[onehalf] \n\t" + "vfmul.vv v12, v12, v0 \n\t" // sigmo + "vse32.v v12, (%[y]) \n\t" + "add %[y], %[y], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ x ] "+r"(x), [ y ] "+r"(y) + : [ lr ] "f"(LogisticConstants.LowerRange), + [ ur ] "f"(LogisticConstants.UpperRange), + [ a1 ] "f"(LogisticConstants.alpha_1), + [ a3 ] "f"(LogisticConstants.alpha_3), + [ a5 ] "f"(LogisticConstants.alpha_5), + [ a7 ] "f"(LogisticConstants.alpha_7), + [ a9 ] "f"(LogisticConstants.alpha_9), + [ b0 ] "f"(LogisticConstants.beta_0), + [ b2 ] "f"(LogisticConstants.beta_2), + [ b4 ] "f"(LogisticConstants.beta_4), + [ b6 ] "f"(LogisticConstants.beta_6), + [ b8 ] "f"(LogisticConstants.beta_8), + [ b10 ] "f"(LogisticConstants.beta_10), + [ onehalf ] "f"(LogisticConstants.one_half) + : "cc", "t0"); #endif for (; i < n; ++i) { y[i] = ggml_silu_f32(x[i]); @@ -325,6 +460,119 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } +#elif defined(__riscv) && defined(__riscv_v) + int N = n; + i += n; + float* src = const_cast(reinterpret_cast(x)); + float* dst = reinterpret_cast(y); + float Accumulator = 0.0f; + const float Neg_Max = -max; + + const float LowerRange = -103.9720840454f; + const float UpperRange = 88.7762626647950f; + const float LowerRangeSumExp = -88.3762626647949f; + const float UpperRangeSumExp = 88.3762626647949f; + const float RoundingBias = 12582912.f; + const float Log2Reciprocal = 1.44269504088896341f; + const float Log2High = -6.93145752e-1f; + const float Log2Low = -1.42860677e-6f; + const float poly_0 = 0x1.694000p-10; + const float poly_1 = 0x1.125edcp-7; + const float poly_2 = 0x1.555b5ap-5; + const float poly_3 = 0x1.555450p-3; + const float poly_4 = 0x1.fffff6p-2; + const float poly_56 = 0x1.000000p+0; + // int32_t MinimumExponent = int32_t(0xC1000000); //unused + const int32_t MaximumExponent = int32_t(0x3F800000); + + __asm__ volatile( + "mv t3, %[LEN] \n\t" + "mv s1, %[SRC] \n\t" + "mv s2, %[DST] \n\t" + + /* 2.0 Compute exp() and accumulate and store to cache_buffer */ + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v16, v8, v8 \n\t" + "vxor.vv v0, v8, v8 \n\t" + + ".align 4 \n\t" + "_EXPACC_LEN_LPST: \n\t" + "vsetvli t0, t3, e32, m4,tu,mu \n\t" + + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + + /* 2.1 START exp() */ + "vfadd.vf v0, v0, %[NEG_MAX] \n\t" // v4 = x - max + + // Ensure that q = RN(x/log(2)) >= e_min, so that 2^q can be computed + // safely with a simple shift into the exponent field. xmin = + // round(-126.5 * log(2), single, RU) ~ -87.68311309814453125 const + // float xmin = -0x1.5ebb82p6; + "vfmax.vf v0, v0, %[LowerRangeSumExp] \n\t" + + // 2.1.0. Reduction x = s * q ln(2) + // const float r_ln2f = 0x1.715476p0f; // single(1/log(2)); + // const float l2uf = 0x1.62e4p-1f; // round(log(2), 24-8, RN); + // const float l2lf = 0x1.7f7d1cp-20f; // round(log(2) - l2uf, single, + // RN); + "vfmv.v.f v4, %[RoundingBias] \n\t" + "vfmacc.vf v4, %[Log2Reciprocal], v0 \n\t" // biased in mlas + "vfsub.vf v8, v4, %[RoundingBias] \n\t" // v12_a = float(x - n); + + // Use Cody-Waite range reduction method (note two constants to + // represent log(2)) to improve accuracy. + "vfmacc.vf v0, %[Log2High], v8 \n\t" + "vfmacc.vf v0, %[Log2Low], v8 \n\t" + "vfcvt.x.f.V v8, v4 \n\t" + + // 2.1.1. Approximate e^s by degree-6 polynomial approximation + "vfmv.v.f v4, %[poly_0] \n\t" + "vfmv.v.f v12, %[poly_1] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_2] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_3] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_4] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_56] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_56] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" // v8 = poly(input - max) + + // 2.1.2. Reconstruction: compute u = u*2^q + // const int16_t p = (24 - 1); + // const int16_t bias = (128 - 1); + "vsll.vi v8, v8, 23 \n\t" + "vadd.vx v8, v8, %[MaximumExponent] \n\t" + //"vfcvt.f.x.v v12, v8 \n\t" + + "vfmul.vv v0, v4, v8 \n\t" + /* 2.1 END exp() */ + + "vse32.v v0, (s2) \n\t" // exp(输入-max)输出 + "sh2add s2, t0, s2 \n\t" + "vfadd.vv v16, v16, v0 \n\t" + "sub t3, t3, t0 \n\t" + "bgtz t3, _EXPACC_LEN_LPST \n\t" + + "_EXPACC_LEN_LPND: \n\t" + + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v24, v8, v8 \n\t" + "vfredosum.vs v24, v16, v24 \n\t" + "vfmv.f.s %[RTN], v24 \n\t" // ft2 = sum(exp( )) + + : [ RTN ] "=f"(Accumulator), [ SRC ] "+r"(src), [ DST ] "+r"(dst) + : [ LEN ] "r"(N), [ NEG_MAX ] "f"(Neg_Max), [ LowerRange ] "f"(LowerRange), [ UpperRange ] "f"(UpperRange), + [ LowerRangeSumExp ] "f"(LowerRangeSumExp), [ UpperRangeSumExp ] "f"(UpperRangeSumExp), + [ RoundingBias ] "f"(RoundingBias), [ Log2Reciprocal ] "f"(Log2Reciprocal), [ Log2High ] "f"(Log2High), + [ Log2Low ] "f"(Log2Low), [ poly_0 ] "f"(poly_0), [ poly_1 ] "f"(poly_1), [ poly_2 ] "f"(poly_2), + [ poly_3 ] "f"(poly_3), [ poly_4 ] "f"(poly_4), [ poly_56 ] "f"(poly_56), + [ MaximumExponent ] "r"(MaximumExponent) + : "cc", "s1", "s2", "t0", "t3"); + sum += (ggml_float)Accumulator; #endif for (; i < n; ++i) { float val = expf(x[i] - max); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 2250d93cb00d1..c2b062b0b6fdd 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -51,7 +51,108 @@ inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { fo inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_cpy_i8(const int n, int8_t * y, const int8_t * x) { +#if defined(__riscv) && defined(__riscv_v) && defined(__riscv_v_intrinsic) + size_t vlenb = __riscv_vlenb(); + if (vlenb == 32) { + // 1024 bytes + __asm__ volatile( // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8,tu,mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8,tu,mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [ s ] "+r"(x), [ d ] "+r"(y) + : [ size ] "r"(n) + : "cc", "t0", "t1"); + } else if (vlenb == 128) { + // 2048 bytes + __asm__ volatile( // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8,tu,mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 1024 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 1024 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 1024 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 1024 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8,tu,mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [ s ] "+r"(x), [ d ] "+r"(y) + : [ size ] "r"(n) + : "cc", "t0", "t1"); + } else { + __asm__ volatile( // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8,tu,mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [ s ] "+r"(x), [ d ] "+r"(y) + : [ size ] "r"(n) + : "cc", "t0", "t1"); + } +#else + memcpy(y, x, n); +#endif +} + +inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { +#if defined(__riscv) && defined(__riscv_v) && defined(__riscv_v_intrinsic) + ggml_vec_cpy_i8(n * sizeof(int32_t), (int8_t *)y, (const int8_t *)x); +#else + for (int i = 0; i < n; ++i) y[i] = x[i]; +#endif +} inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } @@ -65,6 +166,25 @@ inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, co __m256 vz = _mm256_add_ps(vx, vy); _mm256_storeu_ps(z + i, vz); } +#elif defined(__riscv) && defined(__riscv_v) + size_t N = n; + i += n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "sub %[n], %[n], t0 \n\t" + "slli t0, t0, 2 \n\t" + "vle32.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t0 \n\t" + "vle32.v v8, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t0 \n\t" + "vfadd.vv v0, v0, v8 \n\t" + "vse32.v v0, (%[z]) \n\t" + "add %[z], %[z], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ z ] "+r"(z) + : + : "cc", "t0"); #endif for (; i < n; ++i) { z[i] = x[i] + y[i]; @@ -86,7 +206,13 @@ inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp } } inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { +#if defined(__riscv) && defined(__riscv_v) + ggml_vec_cpy_i8(n * sizeof(float), (int8_t *)y, (const int8_t *)x); +#else + for (int i = 0; i < n; ++i) y[i] = x[i]; +#endif +} inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { @@ -94,7 +220,29 @@ inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp } } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { +#if defined(__riscv) && defined(__riscv_v) + int N = n; + __asm__ volatile( + "LOOP%=: \t\n" + "vsetvli t0, %[n], e32, m4\t\n" + "sub %[n], %[n], t0 \t\n" + "slli t0, t0, 2 \t\n" + "vle32.v v0, (%[lhs]) \t\n" + "add %[lhs], %[lhs], t0 \t\n" + "vle32.v v8, (%[rhs]) \t\n" + "add %[rhs], %[rhs], t0 \t\n" + "vfmul.vv v0, v0, v8 \t\n" + "vse32.v v0, (%[z]) \t\n" + "add %[z], %[z], t0 \t\n" + "bnez %[n], LOOP%= \t\n" + : [ n ] "+r"(N), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ z ] "+r"(z) + : + : "cc", "t0"); +#else + for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; +#endif +} inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { for (int i = 0; i < n; ++i) { z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i])); @@ -297,6 +445,28 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } +#elif defined(__riscv) && defined(__riscv_v) + size_t N = n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "slli t1, t0, 1 \n\t" + "vle16.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t1 \n\t" + "vle16.v v2, (%[rhs]) \n\t" + "vfwcvt.f.f.v v4, v0 \n\t" + "vfwcvt.f.f.v v8, v2 \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "vfmacc.vf v8, %[fs], v4 \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "vfncvt.f.f.w v12, v8 \n\t" + "vse16.v v12, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t1 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ lhs ] "+r"(x), [ rhs ] "+r"(y) + : [ fs ] "f"(v) + : "cc", "t0", "t1"); #else // scalar for (int i = 0; i < n; ++i) { @@ -457,6 +627,23 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { y[i] *= v; } #endif +#elif defined(__riscv) && defined(__riscv_v) + size_t N = n; + float* out = y; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "sub %[n], %[n], t0 \n\t" + "slli t0, t0, 2 \n\t" + "vle32.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t0 \n\t" + "vfmul.vf v0, v0, %[rhs] \n\t" + "vse32.v v0, (%[out]) \n\t" + "add %[out], %[out], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ lhs ] "+r"(y), [ out ] "+r"(out) + : [ rhs ] "f"(v) + : "cc", "t0"); #else // scalar for (int i = 0; i < n; ++i) { @@ -1090,6 +1277,27 @@ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16 } inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { +#if defined(__riscv_v_intrinsic) +#ifdef __cplusplus + float * src = const_cast(reinterpret_cast(x)); +#else + float * src = (float *) x; +#endif + float Maximum = -INFINITY; + int64_t N = n; + size_t vl; + vfloat32m8_t vmaxf = __riscv_vfmv_v_f_f32m8(Maximum, __riscv_vsetvlmax_e32m8()); + for (; N > 0; N -= vl, src += vl) { + vl = __riscv_vsetvl_e32m8(N); + vfloat32m8_t vsrc = __riscv_vle32_v_f32m8(src, vl); + vmaxf = __riscv_vfmax_vv_f32m8(vmaxf, vsrc, vl); + } + vfloat32m1_t vmaxf_init = __riscv_vfmv_v_f_f32m1(Maximum, __riscv_vsetvlmax_e32m1()); + vl = __riscv_vsetvlmax_e32m8(); + vmaxf_init = __riscv_vfredmax_vs_f32m8_f32m1(vmaxf, vmaxf_init, vl); + float riscv_max = __riscv_vfmv_f_s_f32m1_f32(vmaxf_init); + *s = MAX(riscv_max, *s); +#else #ifndef GGML_USE_ACCELERATE float max = -INFINITY; for (int i = 0; i < n; ++i) { @@ -1099,6 +1307,7 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #else vDSP_maxv(x, 1, s, n); #endif +#endif } inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { From 0f8d88f954b42ee0d2f73c49657f6716cb5a6357 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 14 Aug 2025 06:38:22 +0000 Subject: [PATCH 02/16] add new line at end of file Change-Id: I889ed1c85fb45e62350ecde0c06f70450cadfbe2 --- docs/build-riscv64-spacemit.md | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/build-riscv64-spacemit.md b/docs/build-riscv64-spacemit.md index 87cd58d5053f6..b1356b8432097 100644 --- a/docs/build-riscv64-spacemit.md +++ b/docs/build-riscv64-spacemit.md @@ -84,4 +84,4 @@ Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04| Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10| Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08| Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04| -Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| \ No newline at end of file +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp index 24f6d328be3e3..3224b61306265 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp @@ -1053,4 +1053,4 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { }; return &ggml_backend_cpu_buffer_type_riscv64_spacemit; -} \ No newline at end of file +} diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h index 8020508ac2eef..ab71bfae4e7b0 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h @@ -6,4 +6,4 @@ // #include // GGML internal header -ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); \ No newline at end of file +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp index 947490c8d9ab0..c51a9bfb355d4 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp @@ -3216,4 +3216,4 @@ SQ4BitGemmKernel_CompInt8(size_t BlkLen, return 1; } } -} \ No newline at end of file +} diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h index e112e7fd7d050..af62472aabf84 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h @@ -27,4 +27,4 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q void QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); -} \ No newline at end of file +} From 2ad4450e5f63b61bbec47999cdfbf41338e60c55 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 14 Aug 2025 06:57:14 +0000 Subject: [PATCH 03/16] add riscv zba extension limit Change-Id: I321eb200f859751727afe5cae13074dfce2bb0ce --- ggml/src/ggml-cpu/vec.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 5130b5dfcb6ff..982cf5497ccfb 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -460,7 +460,7 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } -#elif defined(__riscv) && defined(__riscv_v) +#elif defined(__riscv) && defined(__riscv_v) && defined(__riscv_zba) int N = n; i += n; float* src = const_cast(reinterpret_cast(x)); From c99135471b883f24d94604b44e9699837e250c56 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 19 Sep 2025 08:29:16 +0000 Subject: [PATCH 04/16] fixed for review comments, file renamed and format Change-Id: Ia20b6ec24a36638e62e0fe07cf100916a7cce3ce --- cmake/riscv64-spacemit-linux-gnu-gcc.cmake | 31 +++++++++---------- ggml/src/ggml-cpu/CMakeLists.txt | 8 ++--- ggml/src/ggml-cpu/ggml-cpu.cpp | 2 +- .../{ggml_spacemit_ime.cpp => ime.cpp} | 9 ++++-- .../spacemit/{ggml_spacemit_ime.h => ime.h} | 4 --- ...acemit_ime_kernels.cpp => ime_kernels.cpp} | 5 ++- ...l_spacemit_ime_kernels.h => ime_kernels.h} | 5 +-- 7 files changed, 31 insertions(+), 33 deletions(-) rename ggml/src/ggml-cpu/spacemit/{ggml_spacemit_ime.cpp => ime.cpp} (99%) rename ggml/src/ggml-cpu/spacemit/{ggml_spacemit_ime.h => ime.h} (55%) rename ggml/src/ggml-cpu/spacemit/{ggml_spacemit_ime_kernels.cpp => ime_kernels.cpp} (99%) rename ggml/src/ggml-cpu/spacemit/{ggml_spacemit_ime_kernels.h => ime_kernels.h} (91%) diff --git a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake index e1df484c7c5b8..1ec6d19a7435e 100644 --- a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +++ b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake @@ -1,24 +1,23 @@ -# Copyright (c) 2023 SpacemiT. All rights reserved. set(CMAKE_SYSTEM_NAME Linux) -SET(CMAKE_SYSTEM_PROCESSOR riscv64) +set(CMAKE_SYSTEM_PROCESSOR riscv64) set(CMAKE_SYSTEM_VERSION 1) -if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)") -message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}") +if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)") + message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}") else() -set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple") -if(DEFINED ENV{RISCV_ROOT_PATH}) - file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) -else() - message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") -endif() + set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple") + if (DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) + else() + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") + endif() -set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") -set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc) -set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++) -set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip) -set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu") -set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot") + set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") + set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc) + set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++) + set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip) + set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu") + set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot") endif() set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 1c81687dda9aa..17cbbf0e6a969 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -442,10 +442,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_CPU_RISCV64_SPACEMIT) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) list(APPEND GGML_CPU_SOURCES - ggml-cpu/spacemit/ggml_spacemit_ime.cpp - ggml-cpu/spacemit/ggml_spacemit_ime.h - ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp - ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h + ggml-cpu/spacemit/ime.cpp + ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/ime_kernels.cpp + ggml-cpu/spacemit/ime_kernels.h ) endif() set(MARCH_STR "rv64gc") diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index c89a2b23ef6a4..4ee528bfc85e8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -19,7 +19,7 @@ #endif #ifdef GGML_USE_CPU_RISCV64_SPACEMIT -# include "spacemit/ggml_spacemit_ime.h" +# include "spacemit/ime.h" #endif #if defined(_WIN32) diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp similarity index 99% rename from ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp rename to ggml/src/ggml-cpu/spacemit/ime.cpp index f5841bfcee808..8b156bc084440 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -17,11 +17,14 @@ #include // for qsort #include // for GGML_ASSERT #include +#include +#include -#include "ggml_spacemit_ime.h" -#include "ggml_spacemit_ime_kernels.h" -#include "vec.h" +#include +#include "ime.h" +#include "ime_kernels.h" +#include "vec.h" #if defined(__riscv) diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h b/ggml/src/ggml-cpu/spacemit/ime.h similarity index 55% rename from ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h rename to ggml/src/ggml-cpu/spacemit/ime.h index ab71bfae4e7b0..723b3da6e0f58 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -1,9 +1,5 @@ #pragma once -#include "traits.h" #include "ggml.h" -#include -// #include -// GGML internal header ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime_kernels.cpp similarity index 99% rename from ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp rename to ggml/src/ggml-cpu/spacemit/ime_kernels.cpp index c51a9bfb355d4..51fcfe4b14481 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.cpp @@ -1,4 +1,7 @@ -#include "ggml_spacemit_ime_kernels.h" +#include +#include +#include "ime_kernels.h" +#include "ggml.h" #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Woverlength-strings" diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h similarity index 91% rename from ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h rename to ggml/src/ggml-cpu/spacemit/ime_kernels.h index af62472aabf84..22186a6a08152 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,9 +1,6 @@ #pragma once -#include -#include -#include "traits.h" -#include "ggml.h" +#include namespace sqnbitgemm_spacemit_ime { From dbc5bd62cf457a0111be3f6eba11543e8864f42c Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 23 Sep 2025 07:21:44 +0000 Subject: [PATCH 05/16] fixed for code format, after clang-format Change-Id: I5dc33a0412da3d3f2d77075d8939185d3009eca2 --- CODEOWNERS | 1 + ggml/src/ggml-cpu/CMakeLists.txt | 2 +- ggml/src/ggml-cpu/spacemit/ime.cpp | 1669 ++++++++--------- ggml/src/ggml-cpu/spacemit/ime.h | 13 +- .../{ime_kernels.cpp => ime1_kernels.cpp} | 514 +++-- ggml/src/ggml-cpu/spacemit/ime_kernels.h | 43 +- 6 files changed, 1099 insertions(+), 1143 deletions(-) rename ggml/src/ggml-cpu/spacemit/{ime_kernels.cpp => ime1_kernels.cpp} (90%) diff --git a/CODEOWNERS b/CODEOWNERS index 18564a08b1d6c..7eaf552c4b86d 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -10,3 +10,4 @@ /ggml/src/gguf.cpp @JohannesGaessler /ggml/src/ggml-vulkan/ @0cc4m /ggml/src/ggml-zdnn/ @taronaeo +/ggml/src/ggml-cpu/spacemit @alex-spacemit diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 17cbbf0e6a969..50bb9cac92bca 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -444,7 +444,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND GGML_CPU_SOURCES ggml-cpu/spacemit/ime.cpp ggml-cpu/spacemit/ime.h - ggml-cpu/spacemit/ime_kernels.cpp + ggml-cpu/spacemit/ime1_kernels.cpp ggml-cpu/spacemit/ime_kernels.h ) endif() diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 8b156bc084440..42501c1cfe453 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -1,35 +1,30 @@ +// Copyright (c) 2025 SpacemiT. All rights reserved. +// Licensed under the MIT License. #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP -#include "ggml-common.h" -#include "ggml-backend-impl.h" +#include "ime.h" -#include "ggml-quants.h" -#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml-common.h" #include "ggml-cpu.h" -#include "ggml-cpu-impl.h" +#include "ime_kernels.h" #include "traits.h" -#include -#include -#include -#include -#include // for qsort -#include // for GGML_ASSERT -#include #include +#include +#include +#include // for GGML_ASSERT #include +#include -#include - -#include "ime.h" -#include "ime_kernels.h" -#include "vec.h" - +// clang-format off #if defined(__riscv) #if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) #error "riscv v extension or v_intrinsic not enabled" +#else +#include #endif #if !defined(__riscv_zfh) @@ -53,48 +48,36 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif - #if defined(RISCV64_SPACEMIT_IME1) -#define QGEMM_STRIDEN_THREAD_ALIGN 16 +#define QGEMM_STRIDEN_THREAD_ALIGN 16 #else -#define QGEMM_STRIDEN_THREAD_ALIGN 32 +#define QGEMM_STRIDEN_THREAD_ALIGN 32 #endif -typedef enum { - ScaleFp32 = 0, - ScaleFp16, -} QNBIT_GEMM_SCALE_TYPE; - -template -struct QNBIT_GEMM_DATA_PARAMS { - const T* A = nullptr; ///< address of A (float32/16 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) - const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const T* Bias = nullptr; ///< optional address of Bias, vector size N - T* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C - - QNBIT_GEMM_SCALE_TYPE ScaleType = QNBIT_GEMM_SCALE_TYPE::ScaleFp32; ///< datatype of B scale(FP32 or FP16). +// clang-format on + +struct qnbitgemm_spacemit_ime_args { + const float * a_ptr = nullptr; + size_t lda = 0; + const std::byte * packed_quant_b_data = nullptr; + const float * quant_b_scale = nullptr; + const void * quant_b_zp = nullptr; + const float * quant_b_blksum = nullptr; + const float * bias = nullptr; + float * c_ptr = nullptr; + size_t ldc = 0; }; -constexpr -size_t -DivRoundup(size_t up, size_t down) -{ +constexpr size_t div_round_up(size_t up, size_t down) { return (up + down - 1) / down; } -constexpr size_t -Q8BlkSize(size_t BlkLen) -{ - const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + +constexpr size_t q8_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); // Currently, the strictest alignment requirement of a block is for a float. // Ensure contiguous blocks are suitably aligned. - assert(BlkSize % alignof(float) == 0); - return BlkSize; + assert(blk_size % alignof(float) == 0); + return blk_size; } namespace ggml::cpu::riscv64_spacemit { @@ -103,957 +86,941 @@ const int num_ai_cores = std::thread::hardware_concurrency() / 2; } // namespace ggml::cpu::riscv64_spacemit -static void SQ4BitGemm_CompInt8( - const size_t BlkLen, const size_t K, - const QNBIT_GEMM_DATA_PARAMS* const DataParams, - void* const PerGemmWorkspace, const size_t RangeStartM, - const size_t RangeCountM, const size_t RangeStartN, - const size_t RangeCountN) { - const size_t scale_stride = - DataParams->ScaleType == QNBIT_GEMM_SCALE_TYPE::ScaleFp16 - ? sizeof(uint16_t) - : sizeof(float); - - constexpr size_t BlkBitWidth = 4; - - const size_t k_blks = DivRoundup(K, BlkLen); - - const size_t lda = k_blks * Q8BlkSize(BlkLen); - const size_t ldc = DataParams->ldc; - const size_t ldb = k_blks * (BlkLen * BlkBitWidth / 8); - const std::byte* QuantA = - static_cast(PerGemmWorkspace) + RangeStartM * lda; - - const size_t zero_point_stride = DataParams->QuantBZeroPoint != nullptr ? sizeof(uint8_t) : 0; - const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); - const std::byte* QuantBData = - static_cast(DataParams->PackedQuantBData) + - RangeStartN * packed_b_stride; - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - size_t CountN; - const size_t ComputeBlockCountN = RangeCountM == 1 ? RangeCountN : 16; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, ComputeBlockCountN); - - const std::byte* a_row = QuantA; - const std::byte* b_col = QuantBData + n * packed_b_stride; - const std::byte* b_col_zp = (zero_point_stride != 0) - ? b_col - : nullptr; - float* c_blk = C + n; - - size_t RowsRemaining = RangeCountM; - - while (RowsRemaining > 0) { - const auto RowsHandled = sqnbitgemm_spacemit_ime::SQ4BitGemmKernel_CompInt8( - BlkLen, a_row, b_col, nullptr, b_col_zp, c_blk, - RowsRemaining, CountN, K, k_blks, ldc, nullptr, scale_stride); - - c_blk += RowsHandled * ldc; - a_row += RowsHandled * lda; - - RowsRemaining -= RowsHandled; - } - } -} +static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, + const size_t gemm_k, + const qnbitgemm_spacemit_ime_args * gemm_args, + void * const per_gemm_ws, + const size_t m_start, + const size_t m_count, + const size_t n_start, + const size_t n_count) { + constexpr size_t scale_stride = sizeof(uint16_t); + constexpr size_t blk_bitwidth = 4; + + const size_t k_blks = div_round_up(gemm_k, blk_len); + + const size_t lda = k_blks * q8_blk_size(blk_len); + const size_t ldc = gemm_args->ldc; + const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); + const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; + + const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; + const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); + const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; + + float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + + size_t count_n = 0; + const size_t compute_block_count_n = m_count == 1 ? n_count : 16; + for (size_t n = 0; n < n_count; n += count_n) { + count_n = std::min(n_count - n, compute_block_count_n); + + const std::byte * a_row = quant_a_ptr; + const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; + const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; + float * c_blk = c_ptr + n; + int32_t rows_remaining = m_count; + while (rows_remaining > 0) { + const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( + blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, + scale_stride); -template -constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; - } - return -1; + c_blk += rows_handled * ldc; + a_row += rows_handled * lda; + + rows_remaining -= rows_handled; + } + } } -template -struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks }; -template -struct block_with_zp { - ggml_half d[N]; // deltas for N qK_1 blocks - uint8_t zp[N]; // zero points for N qK_1 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks }; // control size static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); -static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), "wrong block_with_zp<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); using block_q4_0x16 = block<4, 16>; using block_q4_1x16 = block_with_zp<4, 16>; using block_q8_0x16 = block<8, 16>; -static block_q4_0x16 make_block_q4_0x16(block_q4_0* in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - GGML_ASSERT(QK4_0 / blck_size_interleave == 2); +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } } - } - - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } } - } - return out; + return out; } -static block_q4_1x16 make_block_q4_1x16(block_q4_1* in, unsigned int blck_size_interleave) { - block_q4_1x16 out; - GGML_ASSERT(QK4_1 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); - float mid = -std::nearbyintf(m / d); - mid = std::min(15.0f, std::max(0.0f, mid)); - out.d[i] = GGML_FP32_TO_FP16(d); - out.zp[i] = static_cast(mid); - } - - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } } - } - - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } } - } - return out; + return out; } -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 16); +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); - constexpr int nrows_interleaved = 16; + constexpr int nrows_interleaved = 16; - block_q4_0x16* dst = (block_q4_0x16*)t->data; - const block_q4_0* src = (const block_q4_0*)data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; } - src += nrows_interleaved * nblocks; - } - return 0; - GGML_UNUSED(data_size); + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); } -static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_1); - GGML_ASSERT(interleave_block == 16); +static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); - constexpr int nrows_interleaved = 16; + constexpr int nrows_interleaved = 16; - block_q4_1x16* dst = (block_q4_1x16*)t->data; - const block_q4_1* src = (const block_q4_1*)data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_1; + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; } - src += nrows_interleaved * nblocks; - } - return 0; - GGML_UNUSED(data_size); + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); } -static inline void get_scale_min_k4(int j, const uint8_t* GGML_RESTRICT q, uint8_t* GGML_RESTRICT d, uint8_t* GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; - } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); - } +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } } -static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 16); - GGML_ASSERT(QK_K / QK4_1 == 8); +static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); - constexpr int nrows_interleaved = 16; + constexpr int nrows_interleaved = 16; - block_q4_1x16* dst = (block_q4_1x16*)t->data; - const block_q4_K* src = (const block_q4_K*)data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int j = 0; j < 8; j++) { - for (int i = 0; i < nrows_interleaved; i++) { - uint8_t sc, m; - const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - const float min = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); - get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); - // src -> [b0, b32] [b1, b33] ... [b31, b63] - // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] - const uint8_t* q = src[x + i * nblocks].qs + (j / 2) * QK4_1; - if (j % 2 == 0) { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); - } - } else { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); } - } } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } + src += nrows_interleaved * nblocks; } - src += nrows_interleaved * nblocks; - } - return 0; + return 0; - GGML_UNUSED(data_size); + GGML_UNUSED(data_size); } namespace ggml::cpu::riscv64_spacemit { template -int repack(struct ggml_tensor*, const void*, size_t); +int repack(struct ggml_tensor *, const void *, size_t); -template <> -int repack(struct ggml_tensor* t, const void* data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); } -template <> -int repack(struct ggml_tensor* t, const void* data, size_t data_size) { - return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); } -template <> -int repack(struct ggml_tensor* t, const void* data, size_t data_size) { - return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); } class tensor_traits_base : public ggml::cpu::tensor_traits { - public: - virtual int repack(struct ggml_tensor* t, const void* data, size_t data_size) = 0; + public: + virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; }; -template -class tensor_traits : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor* op, size_t& size) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; - size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); - return true; - default: - // GGML_ABORT("fatal error"); - break; +template class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; + size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return false; - } - - bool compute_forward(struct ggml_compute_params* params, struct ggml_tensor* op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_Q4_0 || // - op->src[0]->type == GGML_TYPE_Q4_1 || // - op->src[0]->type == GGML_TYPE_Q4_K) { - forward_mul_mat_q4(params, op); - return true; + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_Q4_0 || // + op->src[0]->type == GGML_TYPE_Q4_1 || // + op->src[0]->type == GGML_TYPE_Q4_K) { + forward_mul_mat_q4(params, op); + return true; + } + default: + // GGML_ABORT("fatal error"); + break; } - default: - // GGML_ABORT("fatal error"); - break; + return false; } - return false; - } - void forward_mul_mat_q4(ggml_compute_params* params, ggml_tensor* op) { - const ggml_tensor* src0 = op->src[0]; - const ggml_tensor* src1 = op->src[1]; - ggml_tensor* dst = op; + void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; - GGML_TENSOR_BINARY_OP_LOCALS + GGML_TENSOR_BINARY_OP_LOCALS - int ith = params->ith; - int nth = params->nth; + int ith = params->ith; + int nth = params->nth; - [[maybe_unused]] const enum ggml_type type = src0->type; + [[maybe_unused]] const enum ggml_type type = src0->type; - void* w_data = (void*)src0->data; - const float* feature = (const float*)src1->data; - float* output = (float*)dst->data; + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; - const auto BatchN = ne12 * ne13; - [[maybe_unused]] const auto BatchWeight = ne02 * ne03; - const auto M = ne11; - const auto K = ne10; - const auto N = ne01; + const size_t batch_feature = ne12 * ne13; + [[maybe_unused]] const size_t batch_weight = ne02 * ne03; + const size_t gemm_m = ne11; + const size_t gemm_k = ne10; + const size_t gemm_n = ne01; - assert(BatchWeight == 1); - // constexpr size_t BlkBitWidth = 4; - constexpr size_t BlkLen = QK4_0; + GGML_ASSERT(batch_weight == 1); - size_t BlockCountK = DivRoundup(K, BlkLen); - const size_t Size = M * BlockCountK * Q8BlkSize(BlkLen); - auto Alignment = alignof(double); - size_t PerGemmWorkspaceStride = DivRoundup(Size, Alignment) * Alignment; - const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; - const size_t desired_wsize = WorkspaceSize + Alignment - 1; + const size_t block_count_k = div_round_up(gemm_k, QK4_0); + const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); + const size_t per_gemm_workspace_stride = + div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); + const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; + const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; - if (params->wsize < desired_wsize && ith == 0) { - throw std::runtime_error( - "wsize less than MlasSQNBitGemmBatchWorkspaceSize"); - } - - std::vector> DataParams(BatchN); - - for (int i = 0; i < BatchN; i++) { - DataParams[i].A = feature + M * K * i; - DataParams[i].lda = K; - DataParams[i].QuantBDataWorkspace = w_data; - DataParams[i].PackedQuantBData = (const std::byte*)w_data; - DataParams[i].QuantBScale = nullptr; - - if constexpr (std::is_same_v) { - DataParams[i].QuantBZeroPoint = nullptr; - } else { - DataParams[i].QuantBZeroPoint = (const uint8_t*)w_data; - } - - DataParams[i].Bias = nullptr; - DataParams[i].C = output + M * N * i; - DataParams[i].ldc = N; - DataParams[i].ScaleType = QNBIT_GEMM_SCALE_TYPE::ScaleFp16; - } - Alignment = alignof(double);; - const uintptr_t WorkspaceAddress = reinterpret_cast(params->wdata); - void* Workspace = reinterpret_cast( - (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))); - - BlockCountK = DivRoundup(K, BlkLen); - const auto PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); - PerGemmWorkspaceStride = - DivRoundup(PerGemmWorkspaceSize, Alignment) * Alignment; - - const auto QuantizeARow = sqnbitgemm_spacemit_ime::QuantizeARow_CompInt8; - const auto QuantizeAM4Row = sqnbitgemm_spacemit_ime::QuantizeAM4Row_CompInt8; - - BlockCountK = DivRoundup(K, BlkLen); - const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - - { - const size_t BlockSizeM = 4; - size_t BlockCountM = DivRoundup(M, BlockSizeM); - int task_count = BatchN * BlockCountM; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - auto gemm_idx = compute_idx / BlockCountM; - auto m_idx = compute_idx % BlockCountM * BlockSizeM; - const auto& data = DataParams[gemm_idx]; - auto RowsTobeHandled = (M - m_idx) > 4 ? 4 : (M - m_idx); - if (RowsTobeHandled == 4) { - const float* ARowPtr = data.A + m_idx * data.lda; - std::byte* QuantARowPtr = static_cast(Workspace) + - gemm_idx * PerGemmWorkspaceStride + - m_idx * QuantAStride; - QuantizeAM4Row(BlkLen, ARowPtr, K, QuantARowPtr); - - } else { - while (RowsTobeHandled) { - const float* ARowPtr = data.A + m_idx * data.lda; - std::byte* QuantARowPtr = static_cast(Workspace) + - gemm_idx * PerGemmWorkspaceStride + - m_idx * QuantAStride; - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - RowsTobeHandled -= 1; - m_idx += 1; - } + if (ith == 0 && params->wsize < desired_wsize) { + throw std::runtime_error("wsize less than desired_wsize"); } - } - } - ggml_barrier(params->threadpool); + std::vector qnbitgemm_args(batch_feature); - if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) - return; + for (size_t i = 0; i < batch_feature; i++) { + qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; + qnbitgemm_args[i].lda = gemm_k; + qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; + qnbitgemm_args[i].quant_b_scale = nullptr; - nth = std::min(nth, int{ggml::cpu::riscv64_spacemit::num_ai_cores}); + if constexpr (std::is_same_v) { + qnbitgemm_args[i].quant_b_zp = nullptr; + } else { + qnbitgemm_args[i].quant_b_zp = w_data; + } - int ThreadsPerGemm = nth / BatchN; - constexpr size_t StrideM = 128; + qnbitgemm_args[i].bias = nullptr; + qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; + qnbitgemm_args[i].ldc = gemm_n; + } - size_t nc = N; - const size_t BlockedM = DivRoundup(M, StrideM); - const size_t max_nc = DivRoundup(N * BlockedM, ThreadsPerGemm); - if (max_nc < nc) { - nc = std::min(nc, DivRoundup(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * - QGEMM_STRIDEN_THREAD_ALIGN); - } + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); + const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + + { + constexpr size_t block_size_m = 4; + size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); + int32_t task_count = batch_feature * per_gemm_block_count_m; + int32_t task_per_thread = (task_count + nth - 1) / nth; + int32_t start = ith * task_per_thread; + int32_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { + int32_t gemm_idx = compute_idx / block_size_m; + int32_t m_idx = compute_idx % block_size_m * block_size_m; + const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; + int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); + + if (rows_tobe_handled == block_size_m) { + const float * a_row_ptr = data.a_ptr + m_idx * data.lda; + std::byte * quant_a_row_ptr = + static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; + sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = data.a_ptr + m_idx * data.lda; + std::byte * quant_a_row_ptr = static_cast(ws) + + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; + sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } + } + } + + ggml_barrier(params->threadpool); - const size_t StrideN = nc; - const size_t ThreadCountM = DivRoundup(M, StrideM); - const size_t ThreadCountN = DivRoundup(N, StrideN); - ThreadsPerGemm = ThreadCountM * ThreadCountN; + if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { + return; + } + nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); + + size_t threads_per_gemm = nth / batch_feature; + constexpr size_t gemm_m_stride = 128; + size_t nc = gemm_n; + const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); + const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); + if (max_nc < nc) { + nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); + } + const size_t gemm_n_stride = nc; + const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); + const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); + threads_per_gemm = thread_count_m * thread_count_n; - { - int task_count = BatchN * ThreadsPerGemm; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - const auto gemm_i = compute_idx / ThreadsPerGemm; - const auto blk_i = compute_idx % ThreadsPerGemm; - const auto* Data = &DataParams[gemm_i]; + { + int task_count = batch_feature * threads_per_gemm; + int task_per_thread = (task_count + nth - 1) / nth; + int start = ith * task_per_thread; + int end = std::min((ith + 1) * task_per_thread, task_count); + for (int compute_idx = start; compute_idx < end; compute_idx++) { + const auto gemm_i = compute_idx / threads_per_gemm; + const auto blk_i = compute_idx % threads_per_gemm; + const auto * data = &qnbitgemm_args[gemm_i]; - const auto ThreadIdN = blk_i / ThreadCountM; - const auto ThreadIdM = blk_i % ThreadCountM; + const auto tid_n = blk_i / thread_count_m; + const auto tid_m = blk_i % thread_count_m; - const size_t RangeStartM = ThreadIdM * StrideM; - const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + const size_t m_start = tid_m * gemm_m_stride; + const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); - const size_t RangeStartN = ThreadIdN * StrideN; - const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + const size_t n_start = tid_n * gemm_n_stride; + const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); - void* PerGemmWorkspace = reinterpret_cast(Workspace) + - gemm_i * PerGemmWorkspaceStride; + void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; - SQ4BitGemm_CompInt8(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, - RangeCountM, RangeStartN, RangeCountN); - } + sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + } + } } - } - int repack(struct ggml_tensor* t, const void* data, size_t data_size) override { - GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), - (int)NB_COLS, (int)INTER_SIZE); - return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); - } + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), + (int) NB_COLS, (int) INTER_SIZE); + return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); + } }; class tensor_traits_common : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor* op, size_t& size) override { - switch (op->op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - size = 0; - return true; - default: - // GGML_ABORT("fatal error"); - break; + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + size = 0; + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return false; - } - - bool compute_forward(struct ggml_compute_params* params, struct ggml_tensor* op) override { - switch (op->op) { - case GGML_OP_NORM: - forward_norm_f32(params, op); - return true; - case GGML_OP_RMS_NORM: - forward_rms_norm_f32(params, op); - return true; - default: - // GGML_ABORT("fatal error"); - break; + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_NORM: + forward_norm_f32(params, op); + return true; + case GGML_OP_RMS_NORM: + forward_rms_norm_f32(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return false; - } - - void forward_norm_f32(ggml_compute_params* params, ggml_tensor* op) { - const ggml_tensor* src0 = op->src[0]; - ggml_tensor* dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto* input = (float*)src0->data; - auto* output = (float*)dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto* p_input = const_cast(input + offset); - - auto* p_output = output + offset; - auto* p_temp_output = p_output; - auto* p_gamma_data = (const float*) nullptr; - auto* p_beta_data = (const float*) nullptr; - size_t gvl = __riscv_vsetvlmax_e32m4(); - vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); - - sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); - - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); - - p_input += gvl; - p_temp_output += gvl; - length -= gvl; - } - - gvl = __riscv_vsetvlmax_e32m1(); - - float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - vfloat32m1_t mean_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); - mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); - mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); - mean /= hidden_size; - - vfloat32m1_t mean_square_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - mean_square = sqrt(mean_square - mean * mean + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + + void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (float *) src0->data; + auto * output = (float *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto * p_input = const_cast(input + offset); + + auto * p_output = output + offset; + auto * p_temp_output = p_output; + auto * p_gamma_data = (const float *) nullptr; + auto * p_beta_data = (const float *) nullptr; + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), + __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } } - } - - void forward_rms_norm_f32(ggml_compute_params* params, ggml_tensor* op) { - const ggml_tensor* src0 = op->src[0]; - ggml_tensor* dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto* input = (float*)src0->data; - auto* output = (float*)dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto* p_input = const_cast(input + offset); - auto* p_output = output + offset; - auto* p_temp_output = p_output; - auto* p_gamma_data = (const float*)nullptr; - auto* p_beta_data = (const float*)nullptr; - - size_t gvl = __riscv_vsetvlmax_e32m4(); - // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); - - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); - - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); - - p_input += gvl; - p_temp_output += gvl; - length -= gvl; - } - - gvl = __riscv_vsetvlmax_e32m1(); - - // float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - - vfloat32m1_t mean_square_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - - mean_square = sqrt(mean_square + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + + void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (float *) src0->data; + auto * output = (float *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto * p_input = const_cast(input + offset); + auto * p_output = output + offset; + auto * p_temp_output = p_output; + auto * p_gamma_data = (const float *) nullptr; + auto * p_beta_data = (const float *) nullptr; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + // float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + + vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), + __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } } - } - int repack(struct ggml_tensor* t, const void* data, size_t data_size) override { - memcpy(t->data, data, data_size); - return 0; - } + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } }; static const tensor_traits q4_0_16x8_q8_0; static const tensor_traits q4_1_16x8_q8_0; static const tensor_traits q4_k_16x8_q8_0; -static const tensor_traits_common rvv_impl; +static const tensor_traits_common rvv_impl; } // namespace ggml::cpu::riscv64_spacemit -static const ggml::cpu::tensor_traits* ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor* cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_1) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { + if (cur->type == GGML_TYPE_Q4_0) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_1) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_K) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_F32) { + return &ggml::cpu::riscv64_spacemit::rvv_impl; } - } else if (cur->type == GGML_TYPE_F32) { - return &ggml::cpu::riscv64_spacemit::rvv_impl; - } - return nullptr; + return nullptr; } -static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor* tensor) { - tensor->extra = (void*)const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); +static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, + struct ggml_tensor * tensor) { + tensor->extra = + (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); - GGML_UNUSED(buffer); + GGML_UNUSED(buffer); - return GGML_STATUS_SUCCESS; + return GGML_STATUS_SUCCESS; } -static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor* tensor, - const void* data, size_t offset, size_t size) { - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); - - auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base*)tensor->extra; - if (tensor_traits) { - auto OK = tensor_traits->repack(tensor, data, size); - GGML_ASSERT(OK == 0); - } +static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, + struct ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra; + if (tensor_traits) { + auto OK = tensor_traits->repack(tensor, data, size); + GGML_ASSERT(OK == 0); + } - GGML_UNUSED(buffer); + GGML_UNUSED(buffer); } -static const char* ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "CPU_RISCV64_SPACEMIT"; +static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_RISCV64_SPACEMIT"; - GGML_UNUSED(buft); + GGML_UNUSED(buft); } -static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); +static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - if (buffer == nullptr) { - return nullptr; - } - - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; } static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 64; + return 64; - GGML_UNUSED(buft); + GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const struct ggml_tensor* tensor) { - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - if (tensor->ne[i] <= 0) { - return 0; - } - } - - size_t nbytes; - const size_t blck_size = ggml_blck_size(tensor->type); - if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, + const struct ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + if (tensor->ne[i] <= 0) { + return 0; + } } - } else { - nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; - if (tensor->type == GGML_TYPE_Q4_K) { - GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); - nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - } + + size_t nbytes; + const size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } } else { - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } + nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + if (tensor->type == GGML_TYPE_Q4_K) { + GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); + nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + } + } else { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } } - } - GGML_UNUSED(buft); - return nbytes; + GGML_UNUSED(buft); + return nbytes; } namespace ggml::cpu::riscv64_spacemit { class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor* op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->buffer && - (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && - ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { - return false; - } - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } - } - break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - if (op->src[0]->type == GGML_TYPE_F32) { - return true; + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + if (op->src[0]->type == GGML_TYPE_F32) { + return true; + } + break; + default: + // GGML_ABORT("fatal error"); + break; } - break; - default: - // GGML_ABORT("fatal error"); - break; + return false; } - return false; - } - - ggml::cpu::tensor_traits* get_tensor_traits(const struct ggml_tensor* op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { - return (ggml::cpu::tensor_traits*)op->src[0]->extra; + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); + default: + // GGML_ABORT("fatal error"); + break; } - break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - return (ggml::cpu::tensor_traits*)(&ggml::cpu::riscv64_spacemit::rvv_impl); - default: - // GGML_ABORT("fatal error"); - break; - } - return nullptr; - } + return nullptr; + } }; } // namespace ggml::cpu::riscv64_spacemit ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { - /* .iface = */ { - /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment, - /* .get_max_size = */ nullptr, // defaults to SIZE_MAX - /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes, - /* .is_host = */ nullptr, - }, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), - /* .context = */ new ggml::cpu::riscv64_spacemit::extra_buffer_type(), - }; - - return &ggml_backend_cpu_buffer_type_riscv64_spacemit; + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + /* .iface = */ + { + /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, + /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes, + /* .is_host = */ nullptr, + }, + /* .device = */ + ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ + new ggml::cpu::riscv64_spacemit::extra_buffer_type(), + }; + + return &ggml_backend_cpu_buffer_type_riscv64_spacemit; } diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index 723b3da6e0f58..d98dd1c62d7bd 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -1,5 +1,16 @@ +// Copyright (c) 2025 SpacemiT. All rights reserved. +// Licensed under the MIT License. + #pragma once -#include "ggml.h" +#include "ggml-alloc.h" + +#ifdef __cplusplus +extern "C" { +#endif ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp similarity index 90% rename from ggml/src/ggml-cpu/spacemit/ime_kernels.cpp rename to ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index 51fcfe4b14481..438c7e5e9d1f1 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,31 +1,19 @@ -#include -#include -#include "ime_kernels.h" +// Copyright (c) 2025 SpacemiT. All rights reserved. +// Licensed under the MIT License. #include "ggml.h" +#include "ime_kernels.h" + +#include +#include +// clang-format off #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Woverlength-strings" #pragma GCC diagnostic ignored "-Wcast-qual" #pragma GCC diagnostic ignored "-Wunused-parameter" #endif - -// HPMADOT_OFF = 0 is not supported when ref_C used -#define HPMADOT_OFF 0 -namespace sqnbitgemm_spacemit_ime -{ -template -struct Q4_0X16 { - T scale[16]; - // uint8_t zp[8]; - // blklen = 16 - // uint8_t[16][8] - // b0 b8, b1 b9, b2 b10, b3 b11, b4 b12, b5 b13, b6 b14, b7 b15 - - // blklen = 32 - // uint8_t[16][2][8] - // b0 b8, b1 b9, b2 b10, b3 b11, b4 b12, b5 b13, b6 b14, b7 b15 - // b16 b24, b17 b25, b18 b26, b19 b27, b20 b28, b21 b29, b22 b30, b23 b31 -}; +// clang-format on +namespace sqnbitgemm_spacemit_ime { #define QUANTIZEM4ROW_KERNEL \ "vmv.s.x v16, zero \n\t" \ @@ -89,16 +77,15 @@ struct Q4_0X16 { "vsetvli t0, t1, e8, mf4 \n\t" \ "vse8.v v31, (s1) \n\t" -void -QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA) -{ +namespace ime1 { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); - const float fone = 1.0f; + const float fone = 1.0f; if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { for (size_t row_index = 0; row_index < 4; ++row_index) { - const float* SRC = A + row_index * CountK; - std::byte* DST = QuantA + row_index * sizeof(float); + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); const size_t offset = (4 - row_index) * 4 + row_index * 8; const size_t stride = 4 * (sizeof(float) + BlkLen); @@ -145,15 +132,15 @@ QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QUANTIZEM4ROW_STORE "QUIT%=: \n\t" - : [ SRC ] "+r"(SRC) - : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), - [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); } } else if (BlkLen == 128) { for (size_t row_index = 0; row_index < 4; ++row_index) { - const float* SRC = A + row_index * CountK; - std::byte* DST = QuantA + row_index * sizeof(float); + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); const size_t offset = (4 - row_index) * 4 + row_index * 8; const size_t stride = 4 * (sizeof(float) + BlkLen); @@ -215,17 +202,17 @@ QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* "jal x0, QUANTIZE%= \n\t" "QUIT%=: \n\t" - : [ SRC ] "+r"(SRC) - : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), - [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); } } else if (BlkLen == 256) { for (size_t row_index = 0; row_index < 4; ++row_index) { - const float* SRC = A + row_index * CountK; - std::byte* DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); __asm__ volatile( "vsetvli t0, zero, e32, m8 \n\t" "li t6, 32 \n\t" @@ -348,23 +335,21 @@ QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* "bnez t2, TAIL_LOOP%= \n\t" "QUIT%=: \n\t" - : [ SRC ] "+r"(SRC) - : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), - [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); } } } -void -QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA) -{ - const float* SRC = A; - std::byte* DST = QuantA; +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { + const float * SRC = A; + std::byte * DST = QuantA; constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); - const float fone = 1.0f; - std::byte* QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); - size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; + const float fone = 1.0f; + std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; if (CountK <= BlkLen) { float max_abs_A = 0.0f; @@ -373,14 +358,14 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q } float scale_A = max_abs_A * range_max_reciprocal; - ((float*)QuantA)[0] = scale_A; + ((float *) QuantA)[0] = scale_A; - auto* QuantAData_offset = (int8_t*)(QuantA + sizeof(float)); + auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); for (size_t k = 0; k < CountK; k++) { QuantAData_offset[k] = - (int8_t)std::clamp(roundf(A[k] / scale_A), (float)std::numeric_limits::lowest(), - (float)std::numeric_limits::max()); + (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), + (float) std::numeric_limits::max()); } for (size_t k = CountK; k < BlkLen; k++) { QuantAData_offset[k] = 0; @@ -399,12 +384,12 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q "addi %[DST], %[DST], 128 \n\t" "sub %[CNT], %[CNT], t0 \n\t" "bnez %[CNT], LOOP%= \n\t" - : [ DST ] "+r"(QuantA_offset), [ CNT ] "+r"(offset) + : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) : : "cc", "t0"); } if (BlkLen == 16) { - float buffer[64] = {0.0f}; + float buffer[64] = { 0.0f }; __asm__ volatile( "addi t3, zero, 16*8 \n\t" "addi t2, zero, 16 \n\t" @@ -720,8 +705,8 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q "vxor.vv v16, v16, v16 \n\t" "jal x0, LOOP_K%= \n\t" "END%=: \n\t" - : [ SRC ] "+r"(SRC), [ DST ] "+r"(DST), [ K ] "+r"(CountK) - : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ BUFFER ] "r"(buffer) + : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17"); } else if (BlkLen == 32) { @@ -856,8 +841,8 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q "vxor.vv v16, v16, v16 \n\t" "jal x0, LOOP_K%= \n\t" "END%=: \n\t" - : [ K ] "+r"(CountK) - : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ SRC ] "r"(SRC), [ DST ] "r"(DST) + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); } else if (BlkLen == 64) { __asm__ volatile( @@ -951,8 +936,8 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q "vxor.vv v16, v16, v16 \n\t" "jal x0, LOOP_K%= \n\t" "END%=: \n\t" - : [ K ] "+r"(CountK) - : [ SRC ] "r"(SRC), [ DST ] "r"(DST), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : [K] "+r"(CountK) + : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); } else if (BlkLen == 128) { __asm__ volatile( @@ -1011,12 +996,12 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q "jal x0, QUANT%= \n\t" "END%=: \n\t" - : [ DST ] "+r"(DST), [ K ] "+r"(CountK) - : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ SRC ] "r"(SRC) + : [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); } else { - float buffer[8] = {0.0f}; - size_t cnt = BlkLen / 256; + float buffer[8] = { 0.0f }; + size_t cnt = BlkLen / 256; __asm__ volatile( "slli t3, %[BLK], 2 \n\t" @@ -1175,15 +1160,16 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q "addi %[DST], %[DST], 64 \n\t" "bnez t6, TAIL_QUANT%= \n\t" "END%=: \n\t" - : [ SRC ] "+r"(SRC), [ DST ] "+r"(DST), [ K ] "+r"(CountK) - : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ BLK ] "r"(BlkLen), [ BUFFER ] "r"(buffer), - [ CNT ] "r"(cnt) + : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), + [CNT] "r"(cnt) : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); } } -namespace -{ +} // namespace ime1 + +namespace { #define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \ "vmadot v16, v14, v0 \n\t" \ "vmadot v18, v14, v1 \n\t" \ @@ -1467,45 +1453,43 @@ namespace "vadd.vi v1, v1, -12 \n\t" template -void -SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias, - const size_t ldc) -{ +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias, + const size_t ldc) { GGML_UNUSED(QuantBScale); GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); + size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; + float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(__fp16); // scale - float* CPtr = C + n; + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(__fp16); // scale + float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; - LDC = 16 * sizeof(float); + LDC = 16 * sizeof(float); } if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; if (NBLKS < 16) { __asm__ volatile( "vsetvli t0, %[N], e32, m2 \n\t" "vle32.v v0, (%[SRC]) \n\t" "vse32.v v0, (%[DST]) \n\t" : - : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) : "cc", "t0"); bias = tmp; } @@ -1583,8 +1567,8 @@ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); @@ -1661,32 +1645,32 @@ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); } } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(__fp16); // scale - float* CPtr = C + n; + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(__fp16); // scale + float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; - LDC = 16 * sizeof(float); + LDC = 16 * sizeof(float); } if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; if (NBLKS < 16) { __asm__ volatile( "vsetvli t0, %[N], e32, m2 \n\t" "vle32.v v0, (%[SRC]) \n\t" "vse32.v v0, (%[DST]) \n\t" : - : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) : "cc", "t0"); bias = tmp; } @@ -1745,8 +1729,8 @@ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); @@ -1807,8 +1791,8 @@ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); } @@ -1816,9 +1800,9 @@ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, } if (CountN % 16 != 0) { // stroe output from tmp to C when NBLKS less than 16. - float* CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); + float * CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); __asm__ volatile( "vsetvli t0, %[N], e32, m2 \n\t" "vle32.v v0, (%[SRC]) \n\t" @@ -1836,50 +1820,49 @@ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, "vse32.v v4, (t3) \n\t" "vse32.v v6, (t4) \n\t" : - : [ N ] "r"(N), [ SRC ] "r"(tmp), [ DST ] "r"(CPtr), [ LDC ] "r"(LDC) + : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); } } + template -void -SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias, - const size_t ldc) -{ +void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias, + const size_t ldc) { GGML_UNUSED(QuantBScale); GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); + size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; + float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float* CPtr = C + n; + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; - LDC = 16 * sizeof(float); + LDC = 16 * sizeof(float); } if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; if (NBLKS < 16) { __asm__ volatile( "vsetvli t0, %[N], e32, m2 \n\t" "vle32.v v0, (%[SRC]) \n\t" "vse32.v v0, (%[DST]) \n\t" : - : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) : "cc", "t0"); bias = tmp; } @@ -1953,8 +1936,8 @@ SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); @@ -2031,32 +2014,32 @@ SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); } } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale - float* CPtr = C + n; + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; - LDC = 16 * sizeof(float); + LDC = 16 * sizeof(float); } if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; if (NBLKS < 16) { __asm__ volatile( "vsetvli t0, %[N], e32, m2 \n\t" "vle32.v v0, (%[SRC]) \n\t" "vse32.v v0, (%[DST]) \n\t" : - : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) : "cc", "t0"); bias = tmp; } @@ -2115,8 +2098,8 @@ SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); @@ -2179,8 +2162,8 @@ SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, SAVE_RESULT_4x16 : - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), - [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", "s5", "s6"); } @@ -2188,9 +2171,9 @@ SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, } if (CountN % 16 != 0) { // stroe output from tmp to C when NBLKS less than 16. - float* CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); + float * CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); __asm__ volatile( "vsetvli t0, %[N], e32, m2 \n\t" "vle32.v v0, (%[SRC]) \n\t" @@ -2208,37 +2191,36 @@ SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, "vse32.v v4, (t3) \n\t" "vse32.v v6, (t4) \n\t" : - : [ N ] "r"(N), [ SRC ] "r"(tmp), [ DST ] "r"(CPtr), [ LDC ] "r"(LDC) + : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); } } + template -void -SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias) -{ +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias) { GGML_UNUSED(QuantBScale); GGML_UNUSED(QuantBZeroPoint); size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(__fp16); // scale - float* CPtr = C + n; - size_t cnt = BlockCountK; + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(__fp16); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; __asm__ volatile( "addi t3, %[NBLKS], 0 \n\t" "vsetvli t0, zero, e8, m1 \n\t" @@ -2356,8 +2338,8 @@ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } else { __asm__ volatile( @@ -2463,21 +2445,21 @@ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(__fp16); // scale - float* CPtr = C + n; - size_t cnt = BlockCountK; + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(__fp16); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; __asm__ volatile( "addi t3, %[NBLKS], 0 \n\t" "addi s1, %[B], 0 \n\t" @@ -2578,8 +2560,8 @@ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); } else { __asm__ volatile( @@ -2669,8 +2651,8 @@ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); } } @@ -2678,31 +2660,29 @@ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, } template -void -SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t BlockCountK, - const float* Bias) -{ +void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias) { GGML_UNUSED(QuantBScale); GGML_UNUSED(QuantBZeroPoint); const size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float* CPtr = C + n; - size_t cnt = BlockCountK; + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; __asm__ volatile( "addi t3, %[NBLKS], 0 \n\t" "vsetvli t0, zero, e8, m1 \n\t" @@ -2823,8 +2803,8 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } else { __asm__ volatile( @@ -2927,21 +2907,21 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte* QuantBDataPtr = (std::byte*)QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale - float* CPtr = C + n; - size_t cnt = BlockCountK; + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; if (Bias != nullptr) { - const float* bias = Bias + n; + const float * bias = Bias + n; __asm__ volatile( "addi t3, %[NBLKS], 0 \n\t" "addi s1, %[B], 0 \n\t" @@ -3035,8 +3015,8 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); } else { __asm__ volatile( @@ -3120,8 +3100,8 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, "vse32.v v31, (s3) \n\t" "END%=: \n\t" - : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) - : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); } } @@ -3129,20 +3109,18 @@ SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, } template -inline void -SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float* Bias, - const size_t ldc, - const size_t scalestride) -{ +inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float * Bias, + const size_t ldc, + const size_t scalestride) { if (scalestride == 4) { SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); @@ -3154,20 +3132,18 @@ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, } template -inline void -SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float* Bias, - const size_t ldc, - const size_t scalestride) -{ +inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float * Bias, + const size_t ldc, + const size_t scalestride) { if (scalestride == 4) { SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); @@ -3179,21 +3155,20 @@ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, } // namespace -size_t -SQ4BitGemmKernel_CompInt8(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float* Bias, - const size_t ScaleStride) -{ +namespace ime1 { +size_t gemm_kernel_i8i4(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float * Bias, + const size_t ScaleStride) { GGML_UNUSED(CountM); GGML_UNUSED(CountK); GGML_UNUSED(ldc); @@ -3219,4 +3194,5 @@ SQ4BitGemmKernel_CompInt8(size_t BlkLen, return 1; } } -} +} // namespace ime1 +} // namespace sqnbitgemm_spacemit_ime diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 22186a6a08152..1c65efb46ba29 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,27 +1,28 @@ +// Copyright (c) 2025 SpacemiT. All rights reserved. +// Licensed under the MIT License. #pragma once #include -namespace sqnbitgemm_spacemit_ime -{ -size_t -SQ4BitGemmKernel_CompInt8(size_t BlkLen, - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float* Bias, - const size_t ScaleStride); -void -QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); +namespace sqnbitgemm_spacemit_ime { +namespace ime1 { +size_t gemm_kernel_i8i4(size_t blk_len, + const std::byte * quant_a_ptr, + const std::byte * quant_b_data, + const float * quant_b_scale, + const std::byte * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t count_k, + size_t block_count_k, + size_t ldc, + const float * bias, + const size_t scale_stride); -void -QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); -} +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); + +} // namespace ime1 +} // namespace sqnbitgemm_spacemit_ime From bd62ed278b08dc894a5fde809d8efb3c205ee93f Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 23 Sep 2025 07:59:33 +0000 Subject: [PATCH 06/16] use _Float16 instead of __fp16 Change-Id: I039fb02bb95270e641bc4442204e658735859d43 --- cmake/riscv64-spacemit-linux-gnu-gcc.cmake | 1 - ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake index 1ec6d19a7435e..08fdbf506304e 100644 --- a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +++ b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake @@ -27,4 +27,3 @@ set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic") -add_definitions(-D__fp16=_Float16) diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index 438c7e5e9d1f1..95e9f263d501b 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1475,7 +1475,7 @@ void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, std::byte * QuantBDataPtr = (std::byte *) QuantBData + // n * BlockCountK * BlkLen / 2 + // b data n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(__fp16); // scale + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; @@ -1656,7 +1656,7 @@ void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; std::byte * QuantBDataPtr = (std::byte *) QuantBData + // n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(__fp16); // scale + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; @@ -2216,7 +2216,7 @@ void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, std::byte * QuantBDataPtr = (std::byte *) QuantBData + // n * BlockCountK * BlkLen / 2 + // b data n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(__fp16); // scale + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; if (Bias != nullptr) { @@ -2455,7 +2455,7 @@ void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; std::byte * QuantBDataPtr = (std::byte *) QuantBData + // n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(__fp16); // scale + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; if (Bias != nullptr) { From 986d8bd3fecfc6d05d7d60450019d8ee61c64b2d Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 23 Sep 2025 08:04:06 +0000 Subject: [PATCH 07/16] add ci for riscv64-spacemit-ime-native Change-Id: I711c1033061df1a289ea77891b2997599dfe8279 --- .github/workflows/build-riscv-native.yml | 59 ++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml index 86dc0ff76e59c..548cc6dab8c4c 100644 --- a/.github/workflows/build-riscv-native.yml +++ b/.github/workflows/build-riscv-native.yml @@ -58,3 +58,62 @@ jobs: -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH cmake --build build --config Release -j $(nproc) + + debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 + runs-on: self-hosted + + steps: + - name: Install prerequisites + run: | + sudo apt-get update || true + sudo apt-get install -y libatomic1 + - uses: actions/checkout@v4 + - name: Setup Riscv + run: | + sudo apt-get update || true + sudo apt-get install -y --no-install-recommends \ + build-essential \ + gcc-14-riscv64-linux-gnu \ + g++-14-riscv64-linux-gnu \ + ccache \ + cmake + + - name: Setup ccache + run: | + mkdir -p $HOME/.ccache + ccache -M 5G -d $HOME/.ccache + export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log + export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" + echo "$GITHUB_WORKSPACE" + echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV + echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV + echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV + echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV + + - name: Build + run: | + cmake -B build \ + -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ + -DGGML_RVV=ON \ + -DGGML_RV_ZFH=ON \ + -DGGML_RV_ZICBOP=ON \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 + + cmake --build build --config Release -j $(nproc) From dba48d8ea24e4499f0b3ad5449ec20d7135b94fd Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 25 Sep 2025 08:12:40 +0000 Subject: [PATCH 08/16] update debian-13-riscv64-spacemit-ime-native ci label Change-Id: Ifb2b891e2fca57b5da604fce2ac255f27731179a --- .github/workflows/build-riscv-native.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml index 87e6005eebf30..8ee3875cea66b 100644 --- a/.github/workflows/build-riscv-native.yml +++ b/.github/workflows/build-riscv-native.yml @@ -60,7 +60,7 @@ jobs: cmake --build build --config Release -j $(nproc) debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 - runs-on: self-hosted + runs-on: [self-hosted, RISCV64] steps: - name: Install prerequisites From 8dcbfc6d46021f59089f1fd0aa6c3694705b5ea4 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 25 Sep 2025 08:21:34 +0000 Subject: [PATCH 09/16] remove license comment for spacemit ime Change-Id: If0dc3ca30a958631ccca0a28b62e0b825f9fb0c3 --- ggml/src/ggml-cpu/spacemit/ime.cpp | 2 -- ggml/src/ggml-cpu/spacemit/ime.h | 3 --- ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp | 2 -- ggml/src/ggml-cpu/spacemit/ime_kernels.h | 2 -- 4 files changed, 9 deletions(-) diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 42501c1cfe453..54d3dece0e03a 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -1,5 +1,3 @@ -// Copyright (c) 2025 SpacemiT. All rights reserved. -// Licensed under the MIT License. #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index d98dd1c62d7bd..800d91acdaef6 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -1,6 +1,3 @@ -// Copyright (c) 2025 SpacemiT. All rights reserved. -// Licensed under the MIT License. - #pragma once #include "ggml-alloc.h" diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index 95e9f263d501b..cbbb6cd91607f 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,5 +1,3 @@ -// Copyright (c) 2025 SpacemiT. All rights reserved. -// Licensed under the MIT License. #include "ggml.h" #include "ime_kernels.h" diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 1c65efb46ba29..7570634150539 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,5 +1,3 @@ -// Copyright (c) 2025 SpacemiT. All rights reserved. -// Licensed under the MIT License. #pragma once #include From aea33c0186640eecff21633c70da1168d06ca062 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 25 Sep 2025 08:36:02 +0000 Subject: [PATCH 10/16] upgrade binutils for gcc ime Change-Id: Ibf2fa74c1064408974cb5b45f044d40987e5fb45 --- .github/workflows/build-riscv-native.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml index 8ee3875cea66b..45e979af8aec2 100644 --- a/.github/workflows/build-riscv-native.yml +++ b/.github/workflows/build-riscv-native.yml @@ -77,6 +77,7 @@ jobs: g++-14-riscv64-linux-gnu \ ccache \ cmake + sudo apt-get upgrade binutils -y - name: Setup ccache run: | From 4e7ab9e6508b5a72309529a37953bbfb0d4b6576 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 26 Sep 2025 02:41:27 +0000 Subject: [PATCH 11/16] add spacemit ime cross jobs Change-Id: I80d74909941d41cb9cd09e51d8baf01c985cbfc6 --- .github/workflows/build-linux-cross.yml | 30 +++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 04ad187d35c09..e62b35faff90b 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -344,3 +344,33 @@ jobs: -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH cmake --build build --config Release -j $(nproc) + + ubuntu-24-riscv64-cpu-spacemit-ime-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup Riscv + run: | + wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz + rm -rf spacemit-toolchain-linux-glibc-x86_64 + mkdir spacemit-toolchain-linux-glibc-x86_64 + tar xf spacemit-toolchain-linux-glibc-x86_64-*.tar.xz -C spacemit-toolchain-linux-glibc-x86_64 --strip-components=1 + + - name: Build + run: | + export RISCV_ROOT_PATH=${PWD}/spacemit-toolchain-linux-glibc-x86_64 + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DGGML_RVV=ON \ + -DGGML_RV_ZFH=ON \ + -DGGML_RV_ZICBOP=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ + -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake + + cmake --build build --config Release -j $(nproc) From 3e6ba8473482c79259247451fbeec9f809cfd995 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 26 Sep 2025 09:15:29 +0000 Subject: [PATCH 12/16] remove native compile for riscv64-spacemit-ime Change-Id: I01920afafdc73fa7424014fd648d243f8ec9e25e --- .github/workflows/build-riscv-native.yml | 110 +++++++++++------------ 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml index 45e979af8aec2..a3a0b0d6638ca 100644 --- a/.github/workflows/build-riscv-native.yml +++ b/.github/workflows/build-riscv-native.yml @@ -59,62 +59,62 @@ jobs: cmake --build build --config Release -j $(nproc) - debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 - runs-on: [self-hosted, RISCV64] + # debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 + # runs-on: [self-hosted, RISCV64] - steps: - - name: Install prerequisites - run: | - sudo apt-get update || true - sudo apt-get install -y libatomic1 - - uses: actions/checkout@v4 - - name: Setup Riscv - run: | - sudo apt-get update || true - sudo apt-get install -y --no-install-recommends \ - build-essential \ - gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu \ - ccache \ - cmake - sudo apt-get upgrade binutils -y + # steps: + # - name: Install prerequisites + # run: | + # sudo apt-get update || true + # sudo apt-get install -y libatomic1 + # - uses: actions/checkout@v4 + # - name: Setup Riscv + # run: | + # sudo apt-get update || true + # sudo apt-get install -y --no-install-recommends \ + # build-essential \ + # gcc-14-riscv64-linux-gnu \ + # g++-14-riscv64-linux-gnu \ + # ccache \ + # cmake + # sudo apt-get upgrade binutils -y - - name: Setup ccache - run: | - mkdir -p $HOME/.ccache - ccache -M 5G -d $HOME/.ccache - export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log - export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" - echo "$GITHUB_WORKSPACE" - echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV - echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV - echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV - echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV + # - name: Setup ccache + # run: | + # mkdir -p $HOME/.ccache + # ccache -M 5G -d $HOME/.ccache + # export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log + # export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" + # echo "$GITHUB_WORKSPACE" + # echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV + # echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV + # echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV + # echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV - - name: Build - run: | - cmake -B build \ - -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ - -DGGML_RVV=ON \ - -DGGML_RV_ZFH=ON \ - -DGGML_RV_ZICBOP=ON \ - -DGGML_CPU_RISCV64_SPACEMIT=ON \ - -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 + # - name: Build + # run: | + # cmake -B build \ + # -DLLAMA_CURL=OFF \ + # -DCMAKE_BUILD_TYPE=Release \ + # -DGGML_OPENMP=OFF \ + # -DLLAMA_BUILD_EXAMPLES=ON \ + # -DLLAMA_BUILD_TOOLS=ON \ + # -DLLAMA_BUILD_TESTS=OFF \ + # -DCMAKE_SYSTEM_NAME=Linux \ + # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + # -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + # -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ + # -DGGML_RVV=ON \ + # -DGGML_RV_ZFH=ON \ + # -DGGML_RV_ZICBOP=ON \ + # -DGGML_CPU_RISCV64_SPACEMIT=ON \ + # -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 - cmake --build build --config Release -j $(nproc) + # cmake --build build --config Release -j $(nproc) From 46fe5a15c0ca87154f21efc4a9f51f3f64a14dfa Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 29 Sep 2025 09:16:24 +0000 Subject: [PATCH 13/16] ci : add caching for spacemit ime cross toolchain Change-Id: Ic54a192019a2fd982bbd58225ce3bbc38f4053de --- .github/workflows/build-linux-cross.yml | 26 +++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index e62b35faff90b..e1d7d4dcf164e 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -348,18 +348,32 @@ jobs: ubuntu-24-riscv64-cpu-spacemit-ime-cross: runs-on: ubuntu-24.04 + env: + SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" + SPACEMIT_IME_TOOLCHAIN_PATH: "/tmp/spacemit-toolchain-linux-glibc-x86_64" + steps: - uses: actions/checkout@v4 - - name: Setup Riscv + + - name: Cache Toolchain + uses: actions/cache@v4 + id: cache-spacemit-ime-cross-toolchain + with: + path: /tmp/spacemit-toolchain-linux-glibc-x86_64 + key: ${{ runner.os }}-${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + + - name: Setup Toolchain + if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit!= 'true' run: | - wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz - rm -rf spacemit-toolchain-linux-glibc-x86_64 - mkdir spacemit-toolchain-linux-glibc-x86_64 - tar xf spacemit-toolchain-linux-glibc-x86_64-*.tar.xz -C spacemit-toolchain-linux-glibc-x86_64 --strip-components=1 + wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${env.SPACEMIT_IME_TOOLCHAIN_VERSION}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} --strip-components=1 + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz - name: Build run: | - export RISCV_ROOT_PATH=${PWD}/spacemit-toolchain-linux-glibc-x86_64 + export RISCV_ROOT_PATH=${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} cmake -B build -DLLAMA_CURL=OFF \ -DCMAKE_BUILD_TYPE=Release \ -DGGML_OPENMP=OFF \ From 33ced1f052e05e0de150626d100a01656507910b Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 29 Sep 2025 09:55:04 +0000 Subject: [PATCH 14/16] ci: bug fixed for cache path and env Change-Id: I28c42e10b6fff053bb6580926ca2353448cb042a --- .github/workflows/build-linux-cross.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index e1d7d4dcf164e..d2ef7b4344b8f 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -350,7 +350,7 @@ jobs: env: SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" - SPACEMIT_IME_TOOLCHAIN_PATH: "/tmp/spacemit-toolchain-linux-glibc-x86_64" + SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64" steps: - uses: actions/checkout@v4 @@ -359,21 +359,21 @@ jobs: uses: actions/cache@v4 id: cache-spacemit-ime-cross-toolchain with: - path: /tmp/spacemit-toolchain-linux-glibc-x86_64 - key: ${{ runner.os }}-${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + path: spacemit-toolchain-linux-glibc-x86_64 + key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} - name: Setup Toolchain - if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit!= 'true' + if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true' run: | - wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${env.SPACEMIT_IME_TOOLCHAIN_VERSION}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz - rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} - mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} - tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} --strip-components=1 - rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz + wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${env.SPACEMIT_IME_TOOLCHAIN_VERSION}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1 + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz - name: Build run: | - export RISCV_ROOT_PATH=${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} cmake -B build -DLLAMA_CURL=OFF \ -DCMAKE_BUILD_TYPE=Release \ -DGGML_OPENMP=OFF \ From 40c8d7e6a8bafc7c04295f310965dae4d76f6a63 Mon Sep 17 00:00:00 2001 From: alex-spacemit Date: Mon, 29 Sep 2025 18:17:48 +0800 Subject: [PATCH 15/16] Update .github/workflows/build-linux-cross.yml for cache path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- .github/workflows/build-linux-cross.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index d2ef7b4344b8f..7c44c7ae7dfd4 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -359,7 +359,7 @@ jobs: uses: actions/cache@v4 id: cache-spacemit-ime-cross-toolchain with: - path: spacemit-toolchain-linux-glibc-x86_64 + path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} - name: Setup Toolchain From b3eec832c2abd13a037dbd42c385ac220b50fd84 Mon Sep 17 00:00:00 2001 From: alex-spacemit Date: Mon, 29 Sep 2025 20:16:12 +0800 Subject: [PATCH 16/16] bugfixed for build-linux-cross.yml, syntax error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- .github/workflows/build-linux-cross.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 7c44c7ae7dfd4..05bee55e49d11 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -365,7 +365,7 @@ jobs: - name: Setup Toolchain if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true' run: | - wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${env.SPACEMIT_IME_TOOLCHAIN_VERSION}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1