From eae477f76356b5a83640941787a168f680334775 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 27 Mar 2024 19:06:43 +0800 Subject: [PATCH] Add filter to linear mul fusion (#2704) --- csrc/cpu/jit/passes/graph_rewrite_linear.cpp | 21 ++++++- .../llm/single_instance/run_quantization.py | 1 + tests/cpu/test_jit.py | 62 +++++++++++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/csrc/cpu/jit/passes/graph_rewrite_linear.cpp b/csrc/cpu/jit/passes/graph_rewrite_linear.cpp index f808842a8..00eb17228 100644 --- a/csrc/cpu/jit/passes/graph_rewrite_linear.cpp +++ b/csrc/cpu/jit/passes/graph_rewrite_linear.cpp @@ -451,6 +451,23 @@ void fuseLinearMulAdd(std::shared_ptr& graph) { %res = ipex_prepack::linear_mul_run(%input, %operand, %packed_weight) return (%res))"; + auto filter_scalar = [](const Match& match, + const std::unordered_map& vmap) { + Node* node = match.anchor; + if (utils::is_scalar(node->input(1)) || utils::is_scalar(node->input(0))) { + return false; + } + if (node->input(1)->type()->cast()->dim().has_value() && + node->input(1)->type()->cast()->dim().value() == 0) { + return false; + } + if (node->input(0)->type()->cast()->dim().has_value() && + node->input(0)->type()->cast()->dim().value() == 0) { + return false; + } + return true; + }; + for (const auto& mul : mul_operators) { TemplateEnv env; env.s("mul", mul); @@ -460,8 +477,8 @@ void fuseLinearMulAdd(std::shared_ptr& graph) { linear_mul_operand_on_the_left_rstring.format(env), linear_mul_fused); } - rewriter_mul_operand_on_the_right.runOnGraph(graph); - rewriter_mul_operand_on_the_left.runOnGraph(graph); + rewriter_mul_operand_on_the_right.runOnGraph(graph, filter_scalar); + rewriter_mul_operand_on_the_left.runOnGraph(graph, filter_scalar); // linear + mul + add // linear_mul Y diff --git a/examples/cpu/inference/python/llm/single_instance/run_quantization.py b/examples/cpu/inference/python/llm/single_instance/run_quantization.py index 1c562f3dc..583ae0d74 100644 --- a/examples/cpu/inference/python/llm/single_instance/run_quantization.py +++ b/examples/cpu/inference/python/llm/single_instance/run_quantization.py @@ -653,6 +653,7 @@ def calib_func(prepared_model): op_type_dict=op_type_dict, smoothquant_args=smoothquant_args ) + pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) prepared_model.save_qconf_summary(args.output_dir + "/best_configure.json") else: diff --git a/tests/cpu/test_jit.py b/tests/cpu/test_jit.py index f3e559a78..cb9469eb2 100644 --- a/tests/cpu/test_jit.py +++ b/tests/cpu/test_jit.py @@ -815,6 +815,19 @@ def forward(self, input): return x_l +class LinearMulAdd_v2(nn.Module): + def __init__(self, in_features, out_features): + super(LinearMulAdd_v2, self).__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.mul_tensor = torch.tensor(1) + self.mul_scalar = 0.5 + + def forward(self, input): + x_add = input + result = self.mul_tensor * self.linear(input) * self.mul_scalar + return result + (x_add).to(result.dtype) + + class LinearMul(nn.Module): def __init__(self, in_features, num_layers, low_rank): super(LinearMul, self).__init__() @@ -841,6 +854,17 @@ def forward(self, input): return x_l +class LinearMul_v2(nn.Module): + def __init__(self, in_features, out_features): + super(LinearMul_v2, self).__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.mul_tensor = torch.tensor(1) + self.mul_scalar = 0.5 + + def forward(self, input): + return self.mul_scalar * self.linear(input) * self.mul_tensor + + class Linear_Reshape_Relu(nn.Module): def __init__(self, in_channels, out_channels, dest_shape, **kwargs): super(Linear_Reshape_Relu, self).__init__() @@ -4536,6 +4560,25 @@ def test_output_linear_mul_add(self): prec=5e-2, ) + def test_output_linear_mul_add_v2(self): + m = LinearMulAdd_v2(4, 4) + x = torch.ones(2, 4) + self._test_output( + m, + x, + kind_in_graph="aten::linear", + kind_not_in_graph="ipex_prepack::linear_mul_add_run", + ) + self._test_mkl_fp32(m, x, kind_in_graph="ipex_prepack::mkl_sgemm_run") + self._test_dnnl_fp32(m, x, kind_in_graph="ipex_prepack::linear_run") + self._test_output_lowp( + m, + x, + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="ipex_prepack::linear_mul_add_run", + prec=5e-2, + ) + def test_output_linear_mul(self): m = LinearMul(4, 2, 8) x = torch.ones(2, 4) @@ -4549,6 +4592,25 @@ def test_output_linear_mul(self): prec=5e-2, ) + def test_output_linear_mul_v2(self): + m = LinearMul_v2(4, 4) + x = torch.ones(2, 4) + self._test_output( + m, + x, + kind_in_graph="aten::linear", + kind_not_in_graph="ipex_prepack::linear_mul_run", + ) + self._test_mkl_fp32(m, x, kind_in_graph="ipex_prepack::mkl_sgemm_run") + self._test_dnnl_fp32(m, x, kind_in_graph="ipex_prepack::linear_run") + self._test_output_lowp( + m, + x, + kind_in_graph="ipex_prepack::linear_run", + kind_not_in_graph="ipex_prepack::linear_mul_run", + prec=5e-2, + ) + def test_output_linear_reshape_relu(self): self._test_output( Linear_Reshape_Relu(3, 32, (64, 16), bias=True),