Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class SchedulerBase : public Scheduler2D {
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
}
const float DensityThres = 32;
const float DensityThres = 16;
static size_t constexpr ReservedSize = 32ULL * 1024ULL;

virtual float calculate_score() {
Expand Down Expand Up @@ -364,7 +364,7 @@ class SchedulerKBlock : public Scheduler2D {
mL2Use += static_cast<size_t>(mBlock[1]) * mBlock[2] * mEleSize[1];
mL2Use += static_cast<size_t>(mStep[0]) * mBlock[2] * mEleSize[0];
}
const float DensityThres = 32;
const float DensityThres = 16;

float calculate_score() {
int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N;
Expand Down Expand Up @@ -492,10 +492,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
assert(this->mBlock[0]>0);
assert(this->mBlock[1]>0);
assert(this->mBlock[2]>0);
assert(this->mBlock[2] % _GemmCore_T::KTILE == 0);
}

protected:
const float DensityThres = 32;
const float DensityThres = 16;
static size_t constexpr ReservedSize = 32ULL * 1024ULL;

void cache_blocking_compute() override {
Expand Down Expand Up @@ -529,6 +530,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
(this->mStep[0] * this->mEleSize[0] +
float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock +
this->mBlock[1] * this->mEleSize[1]));
if (rawk < this->mKBlock) {
rawk = static_cast<int>((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] -
1 * CorSize * (this->mStep[0] + this->mBlock[1])) /
(this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1]));
}
rawk = std::min(rawk, this->mSizePadded[2]);
this->mBlock[2] = utils::padto_le(rawk, this->mStep[2]);
if (this->mBlock[2] > this->mKBlock) {
Expand Down Expand Up @@ -569,9 +575,6 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> {
this->mBlock[2] = static_cast<int>(getMaxK(this->mBlock[1]));
this->mBlock[2] = utils::padto_le(this->mBlock[2], this->mStep[2]);
this->mBlock[2] = std::min(mKBlock, this->mBlock[2]);
auto tmp = utils::updiv(mKBlock, this->mBlock[2]);
while (mKBlock % tmp != 0) tmp++; // TODO(Yu) optimize
this->mBlock[2] = utils::downdiv(mKBlock, tmp);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,14 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
for (; j < align_col; j += 8) quant();
for (; j < col; j++) {
auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type);
if constexpr (std::is_same_v<_S_T, utils::f8>) {
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
} else if constexpr (std::is_same_v<_S_T, float>) {
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
if constexpr (WITH_SCALE) {
if constexpr (std::is_same_v<_S_T, utils::f8>) {
dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x);
} else if constexpr (std::is_same_v<_S_T, float>) {
dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW];
}
} else {
dstptr[i * ld_dst + j] = fp_v;
}
}
}
Expand Down Expand Up @@ -636,6 +640,14 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(
vzps[iv] = _mm256_cvtepi8_epi32(tmp);
}
}
auto rowre = row - irow;
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter16 = 0; iter16 < Loop16; iter16++)
pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 8 * iter16));
for (int iterr = 0; iterr < UnrollRow; iterr++)
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps);
}
for (; irow < row; irow++) {
if constexpr (_NCOL == 24) {
pad_bit4_16(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,28 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr,
vzps[iv] = _mm512_cvtepi8_epi32(tmp);
}
}
}
for (; irow < row; irow++) {
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
if constexpr (_IS_SYM) {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
} else {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
auto rowre = row - irow;
int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow;
for (; irow < rowpad4; irow += UnrollRow) {
for (int iter64 = 0; iter64 < Loop64; iter64++) {
pad_bit4(tmpbuf + iter64 * 64, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2 + 32 * iter64), zmm_mask,
LoadMask64);
}
for (int iterr = 0; iterr < UnrollRow; iterr++) {
if constexpr (_IS_SYM) {
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr);
} else {
dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps);
}
}
}
for (; irow < row; irow++) {
pad_bit4(tmpbuf, reinterpret_cast<int8_t*>(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48);
if constexpr (_IS_SYM) {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr);
} else {
dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps);
}
}
}
return JblasSuccess;
Expand Down Expand Up @@ -563,9 +578,8 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int
auto sptr = scales + kpos * NPad;
int j = 0;
auto quant = [&](__mmask16 mask) {
__m128i f8_src;
auto sign_revert =
_mm512_cvtepi8_epi32(_mm_mask_loadu_epi8(f8_src, mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
_mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)));
auto e_revert = sign_revert;
auto mantissa_revert = sign_revert;
sign_revert = _mm512_slli_epi32(sign_revert, 24);
Expand Down Expand Up @@ -888,10 +902,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
zmm2 = _mm512_add_ps(zmm2, zmm_zp);
zmm3 = _mm512_add_ps(zmm3, zmm_zp);
} else {
mask4 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
mask5 = _mm512_cmplt_ps_mask(zmm1, zmm_v0);
mask6 = _mm512_cmplt_ps_mask(zmm2, zmm_v0);
mask7 = _mm512_cmplt_ps_mask(zmm3, zmm_v0);
mask4 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
mask5 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 1);
mask6 = _mm512_cmp_ps_mask(zmm2, zmm_v0, 1);
mask7 = _mm512_cmp_ps_mask(zmm3, zmm_v0, 1);

