Skip to content

Commit

Permalink
Port woq kernel & tpp optimizations (#2412)
Browse files Browse the repository at this point in the history
* Port woq kernel & tpp optimizations: avoid unnecessary tensor creation; use uint64 as hash key for tpp

* Fix clang-format issue
  • Loading branch information
Xia-Weiwen committed Dec 28, 2023
1 parent f20b8b5 commit e039910
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 213 deletions.
34 changes: 20 additions & 14 deletions csrc/cpu/aten/kernels/WoqTppKrnl.cpp
Expand Up @@ -1323,7 +1323,7 @@ class DequantGemmTPP<
static_assert(N % 16 == 0, "N must be a multiple of 16");
if (std::is_same<Tin, bfloat16>())
TLA_ASSERT(K % 2 == 0, "Kb must be a multiple of 2 for bfloat16");
pgemm = std::make_shared<BrgemmTPP<Tin, Tout>>(
pgemm = new BrgemmTPP<Tin, Tout>(
M,
N,
K,
Expand All @@ -1338,6 +1338,10 @@ class DequantGemmTPP<
/*b_vnni*/ std::is_same<Tin, bfloat16>());
}

~DequantGemmTPP() {
delete pgemm;
}

inline void operator()(
Tin* A,
uint8_t* qB,
Expand Down Expand Up @@ -1430,7 +1434,7 @@ class DequantGemmTPP<
}

private:
std::shared_ptr<BrgemmTPP<Tin, Tout>> pgemm;
BrgemmTPP<Tin, Tout>* pgemm;
long M;
long K;
long lda;
Expand Down Expand Up @@ -1466,7 +1470,7 @@ class DequantGemmTPP<
static_assert(N % 16 == 0, "N must be a multiple of 16");
TLA_ASSERT(K % 4 == 0, "Kb must be a multiple of 4 for int8 VNNI");
// TODO(jgong5): output fp32 directly
pgemm = std::make_shared<TBrgemmTPP>(
pgemm = new TBrgemmTPP(
M,
N,
K,
Expand All @@ -1481,6 +1485,10 @@ class DequantGemmTPP<
/*b_vnni*/ true);
}

~DequantGemmTPP() {
delete pgemm;
}

inline void operator()(
uint8_t* A,
uint8_t* qB,
Expand Down Expand Up @@ -1620,7 +1628,7 @@ class DequantGemmTPP<
}

private:
std::shared_ptr<TBrgemmTPP> pgemm;
TBrgemmTPP* pgemm;
long M;
long K;
long lda;
Expand Down Expand Up @@ -1654,8 +1662,8 @@ void qlinear_woq_affine_impl(
const TensorList& others_list,
int64_t quant_block_k,
const std::optional<at::Tensor>& zps = std::nullopt, // dtype is TComp
at::Tensor t_scale_a = at::empty({1}, at::kFloat),
at::Tensor t_zp_a = at::empty({1}, at::kInt)) {
float* scales_a_ptr = nullptr,
int32_t* zps_a_ptr = nullptr) {
const bool is_4bit_flag = is_4bit(qw_type);
const bool sym_quant = is_sym_quant(qw_type);
auto x_sizes = x.sizes();
Expand Down Expand Up @@ -1708,8 +1716,6 @@ void qlinear_woq_affine_impl(
auto ldy = num_concats <= 1 ? N : Nc / num_concats * Nb;
auto ldc = (no_y_buf || k_splits > 1) ? ldy : Nb;

auto scales_a_ptr = t_scale_a.data_ptr<float>();
auto zps_a_ptr = t_zp_a.data_ptr<int32_t>();
auto px = GetVLAPtr<T>(x, {Kc, Kb});
auto pw = GetVLAPtr<uint8_t>(
(uint8_t*)qw_packed.data_ptr(), {Kc, Kb * (is_4bit_flag ? Nb / 2 : Nb)});
Expand Down Expand Up @@ -3273,8 +3279,6 @@ at::Tensor qlinear_woq_affine(
x_reshape_contig, &scale_a, &zp_a);
auto x_quantized = quantize_per_tensor<act_type>(
x_reshape_contig, scale_a, zp_a);
auto scale_a_t = at::full({1}, scale_a, at::kFloat);
auto zp_a_t = at::full({1}, zp_a, at::kInt);
qlinear_woq_affine_impl<
uint8_t,
uint8_t,
Expand All @@ -3296,8 +3300,8 @@ at::Tensor qlinear_woq_affine(
others_list,
quant_block_k,
zp_list[int8_idx],
scale_a_t,
zp_a_t);
&scale_a,
&zp_a);
} else {
auto block_k = w_sizes[2];
if (quant_block_k <= 0)
Expand All @@ -3312,6 +3316,8 @@ at::Tensor qlinear_woq_affine(
zp_a,
quant_block_k,
quant_a_mode);
float* scale_a_ptr = (float*)scale_a.data_ptr();
int32_t* zp_a_ptr = (int32_t*)zp_a.data_ptr();
range_dispatcher<
long,
QUANT_A_PER_K_BLOCK,
Expand Down Expand Up @@ -3340,8 +3346,8 @@ at::Tensor qlinear_woq_affine(
others_list,
quant_block_k,
zp_list[int8_idx],
scale_a,
zp_a);
scale_a_ptr,
zp_a_ptr);
},
[&](auto quant_a_mode_) { failing_fallback(); });
}
Expand Down

0 comments on commit e039910

Please sign in to comment.