Skip to content

Commit

Permalink
[CPP Graph] add s8 perchannel quant and kernel. (#181)
Browse files Browse the repository at this point in the history
* add s8 perchannel quant and kernel.

* add  QKV , add fusion support for s8 PerN

* add amx_int8 pern gelu fusion

* add gelu add fusion for vnni

* split jblas file. add compute type fp32.

* add comp_type fp32 for ffn fusion

* add bf16 for s4 and s4 ffn fusion

* add workspace for jblas functions

* keep one jblas code

* disable mmap as default. change arg --no_mmap to --use_mmap.
  • Loading branch information
luoyu-intel committed Sep 1, 2023
1 parent fc3ca18 commit 6ce8b13
Show file tree
Hide file tree
Showing 48 changed files with 2,459 additions and 22,617 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,26 @@ class JitAmxbf16 : protected JitAmxtile {
void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); }
};

class JitAmxint8 : protected JitAmxtile {};
class JitAmxint8 : protected JitAmxtile {
protected:
template <class, class>
void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3);
};
template <>
inline void JitAmxint8::_tdpb<int8_t, int8_t>(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
tdpbssd(x1, x2, x3);
}
template <>
inline void JitAmxint8::_tdpb<int8_t, uint8_t>(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
tdpbsud(x1, x2, x3);
}
template <>
inline void JitAmxint8::_tdpb<uint8_t, int8_t>(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
tdpbusd(x1, x2, x3);
}
template <>
inline void JitAmxint8::_tdpb<uint8_t, uint8_t>(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
tdpbuud(x1, x2, x3);
}
} // namespace xbyak
} // namespace jblas
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <tuple>

#include "jit_base.hpp"
#include "jit_blas.h"
#include "jit_blas_utils.h"
Expand All @@ -24,8 +26,10 @@ namespace gemm {
template <JBLAS_ISA ISA_T, typename _SRC_T, typename _DST_T>
class AccumulatorWriteBack {
public:
using SType = _SRC_T;
using DType = _DST_T;
struct Param {
_DST_T* C;
DType* C;
int ldc;
void* elt_const_v;
};
Expand All @@ -35,20 +39,24 @@ class AccumulatorWriteBack {
const int N, const Param& _param, Eltops... ops) {
auto COffset = M_offset * _param.ldc + N_offset;
auto cptr = _param.C + COffset;
bool constexpr Valid = !std::is_same<_DST_T, jblas::utils::bf16>::value ? true : std::is_same<_SRC_T, float>::value;
bool constexpr Valid = !std::is_same<DType, utils::bf16>::value || std::is_same<SType, float>::value;
static_assert(Valid, "fp32 to bf16 conversion only.");
if (std::is_same<_DST_T, jblas::utils::bf16>::value) {
if constexpr (std::is_same<DType, utils::bf16>::value) {
return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward<ISA_T>(
(void*)cacheptr, (void*)cptr, M, N, cachestep * sizeof(_SRC_T), _param.ldc * sizeof(_DST_T));
} else if (sizeof(_SRC_T) == sizeof(_DST_T)) {
return kernel::wrapper::Memcpy2D::template forward<ISA_T, _SRC_T, _DST_T>(
(void*)cacheptr, (void*)cptr, M, N * sizeof(_DST_T), cachestep * sizeof(_SRC_T), _param.ldc * sizeof(_DST_T),
(void*)cacheptr, (void*)cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false);
} else if constexpr (std::is_same<std::tuple<SType, DType>, std::tuple<utils::fp16, float>>::value) {
return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward<ISA_T>(
(void*)cacheptr, (void*)cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false);
} else if constexpr (sizeof(SType) == sizeof(DType)) {
return kernel::wrapper::Memcpy2D::template forward<ISA_T, SType, DType>(
(void*)cacheptr, (void*)cptr, M, N * sizeof(DType), cachestep * sizeof(SType), _param.ldc * sizeof(DType),
_param.elt_const_v, ops...);
} else {
assert(false);
}
}
};

template <JBLAS_ISA ISA_T, typename _SRC_T, typename _DST_T, JBLAS_ELTWISEOP _OP>
class CustomAccumulatorWriteBackWithEltop {
public:
Expand All @@ -61,7 +69,7 @@ class CustomAccumulatorWriteBackWithEltop {
const int N, const Param& _param) {
auto COffset = M_offset * _param.ldc + N_offset;
auto cptr = _param.C + COffset;
if (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) {
if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) {
return kernel::jit::CustomMemCpy::template forward<_OP>(cacheptr, cptr, M, N * sizeof(_DST_T),
cachestep * sizeof(_SRC_T), _param.ldc * sizeof(_DST_T),
_param.elt_const_v);
Expand All @@ -77,6 +85,8 @@ using AccumulatorWriteBackBf16 = AccumulatorWriteBack<ISA_T, utils::bf16, utils:
template <JBLAS_ISA ISA_T>
using AccumulatorWriteBackFp16 = AccumulatorWriteBack<ISA_T, utils::fp16, utils::fp16>;
template <JBLAS_ISA ISA_T>
using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack<ISA_T, utils::fp16, float>;
template <JBLAS_ISA ISA_T>
using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack<ISA_T, float, utils::bf16>;

template <JBLAS_ISA ISA_T>
Expand Down
Loading

0 comments on commit 6ce8b13

Please sign in to comment.