From 58adee5b043a52e0c0a60320d48eae82de557074 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Tue, 6 Jun 2023 05:40:20 +0000 Subject: [PATCH] Fix the size of the saved TorchScript model for prepacked Linear and ConvTranspose (#1688) * add insertBias util function * use empty weight & bias for linear * update UT to include linear * fix ConvTranspose (no unpack for ConvTranspose in jit pass for now thus no need to fix jit side) * use a common parent class to handle empty weight and bias * remove print --- csrc/jit/passes/graph_rewrite_conv.cpp | 13 +---- csrc/jit/passes/graph_rewrite_helper.cpp | 17 ++++++ csrc/jit/passes/graph_rewrite_helper.h | 5 ++ csrc/jit/passes/graph_rewrite_linear.cpp | 19 +++++-- .../nn/utils/_weight_prepack.py | 38 +++++++++---- tests/cpu/test_jit.py | 53 ++++++++++++------- 6 files changed, 99 insertions(+), 46 deletions(-) diff --git a/csrc/jit/passes/graph_rewrite_conv.cpp b/csrc/jit/passes/graph_rewrite_conv.cpp index e54e9d17f..5316fd2d5 100644 --- a/csrc/jit/passes/graph_rewrite_conv.cpp +++ b/csrc/jit/passes/graph_rewrite_conv.cpp @@ -2,6 +2,7 @@ #include "aten/WeightPack.h" #include "cpu/kernels/OpContext.h" #include "graph_rewrite.h" +#include "graph_rewrite_helper.h" #include "graph_rewrite_utils.h" #include "passes/utils.h" @@ -66,18 +67,8 @@ void replaceFrozenIPEXConvWithAtenConv( // empty tensor. Need to get the real bias tensor from the op context. // Please refer to [ Note -- Fix the size of the saved TorchScript model ] // for the details. - at::Tensor bias_tensor; auto may_get_bias_tensor = conv_op_ctx->get_at_bias(); - if (may_get_bias_tensor.has_value()) { - bias_tensor = may_get_bias_tensor.value().set_requires_grad(false); - IValue bias_value(bias_tensor); - auto bias = graph->insertConstant(bias_value); - aten_conv->addInput(bias); - } else { - auto n = graph->createNone(); - auto v = n->insertBefore(aten_conv)->output(); - aten_conv->addInput(v); - } + graph_rewrite_helper::insertBias(graph, aten_conv, may_get_bias_tensor); IValue stride_value(conv_op_ctx->get_stride()); auto stride = graph->insertConstant(stride_value); diff --git a/csrc/jit/passes/graph_rewrite_helper.cpp b/csrc/jit/passes/graph_rewrite_helper.cpp index 26e8f2406..8b7e34819 100644 --- a/csrc/jit/passes/graph_rewrite_helper.cpp +++ b/csrc/jit/passes/graph_rewrite_helper.cpp @@ -289,6 +289,23 @@ bool isClampFusable( return is_fusable; } +void insertBias( + torch::jit::Graph* graph, + torch::jit::Node* node, + c10::optional may_get_bias_tensor) { + at::Tensor bias_tensor; + if (may_get_bias_tensor.has_value()) { + bias_tensor = may_get_bias_tensor.value().set_requires_grad(false); + IValue bias_value(bias_tensor); + auto bias = graph->insertConstant(bias_value); + node->addInput(bias); + } else { + auto n = graph->createNone(); + auto v = n->insertBefore(node)->output(); + node->addInput(v); + } +} + } // namespace graph_rewrite_helper } // namespace jit } // namespace torch_ipex diff --git a/csrc/jit/passes/graph_rewrite_helper.h b/csrc/jit/passes/graph_rewrite_helper.h index e0cf769d8..04f2e1e8d 100644 --- a/csrc/jit/passes/graph_rewrite_helper.h +++ b/csrc/jit/passes/graph_rewrite_helper.h @@ -29,6 +29,11 @@ bool isClampFusable( const torch::jit::Match& match, const std::unordered_map& vmap); +void insertBias( + torch::jit::Graph* graph, + torch::jit::Node* node, + c10::optional bias); + // This struct contains a compiled IR patterns slated for use in the // findPatternMatches function. The struct encapsulates the common // information from parseIR that is used in conjunction with the diff --git a/csrc/jit/passes/graph_rewrite_linear.cpp b/csrc/jit/passes/graph_rewrite_linear.cpp index 85637a752..dbf60bd7d 100644 --- a/csrc/jit/passes/graph_rewrite_linear.cpp +++ b/csrc/jit/passes/graph_rewrite_linear.cpp @@ -4,6 +4,7 @@ #include "auto_opt_config.h" #include "graph_rewrite.h" +#include "graph_rewrite_helper.h" #include "graph_rewrite_utils.h" namespace torch_ipex { @@ -54,16 +55,19 @@ void replaceFrozenIPEXLinearWithAtenLinear( if (!toIValue(prepack_node).has_value()) continue; at::Tensor weight_tensor; + c10::optional may_get_bias_tensor; if (use_mkl_sgemm) { auto linear_op_ctx = toIValue(prepack_node).value().toCustomClass(); - weight_tensor = linear_op_ctx->to_public( - constant_as(n->namedInput("weight")).value()); + weight_tensor = + linear_op_ctx->to_public(linear_op_ctx->get_at_packed_weight()); + may_get_bias_tensor = linear_op_ctx->get_at_bias(); } else { auto linear_op_ctx = toIValue(prepack_node).value().toCustomClass(); - weight_tensor = linear_op_ctx->to_public( - constant_as(n->namedInput("weight")).value()); + weight_tensor = + linear_op_ctx->to_public(linear_op_ctx->get_at_packed_weight()); + may_get_bias_tensor = linear_op_ctx->get_at_bias(); } WithInsertPoint guard(n); auto graph = n->owningGraph(); @@ -73,7 +77,12 @@ void replaceFrozenIPEXLinearWithAtenLinear( IValue weight_value(weight_tensor); auto weight = graph->insertConstant(weight_value); aten_linear->addInput(weight); - aten_linear->addInput(n->inputs().at(2)); + + // bias + // Please refer to [ Note -- Fix the size of the saved TorchScript model ] + // for the details. + graph_rewrite_helper::insertBias(graph, aten_linear, may_get_bias_tensor); + aten_linear->output()->setType(n->output()->type()->cast()); n->output()->replaceAllUsesWith(aten_linear->output()); get_data_handle_nodes.emplace_back(n->inputs().at(3)->node()); diff --git a/intel_extension_for_pytorch/nn/utils/_weight_prepack.py b/intel_extension_for_pytorch/nn/utils/_weight_prepack.py index 98fd15a80..2e4e9919a 100644 --- a/intel_extension_for_pytorch/nn/utils/_weight_prepack.py +++ b/intel_extension_for_pytorch/nn/utils/_weight_prepack.py @@ -47,7 +47,15 @@ def _ipex_module_load_from_state_dict_(self, state_dict, prefix): self.weight_wrapper.load(self, loaded_weight) -class _IPEXConvNd(nn.Module): +class _IPEXPrepackModule(nn.Module): + def _get_forward_weight(self): + return self.weight if self.training else self._ipex_module_empty_tensor + + def _get_forward_bias(self): + return self.bias if self.training else self._ipex_module_empty_tensor + + +class _IPEXConvNd(_IPEXPrepackModule): __constants__ = [ "stride", "padding", @@ -98,8 +106,8 @@ def forward(self, x): if self.padding_mode != "zeros": return torch.ops.torch_ipex.convolution_forward( F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.weight if self.training else self._ipex_module_empty_tensor, - self.bias if self.training else self._ipex_module_empty_tensor, + self._get_forward_weight(), + self._get_forward_bias(), self.ctx.get_data_handle(), self.weight_size, self._real_padding, @@ -109,8 +117,8 @@ def forward(self, x): ) return torch.ops.torch_ipex.convolution_forward( x, - self.weight if self.training else self._ipex_module_empty_tensor, - self.bias if self.training else self._ipex_module_empty_tensor, + self._get_forward_weight(), + self._get_forward_bias(), self.ctx.get_data_handle(), self.weight_size, self._real_padding, @@ -135,7 +143,7 @@ def __init__(self): super(_IPEXConv3d, self).__init__() -class _IPEXLinear(torch.nn.Module): +class _IPEXLinear(_IPEXPrepackModule): def __init__(self): super(_IPEXLinear, self).__init__() @@ -145,11 +153,19 @@ def post_ipex_gemm(self, output): def forward(self, x): if self.use_dnnl: output = torch.ops.torch_ipex.ipex_linear( - x, self.weight, self.bias, self.ctx.get_data_handle(), self.out_features + x, + self._get_forward_weight(), + self._get_forward_bias(), + self.ctx.get_data_handle(), + self.out_features, ) else: output = torch.ops.torch_ipex.ipex_MKLSGEMM( - x, self.weight, self.bias, self.ctx.get_data_handle(), self.out_features + x, + self._get_forward_weight(), + self._get_forward_bias(), + self.ctx.get_data_handle(), + self.out_features, ) return self.post_ipex_gemm(output) @@ -192,7 +208,7 @@ def post_ipex_gemm(self, output): return output -class _IPEXConvTransposeNd(nn.Module): +class _IPEXConvTransposeNd(_IPEXPrepackModule): __constants__ = [ "stride", "padding", @@ -230,8 +246,8 @@ def _load_from_state_dict( def forward(self, x): return torch.ops.torch_ipex.conv_transpose( x, - self.weight, - self.bias, + self._get_forward_weight(), + self._get_forward_bias(), self.ctx.get_data_handle(), self.weight_size, self.padding, diff --git a/tests/cpu/test_jit.py b/tests/cpu/test_jit.py index 97926aa86..743f82a2f 100644 --- a/tests/cpu/test_jit.py +++ b/tests/cpu/test_jit.py @@ -5387,30 +5387,45 @@ def test_replace_PythonGELU_with_AtenGELU(self): def test_empty_weight_bias_inference(self): class M(nn.Module): - def __init__(self): + def __init__(self, module): super(M, self).__init__() - self.conv = nn.Conv2d(3, 5, 3) + self.module = module def forward(self, x): - x = self.conv(x) + x = self.module(x) return x - model = M() - model.eval() - data = torch.randn(1, 3, 56, 56) - optimized = ipex.optimize(model) - with torch.no_grad(): - traced_model = torch.jit.trace(optimized, data) - traced_model = torch.jit.freeze(traced_model) - traced_model(data) - - graph = traced_model.graph - FileCheck().check_not("self.conv.weight").check_not("self.conv.bias").check( - "_ipex_module_empty_tensor" - ).run(graph) - y_ref = model(data) - y_traced = traced_model(data) - self.assertEqual(y_ref, y_traced) + modules = [nn.Conv2d(3, 5, 3), nn.Linear(3, 7), nn.ConvTranspose2d(3, 5, 3)] + inputs = [ + torch.randn(1, 3, 56, 56), + torch.randn(2, 3), + torch.randn(1, 3, 56, 56), + ] + auto_kernel_selection_config = [True, False] + + for module, data in zip(modules, inputs): + for auto_kernel_selection in auto_kernel_selection_config: + # Currently auto_kernel_selection only shows different behavior for nn.Linear + if auto_kernel_selection and not isinstance(module, nn.Linear): + continue + + model = M(module) + model.eval() + optimized = ipex.optimize( + model, auto_kernel_selection=auto_kernel_selection + ) + with torch.no_grad(): + traced_model = torch.jit.trace(optimized, data) + traced_model = torch.jit.freeze(traced_model) + traced_model(data) + + graph = traced_model.graph + FileCheck().check_not("self.module.weight").check_not( + "self.module.bias" + ).check("_ipex_module_empty_tensor").run(graph) + y_ref = model(data) + y_traced = traced_model(data) + self.assertEqual(y_ref, y_traced) if __name__ == "__main__":