Skip to content

Commit

Permalink
Jira 688/s4 vnni kernel (#1084)
Browse files Browse the repository at this point in the history
* add vnni execution path

* use avx inst

* add compression type for vnni format

* enable jit decompress s4f32

* use new kernel for vnni s4 weight

* support bf16 scales.

* add s4_vnni_bf16 quant type

* add b128

* auto set max valid threads

* update threading of dynamic quant

* sync thread number with ggml and jblas

* optimize thread usage

* update kernel

---------

Co-authored-by: Meng, Hengyu <hengyu.meng@intel.com>
  • Loading branch information
luoyu-intel and airMeng committed Jul 3, 2023
1 parent 512a283 commit 3b7665c
Show file tree
Hide file tree
Showing 17 changed files with 3,227 additions and 420 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ static const std::map<std::string, model_ftype> NE_FTYPE_MAP = {
{"q4_j_b128", MODEL_FTYPE_MOSTLY_Q4_JBLAS_B128},
{"q4_j_b1024", MODEL_FTYPE_MOSTLY_Q4_JBLAS_B1024},
{"q4_j_bf16_b32", MODEL_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32},
{"q4_j_vnni_b32", MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32},
{"q4_j_vnni_b128", MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128},
{"q4_j_vnni_bf16_b32", MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32},

};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,30 @@ using namespace jblas;

void jblas_weights4block_f32_forward(float* activation, void* weiptr, float* output, int _m, int _n, int _k, int lda,
int ldo) {
using GemmKernel = jblas::wrapper::gemm_default::weight_comp::avx512f::GemmKernelS4KBlock;
using WeightType = GemmKernel::WeightType;
auto wtmp = WeightType::PackedWeightBase::deserialBuffer(weiptr, 0);
float alpha = 1.f, beta = 0.f;
static GemmKernel kernel;
auto ret = kernel.compute({_m, _n, _k, activation, lda, wtmp, output, output, ldo, ldo, alpha, beta});
auto wtmp = prologue::weight_comp::gemm::CompressedPackedWeight::deserialBuffer(weiptr, 0);
if (wtmp->mCoreType == jblas::gemm::GemmCoreType::AVX512_VNNI_8X48 ||
wtmp->mCoreType == jblas::gemm::GemmCoreType::AVX512_VNNI_3X48_KBLOCK) {
using GemmKernel = jblas::wrapper::gemm_default::weight_comp::avx512_vnni::GemmSKernelDynamicS4KBlock;
static GemmKernel kernel;
auto ret = kernel.compute({_m, _n, _k, activation, lda, wtmp, output, ldo});
} else if (wtmp->mCoreType == jblas::gemm::GemmCoreType::AVX512F_8X48) {
using GemmKernel = jblas::wrapper::gemm_default::weight_comp::avx512f::GemmKernelS4KBlock;
float alpha = 1.f, beta = 0.f;
static GemmKernel kernel;
auto ret = kernel.compute({_m, _n, _k, activation, lda, wtmp, output, output, ldo, ldo, alpha, beta});
}
delete wtmp;
}

void jblas_timer(bool _init) {
static utils::timer<utils::microseconds> tr;
if (_init)
tr.start();
else
printf("time :%f us\n", tr.stop());
}

int jblas_set_threads(int _nth) {
jblas::utils::parallel::CpuDevice::getInstance()->setThreads(_nth);
return jblas::utils::parallel::CpuDevice::getInstance()->getThreads();
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ extern "C" {
#endif
void jblas_weights4block_f32_forward(float* activation, void* weiptr, float* output, int _m, int _n, int _k, int lda,
int ldo);

void jblas_timer(bool _init);

int jblas_set_threads(int _nth);

#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9387,7 +9387,7 @@ static thread_ret_t ne_graph_compute_thread(void* data) {
#include <omp.h>
#endif
void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
const int n_threads = cgraph->n_threads;
int n_threads = cgraph->n_threads;

struct ne_compute_state_shared state_shared = {
/*.spin =*/NE_LOCK_INITIALIZER,
Expand Down Expand Up @@ -9425,6 +9425,7 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
}
}
#else
n_threads = jblas_set_threads(n_threads);
omp_set_num_threads(n_threads);
#endif

Expand All @@ -9437,7 +9438,15 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
struct ne_tensor* node = cgraph->nodes[i];

switch (node->op) {
case NE_OP_CPY:
case NE_OP_CPY: {
node->n_tasks = node->ne[0] == 1 ? n_threads : 1;

size_t cur = 0;
if (ne_is_quantized(node->type)) {
cur = NE_TYPE_SIZE[NE_TYPE_F32] * node->ne[0] * n_threads;
}
work_size = MAX(work_size, cur);
} break;
case NE_OP_DUP: {
node->n_tasks = n_threads;

Expand All @@ -9450,7 +9459,7 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
} break;
case NE_OP_ADD:
case NE_OP_ADD1: {
node->n_tasks = n_threads;
node->n_tasks = 1;

size_t cur = 0;

Expand Down Expand Up @@ -9484,15 +9493,15 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
case NE_OP_SGN:
case NE_OP_NEG:
case NE_OP_STEP:
case NE_OP_MUL:
case NE_OP_RMS_NORM:
case NE_OP_RELU: {
node->n_tasks = 1;
} break;
case NE_OP_MUL:
case NE_OP_GELU:
case NE_OP_SILU:
case NE_OP_SILU_BACK:
case NE_OP_NORM:
case NE_OP_RMS_NORM:
case NE_OP_RMS_NORM_BACK: {
node->n_tasks = n_threads;
} break;
Expand Down Expand Up @@ -9527,7 +9536,7 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
work_size = MAX(work_size, cur);
} break;
case NE_OP_SCALE: {
node->n_tasks = n_threads;
node->n_tasks = 1;
} break;
case NE_OP_SET:
case NE_OP_CONT:
Expand All @@ -9544,6 +9553,8 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
case NE_OP_DIAG_MASK_INF:
case NE_OP_SOFT_MAX:
case NE_OP_ROPE:
node->n_tasks = 1;
break;
case NE_OP_ROPE_BACK: {
node->n_tasks = n_threads;
} break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ class JitAvx512f : protected JitAvx2 {
vpsrld(_fp32, _fp32, 16);
vpmovdw(_bf16, _fp32);
}

void loadbf16_f32(const Xbyak::Zmm &dst, const Xbyak::Address &addr) {
vpmovzxwd(dst, addr);
vpslld(dst, dst, 16);
}

void broadcastbf16_f32(const Xbyak::Zmm &dst, const Xbyak::Reg64 &tmp,
const Xbyak::Address &addr) {
mov(tmp.cvt16(), addr);
shl(tmp.cvt32(), 16);
vpbroadcastd(dst, tmp.cvt32());
}

void store_fp32_bf16(const Xbyak::Zmm &_fp32, const Xbyak::Address &_add) {
auto bf16 = Xbyak::Ymm(_fp32.getIdx());
cvt_fp32_bf16(bf16, _fp32);
vmovups(_add, bf16);
}
};

class JitAvx512vnni : protected JitAvx512f {
Expand Down Expand Up @@ -203,21 +221,9 @@ class JitAmxtile : protected JitAvx512f {

class JitAmxbf16 : protected JitAmxtile {
protected:
void cvt_bf16_fp32(const Xbyak::Zmm &dst, const Xbyak::Zmm &src) {
vpslld(dst, src, 16);
}
void load_bf16_fp32(const Xbyak::Zmm &dst, const Xbyak::Address &addr) {
vpmovzxwd(dst, addr);
cvt_bf16_fp32(dst, dst);
}
void cvt_fp32_bf16(const Xbyak::Ymm &_bf16, const Xbyak::Zmm &_fp32) {
vcvtneps2bf16(_bf16, _fp32);
}
void store_fp32_bf16(const Xbyak::Zmm &_fp32, const Xbyak::Address &_add) {
auto bf16 = Xbyak::Ymm(_fp32.getIdx());
cvt_fp32_bf16(bf16, _fp32);
vmovups(_add, bf16);
}
};

class JitAmxint8 : protected JitAmxtile {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,27 @@
namespace jblas {
namespace epilogue {
namespace gemm {
class AlphaBetaProcessBase {};
class AlphaBetaProcessFp32 : protected AlphaBetaProcessBase {
template <typename _T>
class AccumulateWriteBack {
public:
struct Param {
_T *C;
int ldc;
};

template <JBLAS_ISA ISA_T>
JBLAS_CODE forward(const float *cacheptr, const int cachestep,
const int M_offset, const int N_offset, const int M,
const int N, const Param &_param) {
auto COffset = M_offset * _param.ldc + N_offset;
auto cptr = _param.C + COffset;
return kernel::wrapper::Memcpy2D::template forward<ISA_T>(
(void *)cacheptr, (void *)cptr, M, N * sizeof(_T), cachestep * sizeof(float),
_param.ldc * sizeof(_T));
}
};

class AlphaBetaProcessFp32 {
public:
struct Param {
float *C, *D;
Expand All @@ -24,12 +43,34 @@ class AlphaBetaProcessFp32 : protected AlphaBetaProcessBase {
auto COffset = M_offset * _param.ldc + N_offset;
auto cptr = _param.C + COffset;
auto dptr = _param.D + DOffset;
return kernel::wrapper::AlphaBetaF32F32::forward<ISA_T>(
return kernel::wrapper::AlphaBetaF32F32::template forward<ISA_T>(
_param.alpha, cacheptr, cachestep, _param.beta, dptr, _param.ldd, cptr,
_param.ldc, M, N);
}
};

class AlphaBetaProcessS32U8 {
public:
struct Param {
uint8_t *C;
int ldc;
float alpha;
float scaleAcc, scaleC;
int zpC;
};

template <JBLAS_ISA ISA_T>
JBLAS_CODE forward(const int32_t *cacheptr, const int cachestep,
const int M_offset, const int N_offset, const int M,
const int N, const Param &_param) {
auto COffset = M_offset * _param.ldc + N_offset;
auto cptr = _param.C + COffset;
return kernel::wrapper::QuanOutS32U32::template forward<ISA_T>(
_param.alpha, cacheptr, cachestep, cptr, _param.ldc, M, N,
_param.scaleAcc, _param.scaleC, _param.zpC);
}
};

} // namespace gemm
} // namespace epilogue
} // namespace jblas

0 comments on commit 3b7665c

Please sign in to comment.