zmm0 = _mm512_abs_ps(zmm0);
zmm1 = _mm512_abs_ps(zmm1);
Expand All @@ -908,10 +922,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src
zmm5 = _mm512_sub_ps(zmm1, sub_v);
zmm6 = _mm512_sub_ps(zmm2, sub_v);
zmm7 = _mm512_sub_ps(zmm3, sub_v);
mask0 = _mm512_cmple_ps_mask(zmm4, zmm_v0);
mask1 = _mm512_cmple_ps_mask(zmm5, zmm_v0);
mask2 = _mm512_cmple_ps_mask(zmm6, zmm_v0);
mask3 = _mm512_cmple_ps_mask(zmm7, zmm_v0);
mask0 = _mm512_cmp_ps_mask(zmm4, zmm_v0, 2);
mask1 = _mm512_cmp_ps_mask(zmm5, zmm_v0, 2);
mask2 = _mm512_cmp_ps_mask(zmm6, zmm_v0, 2);
mask3 = _mm512_cmp_ps_mask(zmm7, zmm_v0, 2);
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
xmm1 = _mm_mask_blend_epi8(mask1, xmm1, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
xmm2 = _mm_mask_blend_epi8(mask2, xmm2, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
Expand Down Expand Up @@ -949,7 +963,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
auto zp = _mm512_set1_ps(0.8480964004993439f);
zmm0 = _mm512_add_ps(zmm0, zp);
} else {
mask1 = _mm512_cmplt_ps_mask(zmm0, zmm_v0);
mask1 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1);
zmm0 = _mm512_abs_ps(zmm0);
}
constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8;
Expand All @@ -959,7 +973,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src
if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]);
if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]);
zmm1 = _mm512_sub_ps(zmm0, sub_v);
mask0 = _mm512_cmple_ps_mask(zmm1, zmm_v0);
mask0 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 2);
xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast<const __m128i*>(broadcast_f4_v + i * 16)));
zmm0 = _mm512_mask_add_ps(zmm0, mask0, zmm0, avoid_double_cmp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,25 +230,47 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) {
dstptr[7] = tmp;
}

inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) {
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
auto tmp = static_cast<int>(src32 & 0xf);
dstptr[0] = static_cast<int8_t>(tmp);
tmp = static_cast<int>(src32 & 0xf0) >> 4;
dstptr[1] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf00) >> 8);
dstptr[2] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf000) >> 12);
dstptr[3] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf0000) >> 16);
dstptr[4] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf00000) >> 20);
dstptr[5] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf000000) >> 24);
dstptr[6] = static_cast<int8_t>(tmp);
tmp = static_cast<int>((src32 & 0xf0000000) >> 28);
dstptr[7] = static_cast<int8_t>(tmp);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::S4_FULLRANGE>(int8_t* dstptr, int8_t* srcptr) {
auto src32 = *reinterpret_cast<uint32_t*>(srcptr);
auto tmp = static_cast<int8_t>(src32 & 0xf);
dstptr[0] = tmp - 8;
tmp = static_cast<int8_t>(src32 & 0xf0) >> 4;
dstptr[1] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf00) >> 8);
dstptr[2] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf000) >> 12);
dstptr[3] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf0000) >> 16);
dstptr[4] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf00000) >> 20);
dstptr[5] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf000000) >> 24);
dstptr[6] = tmp - 8;
tmp = static_cast<int8_t>((src32 & 0xf0000000) >> 28);
dstptr[7] = tmp - 8;
convert_s4_s8_8_lowbits(dstptr, srcptr);
for (size_t i = 0; i < 8; i++) {
dstptr[i] -= 8;
}
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_BNB>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_NF4>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_E2M1>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <JBLAS_DTYPE S4_T>
Expand Down