diff --git a/csrc/cpu/aten/Linear.cpp b/csrc/cpu/aten/Linear.cpp index 3f8ee4e26..f0404af52 100644 --- a/csrc/cpu/aten/Linear.cpp +++ b/csrc/cpu/aten/Linear.cpp @@ -429,7 +429,6 @@ at::Tensor woq_linear_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode) { int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8; int64_t quant_w_mode = group_size > 0 ? 1 : 0; @@ -442,7 +441,6 @@ at::Tensor woq_linear_kernel( bias_list, w_dtype, lowp_mode, - num_concats, WOQ_FUSE_NONE, // no post op fusion std::vector(), act_quant_mode, @@ -472,7 +470,6 @@ at::Tensor woq_linear_eltwise_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode) { int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8; int64_t post_op_fusion_type = WOQ_FUSE_NONE; @@ -493,7 +490,6 @@ at::Tensor woq_linear_eltwise_kernel( bias_list, w_dtype, lowp_mode, - num_concats, post_op_fusion_type, std::vector(), act_quant_mode, @@ -532,7 +528,6 @@ at::Tensor woq_linear_add_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, const std::vector& others, int64_t act_quant_mode) { int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8; @@ -546,7 +541,6 @@ at::Tensor woq_linear_add_kernel( bias_list, w_dtype, lowp_mode, - num_concats, WOQ_FUSE_ADD, // post op add others, act_quant_mode, @@ -563,7 +557,6 @@ at::Tensor woq_linear_add_add_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, const std::vector& others, int64_t act_quant_mode) { int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8; @@ -577,7 +570,6 @@ at::Tensor woq_linear_add_add_kernel( bias_list, w_dtype, lowp_mode, - num_concats, WOQ_FUSE_ADD_ADD, // post op add-add others, act_quant_mode, diff --git a/csrc/cpu/aten/Linear.h b/csrc/cpu/aten/Linear.h index 969c58b14..605395293 100644 --- a/csrc/cpu/aten/Linear.h +++ b/csrc/cpu/aten/Linear.h @@ -105,7 +105,6 @@ at::Tensor woq_linear_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode); at::Tensor woq_linear_eltwise_kernel( @@ -120,7 +119,6 @@ at::Tensor woq_linear_eltwise_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode); at::Tensor woq_linear_add_kernel( @@ -132,7 +130,6 @@ at::Tensor woq_linear_add_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, const std::vector& others, int64_t act_quant_mode); @@ -145,7 +142,6 @@ at::Tensor woq_linear_add_add_kernel( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, const std::vector& others, int64_t act_quant_mode); @@ -220,7 +216,6 @@ using woq_tpp_gemm_kernel_fn = at::Tensor (*)( const int, int64_t, int64_t, - int64_t, const std::vector&, int64_t, int64_t, diff --git a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp index 6d7cf4c0d..da298f1f5 100644 --- a/csrc/cpu/aten/kernels/WoqTppKrnl.cpp +++ b/csrc/cpu/aten/kernels/WoqTppKrnl.cpp @@ -1658,7 +1658,6 @@ void qlinear_woq_affine_impl( at::Tensor y, const int qw_type, int k_splits, - int num_concats, int fusion_type, const TensorList& others_list, int64_t quant_block_k, @@ -1681,9 +1680,6 @@ void qlinear_woq_affine_impl( quant_block_k == 0 ? 1 : (K + quant_block_k - 1) / quant_block_k; TLA_ASSERT(Nb % 16 == 0, "Nb must be a multiple of 16"); - TLA_ASSERT( - num_concats <= 1 || Nc % num_concats == 0, - "Nc must be a multiple of num_concats"); // select BLOCK_M according to M // TODO(jgong5): improve the heuristic @@ -1700,7 +1696,7 @@ void qlinear_woq_affine_impl( auto BLOCK_M_rem = M % BLOCK_M; // TODO(jgong5): use heuristics to decide k_splits - if (k_splits <= 0 || num_concats > 1 || M >= 32 || BLOCK_M_rem) { + if (k_splits <= 0 || M >= 32 || BLOCK_M_rem) { k_splits = 1; } TLA_ASSERT(Kc % k_splits == 0, "Kc must be a multiple of k_splits"); @@ -1713,15 +1709,13 @@ void qlinear_woq_affine_impl( k_splits == 1; auto lda = no_x_buf ? K : Kb; - auto ldy = num_concats <= 1 ? N : Nc / num_concats * Nb; + auto ldy = N; auto ldc = (no_y_buf || k_splits > 1) ? ldy : Nb; auto px = GetVLAPtr(x, {Kc, Kb}); auto pw = GetVLAPtr( (uint8_t*)qw_packed.data_ptr(), {Kc, Kb * (is_4bit_flag ? Nb / 2 : Nb)}); auto py = GetVLAPtr(y, {Nc, Nb}); /*[M, Nc, Nb]*/ - auto py_concat = GetVLAPtr( - y, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/ int scales_kc = quant_w_mode == QUANT_W_PER_CHANNEL ? QUANT_W_PER_K_BLOCK : quant_k_blocks; auto pscales = GetVLAPtr(scales, {scales_kc, Nb}); @@ -1730,12 +1724,8 @@ void qlinear_woq_affine_impl( auto pb = GetVLAPtr(b, {Nb}); auto tin0 = others_list.size() > 0 ? others_list[0] : at::Tensor{}; auto pin0 = GetVLAPtr(tin0, {Nc, Nb}); /*[M, Nc, Nb]*/ - auto pin0_concat = GetVLAPtr( - tin0, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/ auto tin1 = others_list.size() > 1 ? others_list[1] : at::Tensor{}; auto pin1 = GetVLAPtr(tin1, {Nc, Nb}); /*[M, Nc, Nb]*/ - auto pin1_concat = GetVLAPtr( - tin1, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/ auto copy_bias_out_tpp = CpyBiasTPP(BLOCK_M, Nb, ldy); auto copy_bias_buf_tpp = CpyBiasTPP(BLOCK_M, Nb, Nb); @@ -1754,19 +1744,9 @@ void qlinear_woq_affine_impl( bool is_fusion_type_addrelated = fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD; auto post_ops_fn = [&](int m, int nc) { - Tout* y_ptr = num_concats <= 1 - ? (Tout*)py[m][nc] - : (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)]; - Tout* tin0_ptr = is_fusion_type_addrelated ? num_concats <= 1 - ? (Tout*)pin0[m][nc] - : (Tout*)pin0_concat[nc / (Nc / num_concats)][m] - [nc % (Nc / num_concats)] - : nullptr; - Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1 - ? (Tout*)pin1[m][nc] - : (Tout*)pin1_concat[nc / (Nc / num_concats)][m] - [nc % (Nc / num_concats)] - : nullptr; + Tout* y_ptr = (Tout*)py[m][nc]; + Tout* tin0_ptr = is_fusion_type_addrelated ? (Tout*)pin0[m][nc] : nullptr; + Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr; if (fusion_type == FUSE_GELU_ERF) { gelu_erf_fwd_tpp(y_ptr, y_ptr); } else if (fusion_type == FUSE_ADD) { @@ -1779,19 +1759,11 @@ void qlinear_woq_affine_impl( } }; auto post_ops_rem_fn = [&](int m, int nc) { - Tout* y_ptr = num_concats <= 1 - ? (Tout*)py[m][nc] - : (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)]; + Tout* y_ptr = (Tout*)py[m][nc]; Tout* tin0_ptr = (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) - ? num_concats <= 1 ? (Tout*)pin0[m][nc] - : (Tout*)pin0_concat[nc / (Nc / num_concats)][m] - [nc % (Nc / num_concats)] + ? (Tout*)pin0[m][nc] : nullptr; - Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1 - ? (Tout*)pin1[m][nc] - : (Tout*)pin1_concat[nc / (Nc / num_concats)][m] - [nc % (Nc / num_concats)] - : nullptr; + Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr; if (fusion_type == FUSE_GELU_ERF) { gelu_erf_fwd_rem_tpp(y_ptr, y_ptr); } else if (fusion_type == FUSE_ADD) { @@ -1961,10 +1933,7 @@ void qlinear_woq_affine_impl( } } bool is_rem = (m + BLOCK_M > M); - TGemmOut* y_ptr = num_concats <= 1 - ? (TGemmOut*)py[m][nc] - : (TGemmOut*)py_concat[nc / (Nc / num_concats)][m] - [nc % (Nc / num_concats)]; + TGemmOut* y_ptr = (TGemmOut*)py[m][nc]; if (!is_rem) { if (kc == 0) { if (b.defined()) { @@ -2073,10 +2042,7 @@ void qlinear_woq_affine_impl( int kc_end = kc_start + Kc / k_splits; int m = idx[2]; bool is_rem = (m + BLOCK_M > M); - auto y_out_ptr = num_concats <= 1 - ? py[m][nc] - : py_concat[nc / (Nc / num_concats)][m] - [nc % (Nc / num_concats)]; + auto y_out_ptr = py[m][nc]; alignas(64) TGemmOut y_buf[BLOCK_M][Nb]; TGemmOut* y_ptr = y_private_ptr[my_id][m][nc]; if (k_splits > 1) { @@ -3389,7 +3355,6 @@ at::Tensor qlinear_woq_affine( const TensorList& bias_list, const int qw_type, int64_t lowp_mode, - int64_t num_concats, int64_t fusion_type, const TensorList& others_list, int64_t quant_a_mode = -1, @@ -3449,7 +3414,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k); @@ -3470,7 +3434,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k, @@ -3494,7 +3457,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k); @@ -3515,7 +3477,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k, @@ -3544,7 +3505,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k); @@ -3565,7 +3525,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k, @@ -3589,7 +3548,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k); @@ -3610,7 +3568,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k, @@ -3639,7 +3596,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k); @@ -3660,7 +3616,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k, @@ -3697,7 +3652,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k, @@ -3738,7 +3692,6 @@ at::Tensor qlinear_woq_affine( y, qw_type, k_splits, - num_concats, fusion_type, others_list, quant_block_k, @@ -3858,12 +3811,6 @@ at::Tensor qlinear_woq_affine( : bf16_idx; y = at::add(y, biases[b_index]); } - if (num_concats > 1) { - y = y.view({-1, num_concats, y.size(-1) / num_concats}) - .transpose(0, 1) - .contiguous() - .view({-1, y.size(-1)}); - } if (fusion_type == FUSE_GELU_ERF) { y = at::gelu(y); } else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) { @@ -3892,7 +3839,6 @@ at::Tensor qlinear_woq_affine( const TensorList& bias_list, const int qw_type, int64_t lowp_mode, - int64_t num_concats, int64_t fusion_type, const TensorList& others_list, int64_t quant_a_mode = -1, @@ -4007,12 +3953,6 @@ at::Tensor qlinear_woq_affine( : bf16_idx; y = at::add(y, biases[b_index]); } - if (num_concats > 1) { - y = y.view({-1, num_concats, y.size(-1) / num_concats}) - .transpose(0, 1) - .contiguous() - .view({-1, y.size(-1)}); - } if (fusion_type == FUSE_GELU_ERF) { y = at::gelu(y); } else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) { diff --git a/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h b/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h index 8f01f4db8..43ec36df2 100644 --- a/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h +++ b/csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h @@ -19,7 +19,6 @@ struct ContextLinearWoq final { bool is_int4_; int64_t group_size_; int64_t lowp_mode_; - int64_t num_concats_; int64_t act_quant_mode_; ContextLinearWoq() = delete; @@ -34,7 +33,6 @@ struct ContextLinearWoq final { bool is_int4 = false, int64_t group_size = -1, int64_t lowp_mode = 0, - int64_t num_concats = 1, int64_t act_quant_mode = 0) : at_weight_(std::move(at_weight)), weight_shape_(std::move(weight_shape)), @@ -43,7 +41,6 @@ struct ContextLinearWoq final { is_int4_(is_int4), group_size_(group_size), lowp_mode_(lowp_mode), - num_concats_(num_concats), act_quant_mode_(act_quant_mode) { // Make three dtype versions of scale, zp and bias // There is one more dtype for zp diff --git a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp index 48efb5da0..e96e04638 100644 --- a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp +++ b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp @@ -22,7 +22,6 @@ c10::intrusive_ptr createWoqLinearPrePackOpContext( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode) { RECORD_FUNCTION( "ipex_prepack::createWoqLinearPrePackOpContext", @@ -39,7 +38,6 @@ c10::intrusive_ptr createWoqLinearPrePackOpContext( is_int4, group_size, lowp_mode, - num_concats, act_quant_mode); } @@ -52,7 +50,6 @@ c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( c10::optional batch_size, int64_t group_size, // group_size along input channel int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode) { RECORD_FUNCTION( "ipex_prepack::createWoqLinearPrePackOpContextInt4", @@ -169,7 +166,6 @@ c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( /*is_int4*/ true, group_size, lowp_mode, - num_concats, act_quant_mode); } @@ -193,7 +189,6 @@ ContextLinearWoq create( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode) { at::Tensor packed_weight; int64_t N = weight_shape[0]; @@ -242,7 +237,6 @@ ContextLinearWoq create( is_int4, group_size, lowp_mode, - num_concats, act_quant_mode); } else { return ContextLinearWoq( @@ -255,7 +249,6 @@ ContextLinearWoq create( is_int4, group_size, lowp_mode, - num_concats, act_quant_mode); } } @@ -269,7 +262,6 @@ ContextLinearWoq create( is_int4, group_size, lowp_mode, - num_concats, act_quant_mode); } @@ -311,7 +303,6 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) { context.is_int4_, context.group_size_, context.lowp_mode_, - context.num_concats_, context.act_quant_mode_); if (res.size(-1) != context.weight_shape_[0]) { int64_t N = context.weight_shape_[0]; @@ -351,7 +342,6 @@ at::Tensor run_eltwise( context.is_int4_, context.group_size_, context.lowp_mode_, - context.num_concats_, context.act_quant_mode_); } @@ -381,7 +371,6 @@ at::Tensor run_add( context.is_int4_, context.group_size_, context.lowp_mode_, - context.num_concats_, others, context.act_quant_mode_); } @@ -412,7 +401,6 @@ at::Tensor run_add_add( context.is_int4_, context.group_size_, context.lowp_mode_, - context.num_concats_, others, context.act_quant_mode_); } diff --git a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h index cc0081353..9fcf8138d 100644 --- a/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h +++ b/csrc/cpu/jit/cpu/kernels/LinearWoqPacked.h @@ -21,7 +21,6 @@ c10::intrusive_ptr createWoqLinearPrePackOpContext( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode); c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( @@ -33,7 +32,6 @@ c10::intrusive_ptr createWoqLinearPrePackOpContextInt4( c10::optional batch_size, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode); at::Tensor woq_linear_run( @@ -51,7 +49,6 @@ ContextLinearWoq create( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode); at::Tensor run(ContextLinearWoq& context, const at::Tensor& input); diff --git a/csrc/cpu/jit/cpu/kernels/OpContext.cpp b/csrc/cpu/jit/cpu/kernels/OpContext.cpp index f0f9daa3a..1038eab04 100644 --- a/csrc/cpu/jit/cpu/kernels/OpContext.cpp +++ b/csrc/cpu/jit/cpu/kernels/OpContext.cpp @@ -372,7 +372,6 @@ c10::intrusive_ptr IpexWoqLinearOpContext::create_context( bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode) { auto op_context = torch_ipex::cpu::detail::woq_linear::create( weight, @@ -385,7 +384,6 @@ c10::intrusive_ptr IpexWoqLinearOpContext::create_context( is_int4, group_size, lowp_mode, - num_concats, act_quant_mode); return c10::make_intrusive( batch_size, std::move(op_context)); diff --git a/csrc/cpu/jit/cpu/kernels/OpContext.h b/csrc/cpu/jit/cpu/kernels/OpContext.h index 9a3b99966..754938361 100644 --- a/csrc/cpu/jit/cpu/kernels/OpContext.h +++ b/csrc/cpu/jit/cpu/kernels/OpContext.h @@ -371,7 +371,6 @@ using SerializationTypeWoqLinearPrePack = std::tuple< bool, // is_int4 int64_t, // group size int64_t, // lowp_mode - int64_t, // num_concats int64_t>; // act_quant_mode class WoqLinearOpContext : public torch::jit::CustomClassHolder { @@ -397,7 +396,6 @@ class WoqLinearOpContext : public torch::jit::CustomClassHolder { this->get_context().is_int4_, this->get_context().group_size_, this->get_context().lowp_mode_, - this->get_context().num_concats_, this->get_context().act_quant_mode_); } @@ -506,7 +504,6 @@ class IpexWoqLinearOpContext final : public WoqLinearOpContext { bool is_int4, int64_t group_size, int64_t lowp_mode, - int64_t num_concats, int64_t act_quant_mode); virtual void load_from_ctx( diff --git a/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp b/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp index 79e2e1e83..b20f6a926 100644 --- a/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp +++ b/csrc/cpu/jit/cpu/kernels/RegisterOpContextClass.cpp @@ -141,8 +141,7 @@ TORCH_LIBRARY(ipex_prepack, m) { std::move(std::get<7>(state)), // is_int4 std::move(std::get<8>(state)), // group size std::move(std::get<9>(state)), // lowp_mode - std::move(std::get<10>(state)), // num_concats - std::move(std::get<11>(state))); // act_quant_mode + std::move(std::get<10>(state))); // act_quant_mode }) .def( "get_weight", @@ -182,10 +181,10 @@ TORCH_LIBRARY(ipex_prepack, m) { "-> __torch__.torch.classes.ipex_prepack.ConvTransposeOpContext"); #ifdef USE_LIBXSMM m.def( - "weight_only_qlinear_prepack(Tensor W, int[] W_shape, Tensor scales, Tensor zero_points, Tensor? B, Tensor? g_idx, int? batch_size, bool is_int4, int group_size, int lowp_mode, int num_concats, int act_quant_mode) " + "weight_only_qlinear_prepack(Tensor W, int[] W_shape, Tensor scales, Tensor zero_points, Tensor? B, Tensor? g_idx, int? batch_size, bool is_int4, int group_size, int lowp_mode, int act_quant_mode) " "-> __torch__.torch.classes.ipex_prepack.WoqLinearOpContext"); m.def( - "weight_only_qlinear_prepack_int4(Tensor W, Tensor scales, Tensor zero_points, Tensor? B, Tensor? g_idx, int? batch_size, int group_size, int lowp_mode, int num_concats, int act_quant_mode) " + "weight_only_qlinear_prepack_int4(Tensor W, Tensor scales, Tensor zero_points, Tensor? B, Tensor? g_idx, int? batch_size, int group_size, int lowp_mode, int act_quant_mode) " "-> __torch__.torch.classes.ipex_prepack.WoqLinearOpContext"); #endif } diff --git a/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py b/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py index c72273fd1..f7d0a8355 100644 --- a/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py +++ b/intel_extension_for_pytorch/nn/modules/weight_only_quantization.py @@ -31,7 +31,6 @@ def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): self.weight = None self._op_context = None self._lowp_mode = 0 - self._num_concats = 1 self._act_quant_mode = 0 self._group_size = -1 @@ -57,7 +56,6 @@ def extra_repr(self): ) extra_repr_str += ", bias={}".format(self.bias) extra_repr_str += ", lowp_mode={}".format(self._lowp_mode) - extra_repr_str += ", num_concats={}".format(self._num_concats) extra_repr_str += ", act_quant_mode={}".format(self._act_quant_mode) extra_repr_str += ", group_size={}".format(self._group_size) return extra_repr_str @@ -101,9 +99,6 @@ def from_float(cls, mod, scales=None, zero_points=None): "Falling back to 2(BF16)." ) act_quant_mode = qconfig.act_quant_mode - num_concats = 1 - if hasattr(mod, "_num_concats"): - num_concats = mod._num_concats dtype = qconfig.weight_dtype is_int4 = dtype == torch.quint4x2 group_size = qconfig.group_size @@ -130,7 +125,6 @@ def from_float(cls, mod, scales=None, zero_points=None): None, # g_idx group_size, lowp_mode, - num_concats, act_quant_mode, ) del qweight @@ -174,9 +168,6 @@ def from_float_and_int4_weight( lowp_mode = mod.qconfig.lowp_mode if hasattr(mod.qconfig, "act_quant_mode"): act_quant_mode = mod.qconfig.act_quant_mode - num_concats = 1 - if hasattr(mod, "_num_concats"): - num_concats = mod._num_concats w_dtype = qweight.dtype supported_qw_dtype = [ @@ -208,12 +199,10 @@ def from_float_and_int4_weight( None, group_size, int(lowp_mode), - num_concats, act_quant_mode, ) qlinear.weight = qlinear._op_context.get_weight() qlinear._lowp_mode = lowp_mode - qlinear._num_concats = num_concats qlinear._act_quant_mode = act_quant_mode qlinear._group_size = group_size del qweight @@ -230,7 +219,6 @@ def _init_cls( g_idx, group_size, lowp_mode, - num_concats, act_quant_mode, ): qlinear = cls( @@ -248,12 +236,10 @@ def _init_cls( is_int4, group_size, int(lowp_mode), - num_concats, act_quant_mode, ) qlinear.weight = qlinear._op_context.get_weight() qlinear._lowp_mode = lowp_mode - qlinear._num_concats = num_concats qlinear._act_quant_mode = act_quant_mode qlinear._group_size = group_size return qlinear @@ -296,7 +282,6 @@ def _init_cls( g_idx, group_size, lowp_mode, - num_concats, act_quant_mode, ): qlinear = cls._init_from_mod(mod, dtype) @@ -313,7 +298,6 @@ def _init_cls( is_int4, group_size, lowp_mode, - num_concats, act_quant_mode, ) qlinear.weight = qlinear._op_context.get_weight() diff --git a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py index c305d2712..1ec82fa49 100644 --- a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py @@ -317,8 +317,6 @@ def __init__(self, module, tpp=False, woq=False): mod.weight = nn.Parameter(concat_weight) mod.bias = nn.Parameter(concat_bias) if use_bias else None mod.qconfig = qconfig - mod._num_concats = len(weights_list) - self._num_concats = mod._num_concats if w_dtype == torch.quint4x2: self.concat_linear = IpexWoqLinear.from_float_and_int4_weight( mod, @@ -332,32 +330,20 @@ def __init__(self, module, tpp=False, woq=False): self.concat_linear = IpexWoqLinear.from_float( mod, concat_scales, concat_zeros ) + elif ( + self.tpp + and hasattr(module, "concat_linear") + and module.concat_linear is not None + ): + self.concat_linear = module.concat_linear else: - self._num_concats = module._num_concats - if ( - self.tpp - and hasattr(module, "concat_linear") - and module.concat_linear is not None - ): - self.concat_linear = module.concat_linear - else: - for i in range(self.num_concat): - attr_name = f"linear_{i}" - setattr(self, attr_name, getattr(module, attr_name)) + for i in range(self.num_concat): + attr_name = f"linear_{i}" + setattr(self, attr_name, getattr(module, attr_name)) def forward(self, x): if self.concat_linear is not None: - concat_output = self.concat_linear(x) - if self.woq: - num_concats = self._num_concats - hidden_size = concat_output.shape[-1] // num_concats - concat_output = concat_output.view(num_concats, -1, hidden_size) - expected_shape = list(x.shape)[:-1] + [hidden_size] - return tuple( - [concat_output[i].view(expected_shape) for i in range(num_concats)] - ) - else: - return concat_output + return self.concat_linear(x) output_list = [] for i in range(self.num_concat): diff --git a/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py b/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py index f39e3e2f5..762fa15f6 100644 --- a/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py @@ -87,7 +87,6 @@ def __init__(self, linear_list: list): attr_name = f"linear_{i}" setattr(self, attr_name, copy.deepcopy(linear_list[i])) self.concat_linear = None - self._num_concats = None if all(not isinstance(linear, IpexWoqLinear) for linear in linear_list): weights_list = [] bias_list = [] @@ -103,7 +102,6 @@ def __init__(self, linear_list: list): ) self.concat_linear.weight = nn.Parameter(concat_weight) self.concat_linear.bias = nn.Parameter(concat_bias) if use_bias else None - self._num_concats = len(weights_list) def forward(self, x): output_list = [] diff --git a/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py b/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py index 830c792c7..75421f78a 100644 --- a/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py +++ b/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py @@ -42,7 +42,7 @@ def _GPTJAttention_forward( 1, # neighbor elements 64, None, - self.concat_qkv._num_concats, + self.concat_qkv.num_concat, ) else: if concat_qkv is not None: @@ -152,7 +152,7 @@ def _LlamaAttention_forward( self.head_dim // 2, self.head_dim, kv_seq_len, - self.concat_qkv._num_concats, + self.concat_qkv.num_concat, ) else: if concat_qkv is not None: @@ -1012,7 +1012,7 @@ def _MistralAttention_forward( self.head_dim // 2, self.head_dim, kv_seq_len, - self.concat_qkv._num_concats, + self.concat_qkv.num_concat, ) else: if concat_qkv is not None: @@ -1116,7 +1116,7 @@ def _MixtralAttention_forward( self.head_dim // 2, self.head_dim, kv_seq_len, - self.concat_qkv._num_concats, + self.concat_qkv.num_concat, ) else: if concat_qkv is not None: @@ -1429,7 +1429,7 @@ def _StableLMEpochAttention_forward( self.pos_embd_dim // 2, self.pos_embd_dim, kv_seq_len, - self.concat_qkv._num_concats, + self.concat_qkv.num_concat, ) else: if concat_qkv is not None: diff --git a/tests/cpu/test_quantization_default_recipe.py b/tests/cpu/test_quantization_default_recipe.py index 5f4573240..03ff9fd01 100644 --- a/tests/cpu/test_quantization_default_recipe.py +++ b/tests/cpu/test_quantization_default_recipe.py @@ -1152,50 +1152,6 @@ def forward(self, x): y_ref = y_ref.to(act_dtype) torch.testing.assert_close(y, y_ref, atol=0.005, rtol=0.01) - def test_weight_only_quantization_num_concats(self): - class Mod(nn.Module): - def __init__(self): - super().__init__() - self.q = nn.Linear(64, 64, bias=False) - self.k = nn.Linear(64, 64, bias=False) - self.v = nn.Linear(64, 64, bias=False) - - def forward(self, x): - q = self.q(x) - k = self.k(x) - v = self.v(x) - return q, k, v - - class Mod2(nn.Module): - def __init__(self): - super().__init__() - self.qkv = nn.Linear(64, 64 * 3, bias=False) - self.qkv._num_concats = 3 - - def forward(self, x): - qkv = self.qkv(x).view(3, -1, 64) - q, k, v = qkv[0], qkv[1], qkv[2] - return q, k, v - - m = Mod().eval() - m2 = Mod2().eval() - m2.qkv.weight = nn.Parameter( - torch.cat([m.q.weight, m.k.weight, m.v.weight], dim=0) - ) - data = torch.rand(4, 64) - qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(lowp_mode=2) - prepared = prepare(m, qconfig, example_inputs=data, inplace=True) - prepared2 = prepare(m2, qconfig, example_inputs=data, inplace=True) - for bf16 in [False, True]: - with torch.no_grad(), torch.cpu.amp.autocast( - enabled=bf16, dtype=torch.bfloat16 if bf16 else None - ): - qm = convert(prepared) - qm2 = convert(prepared2) - output1 = qm(data) - output2 = qm2(data) - torch.testing.assert_close(output1, output2, atol=1e-2, rtol=1e-4) - def _fakequant_by_group(self, t, quant_a_mode, groupsize): assert quant_a_mode >= 0 and quant_a_mode <= 3 if quant_a_mode == 0: