Skip to content

Commit

Permalink
Fix the size of the saved TorchScript model for prepacked Linear and …
Browse files Browse the repository at this point in the history
…ConvTranspose (#1688)

* add insertBias util function

* use empty weight & bias for linear

* update UT to include linear

* fix ConvTranspose (no unpack for ConvTranspose in jit pass for now thus no need to fix jit side)

* use a common parent class to handle empty weight and bias

* remove print
  • Loading branch information
chunyuan-w committed Jun 6, 2023
1 parent 30b70e4 commit 58adee5
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 46 deletions.
13 changes: 2 additions & 11 deletions csrc/jit/passes/graph_rewrite_conv.cpp
Expand Up @@ -2,6 +2,7 @@
#include "aten/WeightPack.h"
#include "cpu/kernels/OpContext.h"
#include "graph_rewrite.h"
#include "graph_rewrite_helper.h"
#include "graph_rewrite_utils.h"
#include "passes/utils.h"

Expand Down Expand Up @@ -66,18 +67,8 @@ void replaceFrozenIPEXConvWithAtenConv(
// empty tensor. Need to get the real bias tensor from the op context.
// Please refer to [ Note -- Fix the size of the saved TorchScript model ]
// for the details.
at::Tensor bias_tensor;
auto may_get_bias_tensor = conv_op_ctx->get_at_bias();
if (may_get_bias_tensor.has_value()) {
bias_tensor = may_get_bias_tensor.value().set_requires_grad(false);
IValue bias_value(bias_tensor);
auto bias = graph->insertConstant(bias_value);
aten_conv->addInput(bias);
} else {
auto n = graph->createNone();
auto v = n->insertBefore(aten_conv)->output();
aten_conv->addInput(v);
}
graph_rewrite_helper::insertBias(graph, aten_conv, may_get_bias_tensor);

IValue stride_value(conv_op_ctx->get_stride());
auto stride = graph->insertConstant(stride_value);
Expand Down
17 changes: 17 additions & 0 deletions csrc/jit/passes/graph_rewrite_helper.cpp
Expand Up @@ -289,6 +289,23 @@ bool isClampFusable(
return is_fusable;
}

void insertBias(
torch::jit::Graph* graph,
torch::jit::Node* node,
c10::optional<at::Tensor> may_get_bias_tensor) {
at::Tensor bias_tensor;
if (may_get_bias_tensor.has_value()) {
bias_tensor = may_get_bias_tensor.value().set_requires_grad(false);
IValue bias_value(bias_tensor);
auto bias = graph->insertConstant(bias_value);
node->addInput(bias);
} else {
auto n = graph->createNone();
auto v = n->insertBefore(node)->output();
node->addInput(v);
}
}

} // namespace graph_rewrite_helper
} // namespace jit
} // namespace torch_ipex
5 changes: 5 additions & 0 deletions csrc/jit/passes/graph_rewrite_helper.h
Expand Up @@ -29,6 +29,11 @@ bool isClampFusable(
const torch::jit::Match& match,
const std::unordered_map<std::string, torch::jit::Value*>& vmap);

void insertBias(
torch::jit::Graph* graph,
torch::jit::Node* node,
c10::optional<at::Tensor> bias);

// This struct contains a compiled IR patterns slated for use in the
// findPatternMatches function. The struct encapsulates the common
// information from parseIR that is used in conjunction with the
Expand Down
19 changes: 14 additions & 5 deletions csrc/jit/passes/graph_rewrite_linear.cpp
Expand Up @@ -4,6 +4,7 @@

#include "auto_opt_config.h"
#include "graph_rewrite.h"
#include "graph_rewrite_helper.h"
#include "graph_rewrite_utils.h"

