Skip to content

Commit

Permalink
Added optimization for dequantized constant folding (#1316)
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed Feb 4, 2021
1 parent 4becde0 commit 96e1a03
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
26 changes: 25 additions & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import unittest
import numpy as np
from onnx import helper, TensorProto, OperatorSetIdProto
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
from backend_test_base import Tf2OnnxBackendTestBase
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
from tf2onnx import utils, constants
Expand Down Expand Up @@ -1241,6 +1241,30 @@ def test_const_fold_cast_with_const(self):

# Const Fold Optimizer Tests End

# Const Dequantize Optimizer Tests Start

@check_opset_min_version(10, "DequantizeLinear")
def test_const_dequantize_reshape(self):
inputval = numpy_helper.from_array(np.random.randint(0, 100, (2, 3, 4, 5), np.uint8), name='X')
scale = numpy_helper.from_array(np.array(0.75, dtype=np.float32), name='scale')
zero_point = numpy_helper.from_array(np.array(3, dtype=np.uint8), name='zero_point')
shape = numpy_helper.from_array(np.array([6, 20], dtype=np.int64), name='shape')
node1 = helper.make_node("DequantizeLinear", ["X", "scale", "zero_point"], ["Y"], name="dequantize")
node2 = helper.make_node("Reshape", ["Y", "shape"], ["Z"], name="reshape")

graph = helper.make_graph(
[node1, node2],
"const-dequantize-test",
[],
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (6, 20))],
[inputval, scale, zero_point, shape]
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_and_compare(["Z"], {}, model_proto, "Reshape", 0)

# Const Dequantize Optimizer Tests End

def test_transpose_back_to_back_non_const(self):

node0 = helper.make_node("Transpose", ["u"], ["v"], perm=[0, 2, 3, 1], name="trans_0")
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from .loop_optimizer import LoopOptimizer
from .back_to_back_optimizer import BackToBackOptimizer
from .upsample_optimizer import UpsampleOptimizer
from .const_dequantize_optimizer import ConstDequantizeOptimizer
from .. import logging

# optimizer sequence need to be considered carefully
_optimizers = OrderedDict([
("optimize_transpose", TransposeOptimizer),
("remove_redundant_upsample", UpsampleOptimizer),
("fold_constants", ConstFoldOptimizer),
("const_dequantize_optimizer", ConstDequantizeOptimizer),
("loop_optimizer", LoopOptimizer),
# merge_duplication should be used after optimize_transpose
# for optimize_transpose may have some trans nodes that can be merge
Expand Down
67 changes: 67 additions & 0 deletions tf2onnx/optimizer/const_dequantize_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0


"""const dequantize Optimizer.
if a dequantize op's inputs are const we may be able to fold it through the next op
"""

from .optimizer_base import GraphOptimizerBase
from .const_fold_optimizer import ConstFoldOptimizer

# pylint: disable=logging-not-lazy,unused-argument,missing-docstring


class ConstDequantizeOptimizer(GraphOptimizerBase):

def __init__(self): # pylint: disable=useless-super-delegation
super(ConstDequantizeOptimizer, self).__init__()

def _optimize(self, graph):
return self._apply_optimization(graph, self._optimize_at_current_graph_level)

def _optimize_at_current_graph_level(self, graph):
graph_changed = True
while graph_changed:
graph_changed = False
ops = graph.get_nodes()
for op in ops:
if self._fold_node(op, graph):
graph_changed = True
self.graph_been_opt = True
return graph

def _fold_node(self, node, graph):
""" if a dequantize op's inputs are const and it is fed into a tensor reshaping op, we can apply the op
directly to the quantized inputs. Returns True if the graph is changed.
"""
if node.type not in ["Transpose", "Reshape", "Unsqueeze"]:
return False
dequant_node = node.inputs[0]
if dequant_node.type != "DequantizeLinear":
return False
if len(graph.find_output_consumers(dequant_node.output[0])) > 1:
return False
if not self._all_inputs_are_const(node.inputs[1:]) or self._is_graph_output(node, graph):
return False
if not self._all_inputs_are_const(dequant_node.inputs):
return False
graph.replace_input(node, node.input[0], dequant_node.input[0], 0)
const_outputs = ConstFoldOptimizer.compute_const_folding(node, graph)
graph.replace_all_inputs(node.output[0], dequant_node.output[0])
graph.remove_node(node.name)
dequant_const = dequant_node.inputs[0]
if len(graph.find_output_consumers(dequant_const.output[0])) > 1:
dequant_const = graph.copy_const(dequant_const)
graph.replace_input(dequant_node, dequant_node.input[0], dequant_const.output[0], 0)
dequant_const.set_tensor_value(const_outputs[0])
return True

@staticmethod
def _all_inputs_are_const(nodes):
return all(node.is_const() for node in nodes if node)

@staticmethod
def _is_graph_output(node, graph):
node_out_set = set(node.output)
graph_out_set = set(graph.outputs)
return node_out_set.intersection(graph_out_set)
6 changes: 5 additions & 1 deletion tf2onnx/optimizer/const_fold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _should_skip(node):
if node.is_const() or node.is_graph_input():
return True

skip_type = ["Identity"]
skip_type = ["Identity", "DequantizeLinear"]
if node.type in skip_type:
return True

Expand All @@ -73,6 +73,10 @@ def _fold_node(self, node, graph):
self.logger.debug("need to add function to fold op %s whose op_type is %s", node.name, node.type)
return False

@staticmethod
def compute_const_folding(node, graph):
return _func_map[node.type](node, graph)

@staticmethod
def _all_inputs_are_const(nodes):
return all(node.is_const() for node in nodes if node)
Expand Down

0 comments on commit 96e1a03

Please sign in to comment.