Skip to content

Commit

Permalink
WOQ: Remove concat-linear implementation from kernel (#2617)
Browse files Browse the repository at this point in the history
* Remove num_concat for woq in the frontend to align with bf16 tpp linear

* Remove num_concat in woq linear kernel

* Fix clang-format issue
  • Loading branch information
Xia-Weiwen committed Feb 28, 2024
1 parent d6130df commit adb5638
Show file tree
Hide file tree
Showing 14 changed files with 28 additions and 201 deletions.
8 changes: 0 additions & 8 deletions csrc/cpu/aten/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ at::Tensor woq_linear_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
int64_t act_quant_mode) {
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
Expand All @@ -442,7 +441,6 @@ at::Tensor woq_linear_kernel(
bias_list,
w_dtype,
lowp_mode,
num_concats,
WOQ_FUSE_NONE, // no post op fusion
std::vector<at::Tensor>(),
act_quant_mode,
Expand Down Expand Up @@ -472,7 +470,6 @@ at::Tensor woq_linear_eltwise_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
int64_t act_quant_mode) {
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
int64_t post_op_fusion_type = WOQ_FUSE_NONE;
Expand All @@ -493,7 +490,6 @@ at::Tensor woq_linear_eltwise_kernel(
bias_list,
w_dtype,
lowp_mode,
num_concats,
post_op_fusion_type,
std::vector<at::Tensor>(),
act_quant_mode,
Expand Down Expand Up @@ -532,7 +528,6 @@ at::Tensor woq_linear_add_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
const std::vector<at::Tensor>& others,
int64_t act_quant_mode) {
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
Expand All @@ -546,7 +541,6 @@ at::Tensor woq_linear_add_kernel(
bias_list,
w_dtype,
lowp_mode,
num_concats,
WOQ_FUSE_ADD, // post op add
others,
act_quant_mode,
Expand All @@ -563,7 +557,6 @@ at::Tensor woq_linear_add_add_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
const std::vector<at::Tensor>& others,
int64_t act_quant_mode) {
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
Expand All @@ -577,7 +570,6 @@ at::Tensor woq_linear_add_add_kernel(
bias_list,
w_dtype,
lowp_mode,
num_concats,
WOQ_FUSE_ADD_ADD, // post op add-add
others,
act_quant_mode,
Expand Down
5 changes: 0 additions & 5 deletions csrc/cpu/aten/Linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ at::Tensor woq_linear_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
int64_t act_quant_mode);

at::Tensor woq_linear_eltwise_kernel(
Expand All @@ -120,7 +119,6 @@ at::Tensor woq_linear_eltwise_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
int64_t act_quant_mode);

at::Tensor woq_linear_add_kernel(
Expand All @@ -132,7 +130,6 @@ at::Tensor woq_linear_add_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
const std::vector<at::Tensor>& others,
int64_t act_quant_mode);

Expand All @@ -145,7 +142,6 @@ at::Tensor woq_linear_add_add_kernel(
bool is_int4,
int64_t group_size,
int64_t lowp_mode,
int64_t num_concats,
const std::vector<at::Tensor>& others,
int64_t act_quant_mode);

Expand Down Expand Up @@ -220,7 +216,6 @@ using woq_tpp_gemm_kernel_fn = at::Tensor (*)(
const int,
int64_t,
int64_t,
int64_t,
const std::vector<at::Tensor>&,
int64_t,
int64_t,
Expand Down
80 changes: 10 additions & 70 deletions csrc/cpu/aten/kernels/WoqTppKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,6 @@ void qlinear_woq_affine_impl(
at::Tensor y,
const int qw_type,
int k_splits,
int num_concats,
int fusion_type,
const TensorList& others_list,
int64_t quant_block_k,
Expand All @@ -1681,9 +1680,6 @@ void qlinear_woq_affine_impl(
quant_block_k == 0 ? 1 : (K + quant_block_k - 1) / quant_block_k;

TLA_ASSERT(Nb % 16 == 0, "Nb must be a multiple of 16");
TLA_ASSERT(
num_concats <= 1 || Nc % num_concats == 0,
"Nc must be a multiple of num_concats");

// select BLOCK_M according to M
// TODO(jgong5): improve the heuristic
Expand All @@ -1700,7 +1696,7 @@ void qlinear_woq_affine_impl(
auto BLOCK_M_rem = M % BLOCK_M;

// TODO(jgong5): use heuristics to decide k_splits
if (k_splits <= 0 || num_concats > 1 || M >= 32 || BLOCK_M_rem) {
if (k_splits <= 0 || M >= 32 || BLOCK_M_rem) {
k_splits = 1;
}
TLA_ASSERT(Kc % k_splits == 0, "Kc must be a multiple of k_splits");
Expand All @@ -1713,15 +1709,13 @@ void qlinear_woq_affine_impl(
k_splits == 1;

auto lda = no_x_buf ? K : Kb;
auto ldy = num_concats <= 1 ? N : Nc / num_concats * Nb;
auto ldy = N;
auto ldc = (no_y_buf || k_splits > 1) ? ldy : Nb;

auto px = GetVLAPtr<T>(x, {Kc, Kb});
auto pw = GetVLAPtr<uint8_t>(
(uint8_t*)qw_packed.data_ptr(), {Kc, Kb * (is_4bit_flag ? Nb / 2 : Nb)});
auto py = GetVLAPtr<Tout>(y, {Nc, Nb}); /*[M, Nc, Nb]*/
auto py_concat = GetVLAPtr<Tout>(
y, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/
int scales_kc = quant_w_mode == QUANT_W_PER_CHANNEL ? QUANT_W_PER_K_BLOCK
: quant_k_blocks;
auto pscales = GetVLAPtr<TScale>(scales, {scales_kc, Nb});
Expand All @@ -1730,12 +1724,8 @@ void qlinear_woq_affine_impl(
auto pb = GetVLAPtr<TGemmOut>(b, {Nb});
auto tin0 = others_list.size() > 0 ? others_list[0] : at::Tensor{};
auto pin0 = GetVLAPtr<Tout>(tin0, {Nc, Nb}); /*[M, Nc, Nb]*/
auto pin0_concat = GetVLAPtr<Tout>(
tin0, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/
auto tin1 = others_list.size() > 1 ? others_list[1] : at::Tensor{};
auto pin1 = GetVLAPtr<Tout>(tin1, {Nc, Nb}); /*[M, Nc, Nb]*/
auto pin1_concat = GetVLAPtr<Tout>(
tin1, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/

auto copy_bias_out_tpp = CpyBiasTPP<TGemmOut>(BLOCK_M, Nb, ldy);
auto copy_bias_buf_tpp = CpyBiasTPP<TGemmOut>(BLOCK_M, Nb, Nb);
Expand All @@ -1754,19 +1744,9 @@ void qlinear_woq_affine_impl(
bool is_fusion_type_addrelated =
fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD;
auto post_ops_fn = [&](int m, int nc) {
Tout* y_ptr = num_concats <= 1
? (Tout*)py[m][nc]
: (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)];
Tout* tin0_ptr = is_fusion_type_addrelated ? num_concats <= 1
? (Tout*)pin0[m][nc]
: (Tout*)pin0_concat[nc / (Nc / num_concats)][m]
[nc % (Nc / num_concats)]
: nullptr;
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1
? (Tout*)pin1[m][nc]
: (Tout*)pin1_concat[nc / (Nc / num_concats)][m]
[nc % (Nc / num_concats)]
: nullptr;
Tout* y_ptr = (Tout*)py[m][nc];
Tout* tin0_ptr = is_fusion_type_addrelated ? (Tout*)pin0[m][nc] : nullptr;
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr;
if (fusion_type == FUSE_GELU_ERF) {
gelu_erf_fwd_tpp(y_ptr, y_ptr);
} else if (fusion_type == FUSE_ADD) {
Expand All @@ -1779,19 +1759,11 @@ void qlinear_woq_affine_impl(
}
};
auto post_ops_rem_fn = [&](int m, int nc) {
Tout* y_ptr = num_concats <= 1
? (Tout*)py[m][nc]
: (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)];
Tout* y_ptr = (Tout*)py[m][nc];
Tout* tin0_ptr = (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD)
? num_concats <= 1 ? (Tout*)pin0[m][nc]
: (Tout*)pin0_concat[nc / (Nc / num_concats)][m]
[nc % (Nc / num_concats)]
? (Tout*)pin0[m][nc]
: nullptr;
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1
? (Tout*)pin1[m][nc]
: (Tout*)pin1_concat[nc / (Nc / num_concats)][m]
[nc % (Nc / num_concats)]
: nullptr;
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr;
if (fusion_type == FUSE_GELU_ERF) {
gelu_erf_fwd_rem_tpp(y_ptr, y_ptr);
} else if (fusion_type == FUSE_ADD) {
Expand Down Expand Up @@ -1961,10 +1933,7 @@ void qlinear_woq_affine_impl(
}
}
bool is_rem = (m + BLOCK_M > M);
TGemmOut* y_ptr = num_concats <= 1
? (TGemmOut*)py[m][nc]
: (TGemmOut*)py_concat[nc / (Nc / num_concats)][m]
[nc % (Nc / num_concats)];
TGemmOut* y_ptr = (TGemmOut*)py[m][nc];
if (!is_rem) {
if (kc == 0) {
if (b.defined()) {
Expand Down Expand Up @@ -2073,10 +2042,7 @@ void qlinear_woq_affine_impl(
int kc_end = kc_start + Kc / k_splits;
int m = idx[2];
bool is_rem = (m + BLOCK_M > M);
auto y_out_ptr = num_concats <= 1
? py[m][nc]
: py_concat[nc / (Nc / num_concats)][m]
[nc % (Nc / num_concats)];
auto y_out_ptr = py[m][nc];
alignas(64) TGemmOut y_buf[BLOCK_M][Nb];
TGemmOut* y_ptr = y_private_ptr[my_id][m][nc];
if (k_splits > 1) {
Expand Down Expand Up @@ -3389,7 +3355,6 @@ at::Tensor qlinear_woq_affine(
const TensorList& bias_list,
const int qw_type,
int64_t lowp_mode,
int64_t num_concats,
int64_t fusion_type,
const TensorList& others_list,
int64_t quant_a_mode = -1,
Expand Down Expand Up @@ -3449,7 +3414,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k);
Expand All @@ -3470,7 +3434,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k,
Expand All @@ -3494,7 +3457,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k);
Expand All @@ -3515,7 +3477,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k,
Expand Down Expand Up @@ -3544,7 +3505,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k);
Expand All @@ -3565,7 +3525,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k,
Expand All @@ -3589,7 +3548,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k);
Expand All @@ -3610,7 +3568,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k,
Expand Down Expand Up @@ -3639,7 +3596,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k);
Expand All @@ -3660,7 +3616,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k,
Expand Down Expand Up @@ -3697,7 +3652,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k,
Expand Down Expand Up @@ -3738,7 +3692,6 @@ at::Tensor qlinear_woq_affine(
y,
qw_type,
k_splits,
num_concats,
fusion_type,
others_list,
quant_block_k,
Expand Down Expand Up @@ -3858,12 +3811,6 @@ at::Tensor qlinear_woq_affine(
: bf16_idx;
y = at::add(y, biases[b_index]);
}
if (num_concats > 1) {
y = y.view({-1, num_concats, y.size(-1) / num_concats})
.transpose(0, 1)
.contiguous()
.view({-1, y.size(-1)});
}
if (fusion_type == FUSE_GELU_ERF) {
y = at::gelu(y);
} else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) {
Expand Down Expand Up @@ -3892,7 +3839,6 @@ at::Tensor qlinear_woq_affine(
const TensorList& bias_list,
const int qw_type,
int64_t lowp_mode,
int64_t num_concats,
int64_t fusion_type,
const TensorList& others_list,
int64_t quant_a_mode = -1,
Expand Down Expand Up @@ -4007,12 +3953,6 @@ at::Tensor qlinear_woq_affine(
: bf16_idx;
y = at::add(y, biases[b_index]);
}
if (num_concats > 1) {
y = y.view({-1, num_concats, y.size(-1) / num_concats})
.transpose(0, 1)
.contiguous()
.view({-1, y.size(-1)});
}
if (fusion_type == FUSE_GELU_ERF) {
y = at::gelu(y);
} else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) {
Expand Down
3 changes: 0 additions & 3 deletions csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ struct ContextLinearWoq final {
bool is_int4_;
int64_t group_size_;
int64_t lowp_mode_;
int64_t num_concats_;
int64_t act_quant_mode_;

ContextLinearWoq() = delete;
Expand All @@ -34,7 +33,6 @@ struct ContextLinearWoq final {
bool is_int4 = false,
int64_t group_size = -1,
int64_t lowp_mode = 0,
int64_t num_concats = 1,
int64_t act_quant_mode = 0)
: at_weight_(std::move(at_weight)),
weight_shape_(std::move(weight_shape)),
Expand All @@ -43,7 +41,6 @@ struct ContextLinearWoq final {
is_int4_(is_int4),
group_size_(group_size),
lowp_mode_(lowp_mode),
num_concats_(num_concats),
act_quant_mode_(act_quant_mode) {
// Make three dtype versions of scale, zp and bias
// There is one more dtype for zp
Expand Down

0 comments on commit adb5638

Please sign in to comment.