namespace torch_ipex {
Expand Down Expand Up @@ -54,16 +55,19 @@ void replaceFrozenIPEXLinearWithAtenLinear(
if (!toIValue(prepack_node).has_value())
continue;
at::Tensor weight_tensor;
c10::optional<at::Tensor> may_get_bias_tensor;
if (use_mkl_sgemm) {
auto linear_op_ctx =
toIValue(prepack_node).value().toCustomClass<MKLOpContext>();
weight_tensor = linear_op_ctx->to_public(
constant_as<at::Tensor>(n->namedInput("weight")).value());
weight_tensor =
linear_op_ctx->to_public(linear_op_ctx->get_at_packed_weight());
may_get_bias_tensor = linear_op_ctx->get_at_bias();
} else {
auto linear_op_ctx =
toIValue(prepack_node).value().toCustomClass<LinearOpContext>();
weight_tensor = linear_op_ctx->to_public(
constant_as<at::Tensor>(n->namedInput("weight")).value());
weight_tensor =
linear_op_ctx->to_public(linear_op_ctx->get_at_packed_weight());
may_get_bias_tensor = linear_op_ctx->get_at_bias();
}
WithInsertPoint guard(n);
auto graph = n->owningGraph();
Expand All @@ -73,7 +77,12 @@ void replaceFrozenIPEXLinearWithAtenLinear(
IValue weight_value(weight_tensor);
auto weight = graph->insertConstant(weight_value);
aten_linear->addInput(weight);
aten_linear->addInput(n->inputs().at(2));

// bias
// Please refer to [ Note -- Fix the size of the saved TorchScript model ]
// for the details.
graph_rewrite_helper::insertBias(graph, aten_linear, may_get_bias_tensor);

aten_linear->output()->setType(n->output()->type()->cast<TensorType>());
n->output()->replaceAllUsesWith(aten_linear->output());
get_data_handle_nodes.emplace_back(n->inputs().at(3)->node());
Expand Down
38 changes: 27 additions & 11 deletions intel_extension_for_pytorch/nn/utils/_weight_prepack.py
Expand Up @@ -47,7 +47,15 @@ def _ipex_module_load_from_state_dict_(self, state_dict, prefix):
self.weight_wrapper.load(self, loaded_weight)


class _IPEXConvNd(nn.Module):
class _IPEXPrepackModule(nn.Module):
def _get_forward_weight(self):
return self.weight if self.training else self._ipex_module_empty_tensor

def _get_forward_bias(self):
return self.bias if self.training else self._ipex_module_empty_tensor


class _IPEXConvNd(_IPEXPrepackModule):
__constants__ = [
"stride",
"padding",
Expand Down Expand Up @@ -98,8 +106,8 @@ def forward(self, x):
if self.padding_mode != "zeros":
return torch.ops.torch_ipex.convolution_forward(
F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode),
self.weight if self.training else self._ipex_module_empty_tensor,
self.bias if self.training else self._ipex_module_empty_tensor,
self._get_forward_weight(),
self._get_forward_bias(),
self.ctx.get_data_handle(),
self.weight_size,
self._real_padding,
Expand All @@ -109,8 +117,8 @@ def forward(self, x):
)
return torch.ops.torch_ipex.convolution_forward(
x,
self.weight if self.training else self._ipex_module_empty_tensor,
self.bias if self.training else self._ipex_module_empty_tensor,
self._get_forward_weight(),
self._get_forward_bias(),
self.ctx.get_data_handle(),
self.weight_size,
self._real_padding,
Expand All @@ -135,7 +143,7 @@ def __init__(self):
super(_IPEXConv3d, self).__init__()


class _IPEXLinear(torch.nn.Module):
class _IPEXLinear(_IPEXPrepackModule):
def __init__(self):
super(_IPEXLinear, self).__init__()

Expand All @@ -145,11 +153,19 @@ def post_ipex_gemm(self, output):
def forward(self, x):
if self.use_dnnl:
output = torch.ops.torch_ipex.ipex_linear(
x, self.weight, self.bias, self.ctx.get_data_handle(), self.out_features
x,
self._get_forward_weight(),
self._get_forward_bias(),
self.ctx.get_data_handle(),
self.out_features,
)
else:
output = torch.ops.torch_ipex.ipex_MKLSGEMM(
x, self.weight, self.bias, self.ctx.get_data_handle(), self.out_features
x,
self._get_forward_weight(),
self._get_forward_bias(),
self.ctx.get_data_handle(),
self.out_features,
)

return self.post_ipex_gemm(output)
Expand Down Expand Up @@ -192,7 +208,7 @@ def post_ipex_gemm(self, output):
return output


class _IPEXConvTransposeNd(nn.Module):
class _IPEXConvTransposeNd(_IPEXPrepackModule):
__constants__ = [
"stride",
"padding",
Expand Down Expand Up @@ -230,8 +246,8 @@ def _load_from_state_dict(
def forward(self, x):
return torch.ops.torch_ipex.conv_transpose(
x,
self.weight,
self.bias,
self._get_forward_weight(),
self._get_forward_bias(),
self.ctx.get_data_handle(),
self.weight_size,
self.padding,
Expand Down
53 changes: 34 additions & 19 deletions tests/cpu/test_jit.py
Expand Up @@ -5387,30 +5387,45 @@ def test_replace_PythonGELU_with_AtenGELU(self):

def test_empty_weight_bias_inference(self):
class M(nn.Module):
def __init__(self):
def __init__(self, module):
super(M, self).__init__()
self.conv = nn.Conv2d(3, 5, 3)
self.module = module

def forward(self, x):
x = self.conv(x)
x = self.module(x)
return x

model = M()
model.eval()
data = torch.randn(1, 3, 56, 56)
optimized = ipex.optimize(model)
with torch.no_grad():
traced_model = torch.jit.trace(optimized, data)
traced_model = torch.jit.freeze(traced_model)
traced_model(data)

graph = traced_model.graph
FileCheck().check_not("self.conv.weight").check_not("self.conv.bias").check(
"_ipex_module_empty_tensor"
).run(graph)
y_ref = model(data)
y_traced = traced_model(data)
self.assertEqual(y_ref, y_traced)
modules = [nn.Conv2d(3, 5, 3), nn.Linear(3, 7), nn.ConvTranspose2d(3, 5, 3)]
inputs = [
torch.randn(1, 3, 56, 56),
torch.randn(2, 3),
torch.randn(1, 3, 56, 56),
]
auto_kernel_selection_config = [True, False]

for module, data in zip(modules, inputs):
for auto_kernel_selection in auto_kernel_selection_config:
# Currently auto_kernel_selection only shows different behavior for nn.Linear
if auto_kernel_selection and not isinstance(module, nn.Linear):
continue

model = M(module)
model.eval()
optimized = ipex.optimize(
model, auto_kernel_selection=auto_kernel_selection
)
with torch.no_grad():
traced_model = torch.jit.trace(optimized, data)
traced_model = torch.jit.freeze(traced_model)
traced_model(data)

graph = traced_model.graph
FileCheck().check_not("self.module.weight").check_not(
"self.module.bias"
).check("_ipex_module_empty_tensor").run(graph)
y_ref = model(data)
y_traced = traced_model(data)
self.assertEqual(y_ref, y_traced)


if __name__ == "__main__":
Expand Down

0 comments on commit 58adee5

Please sign in to comment.