From aeaeba47bc722d9b18f13f8a78e02092c0a6bb5b Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Tue, 16 Jan 2024 00:32:41 +0000 Subject: [PATCH] [release/2.2] Fuse gate_proj and up_proj in MLP of LLaMA (#2469) * Fuse gate_proj and up_proj in MLP of LLaMA (#2430) * Fuse gate_proj and up_proj in MLP of LLaMA * fix clang-format * Update run_quantization.py (#2471) --------- Co-authored-by: jianan-gu --- csrc/cpu/aten/TPPGEMM.cpp | 21 +++ csrc/cpu/aten/TPPGEMM.h | 18 +++ csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp | 35 +++++ csrc/cpu/tpp/kernels/TPPGEMMKrnl.h | 138 ++++++++++++++++++ .../llm/single_instance/run_quantization.py | 3 +- .../models/cpu/fusions/linear_fusion.py | 8 +- tests/cpu/test_tpp_linear.py | 50 +++++++ 7 files changed, 265 insertions(+), 8 deletions(-) diff --git a/csrc/cpu/aten/TPPGEMM.cpp b/csrc/cpu/aten/TPPGEMM.cpp index e4e4c6a94..98497abf7 100644 --- a/csrc/cpu/aten/TPPGEMM.cpp +++ b/csrc/cpu/aten/TPPGEMM.cpp @@ -9,6 +9,7 @@ namespace cpu { IPEX_DEFINE_DISPATCH(tpp_linear_nobias_kernel_stub); IPEX_DEFINE_DISPATCH(tpp_linear_bias_kernel_stub); IPEX_DEFINE_DISPATCH(tpp_linear_gelu_kernel_stub); +IPEX_DEFINE_DISPATCH(tpp_fused_gate_up_proj_kernel_stub); IPEX_DEFINE_DISPATCH(tpp_linear_silu_kernel_stub); IPEX_DEFINE_DISPATCH(tpp_linear_relu_kernel_stub); IPEX_DEFINE_DISPATCH(tpp_linear_add_kernel_stub); @@ -38,6 +39,17 @@ at::Tensor tpp_linear_gelu_forward_cpu( return tpp_linear_gelu_kernel_stub(kCPU, t_in, t_wt, t_bias); } +at::Tensor tpp_fused_gate_up_proj_forward_cpu( + const at::Tensor& t_in, + const at::Tensor& t_wt_gate, + const at::Tensor& t_bias_gate, + const at::Tensor& t_wt_up, + const at::Tensor& t_bias_up, + c10::optional out_features) { + return tpp_fused_gate_up_proj_kernel_stub( + kCPU, t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up); +} + at::Tensor tpp_linear_silu_forward_cpu( const at::Tensor& t_in, const at::Tensor& t_wt, @@ -117,6 +129,15 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { torch_ipex::cpu::tpp_linear_gelu_forward_cpu); } +TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { + m.def( + "tpp_fused_gate_up_proj(Tensor t_in, Tensor t_wt_gate, Tensor t_bias_gate, Tensor t_wt_up, Tensor t_bias_up,int? out_features=None)-> Tensor out"); + m.impl( + "tpp_fused_gate_up_proj", + c10::DispatchKey::CPU, + torch_ipex::cpu::tpp_fused_gate_up_proj_forward_cpu); +} + TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { m.def( "tpp_linear_add_add(Tensor t_in, Tensor t_in1, Tensor t_in2, Tensor t_wt, Tensor t_bias, float scale, int? out_features=None)-> Tensor out"); diff --git a/csrc/cpu/aten/TPPGEMM.h b/csrc/cpu/aten/TPPGEMM.h index f72d1496b..4ee2bf420 100644 --- a/csrc/cpu/aten/TPPGEMM.h +++ b/csrc/cpu/aten/TPPGEMM.h @@ -24,6 +24,14 @@ at::Tensor tpp_linear_gelu_forward_cpu( const at::Tensor& t_bias, c10::optional out_features); +at::Tensor tpp_fused_gate_up_proj_forward_cpu( + const at::Tensor& t_in, + const at::Tensor& t_wt_gate, + const at::Tensor& t_bias_gate, + const at::Tensor& t_wt_up, + const at::Tensor& t_bias_up, + c10::optional out_features); + at::Tensor tpp_linear_silu_forward_cpu( const at::Tensor& t_in, const at::Tensor& t_wt, @@ -71,6 +79,13 @@ using tpp_linear_bias_kernel_impl_fn = using tpp_linear_gelu_kernel_impl_fn = at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&); +using tpp_fused_gate_up_proj_kernel_impl_fn = at::Tensor (*)( + const at::Tensor&, + const at::Tensor&, + const at::Tensor&, + const at::Tensor&, + const at::Tensor&); + using tpp_linear_silu_kernel_impl_fn = at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&); @@ -105,6 +120,9 @@ IPEX_DECLARE_DISPATCH( IPEX_DECLARE_DISPATCH( tpp_linear_gelu_kernel_impl_fn, tpp_linear_gelu_kernel_stub); +IPEX_DECLARE_DISPATCH( + tpp_fused_gate_up_proj_kernel_impl_fn, + tpp_fused_gate_up_proj_kernel_stub); IPEX_DECLARE_DISPATCH( tpp_linear_silu_kernel_impl_fn, tpp_linear_silu_kernel_stub); diff --git a/csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp b/csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp index e7c89947a..988b605fa 100644 --- a/csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp +++ b/csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp @@ -87,6 +87,38 @@ at::Tensor tpp_linear_gelu_kernel_impl( return t_out; } +at::Tensor tpp_fused_gate_up_proj_kernel_impl( + const at::Tensor& t_in, + const at::Tensor& t_wt_gate, + const at::Tensor& t_bias_gate, + const at::Tensor& t_wt_up, + const at::Tensor& t_bias_up) { + auto sizes = t_in.sizes().vec(); + AT_ASSERT( + t_wt_gate.sizes() == t_wt_up.sizes(), + "Expect t_wt_gate.sizes() == t_wt_up.sizes()"); + auto wt_sizes = t_wt_gate.sizes(); + sizes[2] = wt_sizes[0] * wt_sizes[3]; + + auto t_out = t_in.new_empty(sizes); + + auto dt = t_wt_gate.dtype(); + if (dt == at::kFloat) { + torch_ipex::tpp::tpp_fused_gate_up_proj( + t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up, t_out); + } else if (dt == at::kBFloat16) { + torch_ipex::tpp::tpp_fused_gate_up_proj( + t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up, t_out); + } else { + AT_ASSERT( + 0, + "TPP does not support current weight dtype %s:%d\n", + __FILE__, + __LINE__); + } + return t_out; +} + at::Tensor tpp_linear_silu_kernel_impl( const at::Tensor& t_in, const at::Tensor& t_wt, @@ -219,6 +251,9 @@ IPEX_REGISTER_DISPATCH( IPEX_REGISTER_DISPATCH( tpp_linear_gelu_kernel_stub, &tpp_linear_gelu_kernel_impl); +IPEX_REGISTER_DISPATCH( + tpp_fused_gate_up_proj_kernel_stub, + &tpp_fused_gate_up_proj_kernel_impl); IPEX_REGISTER_DISPATCH( tpp_linear_relu_kernel_stub, &tpp_linear_relu_kernel_impl); diff --git a/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h b/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h index 03727b4af..37e97c615 100644 --- a/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h +++ b/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h @@ -42,6 +42,9 @@ REGISTER_LOCAL_SCOPE( REGISTER_LOCAL_SCOPE( tpp_linear_silu_krnl, "tpp_linear_silu_krnl"); // linear bias + silu +REGISTER_LOCAL_SCOPE( + tpp_fused_gate_up_proj_krnl, + "tpp_fused_gate_up_proj_krnl"); // fused gate_proj and up_proj REGISTER_LOCAL_SCOPE( tpp_linear_relu_krnl, "tpp_linear_relu_krnl"); // linear bias + relu @@ -521,6 +524,141 @@ inline void tpp_linear_gelu( } } +// Fused kernel for the gate_proj and the up_proj related computation in the MLP +// of LLaMA. The ref computation of the kernel is: +// act_fn(gate_proj(x)) * up_proj(x) where act_fn is silu, gate_proj and +// up_proj are two nn.Linear with the same weight shapes and bias = False. +// t_in is the input activation +// t_wt_gate is the prepacked weight of the gate_proj +// t_wt_up is the prepacked weight of the up_proj +// t_bias_gate is the bias of the gate_proj +// t_bias_up is the bias of the up_proj +// t_out is the output result of the kernel +template +inline void tpp_fused_gate_up_proj( + const at::Tensor& t_in, + const at::Tensor& t_wt_gate, + const at::Tensor& t_bias_gate, + const at::Tensor& t_wt_up, + const at::Tensor& t_bias_up, + at::Tensor& t_out) { + auto t_wt_gate_ = t_wt_gate; + auto t_wt_up_ = t_wt_up; + auto in_sizes = t_in.sizes(); + auto BS = in_sizes[0] * in_sizes[1]; + if (BS > FT_OPT_SIZE) { // first token compute + t_wt_gate_ = wt_tensor_for_first_token(t_wt_gate_); + t_wt_up_ = wt_tensor_for_first_token(t_wt_up_); + large_cache_opt = true; + } + + auto wt_sizes = t_wt_gate_.sizes(); + auto C = in_sizes[2]; + + auto Nc = wt_sizes[1]; + auto Hc = C / Nc; + auto Nk = wt_sizes[0]; + auto Hk = wt_sizes[3]; + auto K = Nk * Hk; + + auto t_wt_gate_V = + torch_ipex::tpp::wt_tensor_for_fwd(Nk, Hk, Nc, Hc, t_wt_gate_); + auto t_wt_up_V = torch_ipex::tpp::wt_tensor_for_fwd(Nk, Hk, Nc, Hc, t_wt_up_); + + // This is used to store the intermediate result of the up_proj layer + auto t_out_tmp = at::empty_like(t_out); + + auto in = GetVLAPtr(t_in, {Nc, Hc}); + auto wt_gate_V = GetVLAPtr(t_wt_gate_V, {Nc, Hc * Hk}); + auto wt_up_V = GetVLAPtr(t_wt_up_V, {Nc, Hc * Hk}); + auto bias_gate = GetVLAPtr(t_bias_gate, {Hk}); + auto bias_up = GetVLAPtr(t_bias_up, {Hk}); + auto out = GetVLAPtr(t_out, {Nk, Hk}); + auto out_tmp = GetVLAPtr(t_out_tmp, {Nk, Hk}); + + auto Ncb = Nc; + auto BSb = 64L; + auto rem = BS % 64; + if (large_cache_opt) + Ncb = NCB_BLOCK_SIZE; + + bool with_bias_gate = (t_bias_gate.numel() > 0); + bool with_bias_up = (t_bias_up.numel() > 0); + auto copy_bias_tpp = SCOPEIT(CpyBiasTPP(BSb, Hk, K), BIAS); + auto copy_bias_tpp_rem = SCOPEIT(CpyBiasTPP(rem, Hk, K), BIAS); + auto zero_tpp = SCOPEIT(SetZeroTPP(BSb, Hk, K), EW_ZERO); + auto zero_tpp_rem = SCOPEIT(SetZeroTPP(rem, Hk, K), EW_ZERO); + auto brgemm_tpp = SCOPEITGEMM( + (BrgemmTPP(BSb, Hk, Hc, Hc, Hk * Hc, C, Hk, K, 1.0, 0, Ncb))); + auto brgemm_tpp_rem = SCOPEITGEMM( + (BrgemmTPP(rem, Hk, Hc, Hc, Hk * Hc, C, Hk, K, 1.0, 0, Ncb))); + auto silu_fwd_tpp = SCOPEIT(SiLUFwdTPP(BSb, Hk, K, K), ACT); + auto silu_fwd_tpp_rem = SCOPEIT(SiLUFwdTPP(rem, Hk, K, K), ACT); + auto mul_tpp = SCOPEIT((MulTPP(BSb, Hk, K, K)), EW_MUL); + auto mul_tpp_rem = SCOPEIT((MulTPP(rem, Hk, K, K)), EW_MUL); + + { + RECORD_SCOPE(tpp_fused_gate_up_proj_krnl, {t_in, t_wt_gate_V}); + + auto loop_scheme = large_cache_opt ? GEMM_LOOP_SCHEME : "aCb"; + auto igemm_loop = torch_ipex::tpp::ThreadedLoop<3>( + {{0, Nc, Ncb, false}, {0, BS, BSb}, {Nk}}, loop_scheme); + igemm_loop( + [&](int* ind) { + int nc = ind[0], s1 = ind[1], nk = ind[2]; + auto count = nc + Ncb < Nc ? Ncb : Nc - nc; + bool is_rem = (s1 + BSb > BS); + if (!is_rem) { + if (nc == 0) { + if (with_bias_gate) { + copy_bias_tpp(bias_gate[nk], out[s1][nk]); + } else { + zero_tpp(out[s1][nk]); + } + + if (with_bias_up) { + copy_bias_tpp(bias_up[nk], out_tmp[s1][nk]); + } else { + zero_tpp(out_tmp[s1][nk]); + } + } + brgemm_tpp(in[s1][nc], wt_gate_V[nk][nc], out[s1][nk], count, true); + brgemm_tpp( + in[s1][nc], wt_up_V[nk][nc], out_tmp[s1][nk], count, true); + if (!(nc + Ncb < Nc)) { // last nc iter + silu_fwd_tpp(out[s1][nk], out[s1][nk]); + mul_tpp(out[s1][nk], out_tmp[s1][nk], out[s1][nk]); + } + } else { + if (nc == 0) { + if (with_bias_gate) { + copy_bias_tpp_rem(bias_gate[nk], out[s1][nk]); + } else { + zero_tpp_rem(out[s1][nk]); + } + + if (with_bias_up) { + copy_bias_tpp_rem(bias_up[nk], out_tmp[s1][nk]); + } else { + zero_tpp_rem(out_tmp[s1][nk]); + } + } + brgemm_tpp_rem( + in[s1][nc], wt_gate_V[nk][nc], out[s1][nk], count, false); + brgemm_tpp_rem( + in[s1][nc], wt_up_V[nk][nc], out_tmp[s1][nk], count, false); + brgemm_tpp.config(); + if (!(nc + Ncb < Nc)) { // last nc iter + silu_fwd_tpp_rem(out[s1][nk], out[s1][nk]); + mul_tpp_rem(out[s1][nk], out_tmp[s1][nk], out[s1][nk]); + } + } + }, + [&]() { brgemm_tpp.config(); }, + [&]() { brgemm_tpp.release(); }); + } +} + template inline void tpp_linear_add( const at::Tensor t_in, diff --git a/examples/cpu/inference/python/llm/single_instance/run_quantization.py b/examples/cpu/inference/python/llm/single_instance/run_quantization.py index b40cc65d5..5139d0cf8 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_quantization.py +++ b/examples/cpu/inference/python/llm/single_instance/run_quantization.py @@ -536,10 +536,11 @@ def calib_func(prepared_model): op_type_dict=op_type_dict, smoothquant_args=smoothquant_args ) + pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) prepared_model.save_qconf_summary(args.output_dir + "/best_configure.json") else: - qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=args.alpha) + qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=float(args.alpha)) user_model = ipex.llm.optimize( user_model.eval(), dtype=amp_dtype, diff --git a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py index 5c18d7484..1c4782689 100644 --- a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py @@ -379,22 +379,16 @@ def forward(self, x): and not self.linear_m.tpp_fallback ): x = x.to(self.dtype).contiguous() - x1 = torch.ops.torch_ipex.tpp_linear_silu( + return torch.ops.torch_ipex.tpp_fused_gate_up_proj( x, self.linear_s.weight.detach(), self.linear_s.bias.detach() if self.linear_s.bias is not None else x.new_empty(0), - self.linear_s.out_features, - ) - return torch.ops.torch_ipex.tpp_linear_mul( - x, - x1, self.linear_m.weight.detach(), self.linear_m.bias.detach() if self.linear_m.bias is not None else x.new_empty(0), - self.linear_m.out_features, ) else: # fallback path return nn.functional.silu(self.linear_s(x)) * self.linear_m(x) diff --git a/tests/cpu/test_tpp_linear.py b/tests/cpu/test_tpp_linear.py index 308f10b44..8b1e65af8 100644 --- a/tests/cpu/test_tpp_linear.py +++ b/tests/cpu/test_tpp_linear.py @@ -46,6 +46,16 @@ def forward(self, x): return torch.nn.functional.silu(self.mlp(x)) +class Linear_Gate_Up(torch.nn.Module): + def __init__(self, in_feature, out_feature, bias_gate, bias_up): + super(Linear_Gate_Up, self).__init__() + self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_gate) + self.up_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_up) + + def forward(self, x): + return torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x) + + class Linear_relu(torch.nn.Module): def __init__(self): super(Linear_relu, self).__init__() @@ -172,6 +182,46 @@ def test_tpp_linear_torchcompile(self): self.assertTrue(out.dtype == dtype) _disable_tpp() + def test_tpp_fused_gate_up_proj(self): + in_feature = 64 + out_feature = 32 + + x = torch.randn(1, 4, in_feature) + x_tpp = copy.deepcopy(x) + + with torch.no_grad(): + for dtype, bias_gate, bias_up in itertools.product( + [torch.float, torch.bfloat16], [False, True], [False, True] + ): + model = Linear_Gate_Up( + in_feature, out_feature, bias_gate, bias_up + ).eval() + if dtype == torch.bfloat16: + x = x.to(torch.bfloat16) + x_tpp = x_tpp.to(torch.bfloat16) + model = model.to(torch.bfloat16) + ref_out = model(x) + + _enable_tpp() + model = ipex.optimize(model, dtype=dtype) + out = torch.ops.torch_ipex.tpp_fused_gate_up_proj( + x_tpp, + model.gate_proj.weight, + model.gate_proj.bias, + model.up_proj.weight, + model.up_proj.bias, + ) + + out_linear_silu = torch.ops.torch_ipex.tpp_linear_silu( + x_tpp, model.gate_proj.weight, model.gate_proj.bias + ) + out_tpp_ref = torch.ops.torch_ipex.tpp_linear_mul( + x_tpp, out_linear_silu, model.up_proj.weight, model.up_proj.bias + ) + self.assertEqual(out, out_tpp_ref) + self.assertEqual(out, ref_out) + _disable_tpp() + def test_tpp_linear_gelu(self): x1 = torch.rand(1, 4, 4096) x2 = copy.deepcopy(x1)