Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added optimization for dequantized constant folding #1316

Merged
merged 1 commit into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import unittest
import numpy as np
from onnx import helper, TensorProto, OperatorSetIdProto
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
from backend_test_base import Tf2OnnxBackendTestBase
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
from tf2onnx import utils, constants
Expand Down Expand Up @@ -1241,6 +1241,30 @@ def test_const_fold_cast_with_const(self):

# Const Fold Optimizer Tests End

# Const Dequantize Optimizer Tests Start

@check_opset_min_version(10, "DequantizeLinear")
def test_const_dequantize_reshape(self):
inputval = numpy_helper.from_array(np.random.randint(0, 100, (2, 3, 4, 5), np.uint8), name='X')
scale = numpy_helper.from_array(np.array(0.75, dtype=np.float32), name='scale')
zero_point = numpy_helper.from_array(np.array(3, dtype=np.uint8), name='zero_point')
shape = numpy_helper.from_array(np.array([6, 20], dtype=np.int64), name='shape')
node1 = helper.make_node("DequantizeLinear", ["X", "scale", "zero_point"], ["Y"], name="dequantize")
node2 = helper.make_node("Reshape", ["Y", "shape"], ["Z"], name="reshape")

graph = helper.make_graph(
[node1, node2],
"const-dequantize-test",
[],
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (6, 20))],
[inputval, scale, zero_point, shape]
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_and_compare(["Z"], {}, model_proto, "Reshape", 0)

# Const Dequantize Optimizer Tests End

def test_transpose_back_to_back_non_const(self):

node0 = helper.make_node("Transpose", ["u"], ["v"], perm=[0, 2, 3, 1], name="trans_0")
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from .loop_optimizer import LoopOptimizer
from .back_to_back_optimizer import BackToBackOptimizer
from .upsample_optimizer import UpsampleOptimizer
from .const_dequantize_optimizer import ConstDequantizeOptimizer
from .. import logging

# optimizer sequence need to be considered carefully
_optimizers = OrderedDict([
("optimize_transpose", TransposeOptimizer),
("remove_redundant_upsample", UpsampleOptimizer),
("fold_constants", ConstFoldOptimizer),
("const_dequantize_optimizer", ConstDequantizeOptimizer),
("loop_optimizer", LoopOptimizer),
# merge_duplication should be used after optimize_transpose
# for optimize_transpose may have some trans nodes that can be merge
Expand Down
67 changes: 67 additions & 0 deletions tf2onnx/optimizer/const_dequantize_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0


"""const dequantize Optimizer.
if a dequantize op's inputs are const we may be able to fold it through the next op
"""

from .optimizer_base import GraphOptimizerBase
from .const_fold_optimizer import ConstFoldOptimizer

# pylint: disable=logging-not-lazy,unused-argument,missing-docstring


class ConstDequantizeOptimizer(GraphOptimizerBase):

def __init__(self): # pylint: disable=useless-super-delegation
super(ConstDequantizeOptimizer, self).__init__()

def _optimize(self, graph):
return self._apply_optimization(graph, self._optimize_at_current_graph_level)

def _optimize_at_current_graph_level(self, graph):
graph_changed = True
while graph_changed:
graph_changed = False
ops = graph.get_nodes()
for op in ops:
if self._fold_node(op, graph):
graph_changed = True
self.graph_been_opt = True
return graph

def _fold_node(self, node, graph):
""" if a dequantize op's inputs are const and it is fed into a tensor reshaping op, we can apply the op
directly to the quantized inputs. Returns True if the graph is changed.
"""
if node.type not in ["Transpose", "Reshape", "Unsqueeze"]:
return False
dequant_node = node.inputs[0]
if dequant_node.type != "DequantizeLinear":
return False
if len(graph.find_output_consumers(dequant_node.output[0])) > 1:
return False
if not self._all_inputs_are_const(node.inputs[1:]) or self._is_graph_output(node, graph):
return False
if not self._all_inputs_are_const(dequant_node.inputs):
return False
graph.replace_input(node, node.input[0], dequant_node.input[0], 0)
const_outputs = ConstFoldOptimizer.compute_const_folding(node, graph)
graph.replace_all_inputs(node.output[0], dequant_node.output[0])
graph.remove_node(node.name)
dequant_const = dequant_node.inputs[0]
if len(graph.find_output_consumers(dequant_const.output[0])) > 1:
dequant_const = graph.copy_const(dequant_const)
graph.replace_input(dequant_node, dequant_node.input[0], dequant_const.output[0], 0)
dequant_const.set_tensor_value(const_outputs[0])
return True

@staticmethod
def _all_inputs_are_const(nodes):
return all(node.is_const() for node in nodes if node)

@staticmethod
def _is_graph_output(node, graph):
node_out_set = set(node.output)
graph_out_set = set(graph.outputs)
return node_out_set.intersection(graph_out_set)
6 changes: 5 additions & 1 deletion tf2onnx/optimizer/const_fold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _should_skip(node):
if node.is_const() or node.is_graph_input():
return True

skip_type = ["Identity"]
skip_type = ["Identity", "DequantizeLinear"]
if node.type in skip_type:
return True

Expand All @@ -73,6 +73,10 @@ def _fold_node(self, node, graph):
self.logger.debug("need to add function to fold op %s whose op_type is %s", node.name, node.type)
return False

@staticmethod
def compute_const_folding(node, graph):
return _func_map[node.type](node, graph)

@staticmethod
def _all_inputs_are_const(nodes):
return all(node.is_const() for node in nodes if node)
Expand Down