Skip to content

Commit

Permalink
Add filter to linear mul fusion (#2704)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianan-gu committed Mar 27, 2024
1 parent f57307d commit eae477f
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
21 changes: 19 additions & 2 deletions csrc/cpu/jit/passes/graph_rewrite_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,23 @@ void fuseLinearMulAdd(std::shared_ptr<Graph>& graph) {
%res = ipex_prepack::linear_mul_run(%input, %operand, %packed_weight)
return (%res))";

auto filter_scalar = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
Node* node = match.anchor;
if (utils::is_scalar(node->input(1)) || utils::is_scalar(node->input(0))) {
return false;
}
if (node->input(1)->type()->cast<TensorType>()->dim().has_value() &&
node->input(1)->type()->cast<TensorType>()->dim().value() == 0) {
return false;
}
if (node->input(0)->type()->cast<TensorType>()->dim().has_value() &&
node->input(0)->type()->cast<TensorType>()->dim().value() == 0) {
return false;
}
return true;
};

for (const auto& mul : mul_operators) {
TemplateEnv env;
env.s("mul", mul);
Expand All @@ -460,8 +477,8 @@ void fuseLinearMulAdd(std::shared_ptr<Graph>& graph) {
linear_mul_operand_on_the_left_rstring.format(env), linear_mul_fused);
}

rewriter_mul_operand_on_the_right.runOnGraph(graph);
rewriter_mul_operand_on_the_left.runOnGraph(graph);
rewriter_mul_operand_on_the_right.runOnGraph(graph, filter_scalar);
rewriter_mul_operand_on_the_left.runOnGraph(graph, filter_scalar);

// linear + mul + add
// linear_mul Y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ def calib_func(prepared_model):
op_type_dict=op_type_dict,
smoothquant_args=smoothquant_args
)
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
prepared_model.save_qconf_summary(args.output_dir + "/best_configure.json")

else:
Expand Down
62 changes: 62 additions & 0 deletions tests/cpu/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,19 @@ def forward(self, input):
return x_l


class LinearMulAdd_v2(nn.Module):
def __init__(self, in_features, out_features):
super(LinearMulAdd_v2, self).__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
self.mul_tensor = torch.tensor(1)
self.mul_scalar = 0.5

def forward(self, input):
x_add = input
result = self.mul_tensor * self.linear(input) * self.mul_scalar
return result + (x_add).to(result.dtype)


class LinearMul(nn.Module):
def __init__(self, in_features, num_layers, low_rank):
super(LinearMul, self).__init__()
Expand All @@ -841,6 +854,17 @@ def forward(self, input):
return x_l


class LinearMul_v2(nn.Module):
def __init__(self, in_features, out_features):
super(LinearMul_v2, self).__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
self.mul_tensor = torch.tensor(1)
self.mul_scalar = 0.5

def forward(self, input):
return self.mul_scalar * self.linear(input) * self.mul_tensor


class Linear_Reshape_Relu(nn.Module):
def __init__(self, in_channels, out_channels, dest_shape, **kwargs):
super(Linear_Reshape_Relu, self).__init__()
Expand Down Expand Up @@ -4536,6 +4560,25 @@ def test_output_linear_mul_add(self):
prec=5e-2,
)

def test_output_linear_mul_add_v2(self):
m = LinearMulAdd_v2(4, 4)
x = torch.ones(2, 4)
self._test_output(
m,
x,
kind_in_graph="aten::linear",
kind_not_in_graph="ipex_prepack::linear_mul_add_run",
)
self._test_mkl_fp32(m, x, kind_in_graph="ipex_prepack::mkl_sgemm_run")
self._test_dnnl_fp32(m, x, kind_in_graph="ipex_prepack::linear_run")
self._test_output_lowp(
m,
x,
kind_in_graph="ipex_prepack::linear_run",
kind_not_in_graph="ipex_prepack::linear_mul_add_run",
prec=5e-2,
)

def test_output_linear_mul(self):
m = LinearMul(4, 2, 8)
x = torch.ones(2, 4)
Expand All @@ -4549,6 +4592,25 @@ def test_output_linear_mul(self):
prec=5e-2,
)

def test_output_linear_mul_v2(self):
m = LinearMul_v2(4, 4)
x = torch.ones(2, 4)
self._test_output(
m,
x,
kind_in_graph="aten::linear",
kind_not_in_graph="ipex_prepack::linear_mul_run",
)
self._test_mkl_fp32(m, x, kind_in_graph="ipex_prepack::mkl_sgemm_run")
self._test_dnnl_fp32(m, x, kind_in_graph="ipex_prepack::linear_run")
self._test_output_lowp(
m,
x,
kind_in_graph="ipex_prepack::linear_run",
kind_not_in_graph="ipex_prepack::linear_mul_run",
prec=5e-2,
)

def test_output_linear_reshape_relu(self):
self._test_output(
Linear_Reshape_Relu(3, 32, (64, 16), bias=True),
Expand Down

0 comments on commit eae477f

Please sign in to comment.