Skip to content

Commit

Permalink
Merge pull request #185 from daquexian/speed_up_and_avoid_bloating
Browse files Browse the repository at this point in the history
speed up and avoid size bloating
  • Loading branch information
daquexian committed May 6, 2022
2 parents ec97e5b + c534b02 commit 55b0b58
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions onnxsim/onnx_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def check_node(graph):
# Skip QuantizeLinear and DequantizeLinear to preserve quantization info
pass
elif all([x in const_tensors for x in node.input]) and not is_non_deterministic_node(node):
# Skip these nodes to avoid bloating the model size
if node.op_type in ['ConstantOfShape', 'Tile']:
continue
const_nodes.append(node)
const_tensors.extend(node.output)

Expand Down Expand Up @@ -305,13 +308,6 @@ def forward_for_node_outputs(model: onnx.ModelProto,
return res


def insert_elem(repeated_container, index: int, element):
repeated_container.extend([repeated_container[-1]])
for i in reversed(range(index + 1, len(repeated_container) - 1)):
repeated_container[i].CopyFrom(repeated_container[i - 1])
repeated_container[index].CopyFrom(element)


def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: Sequence[onnx.NodeProto],
res: Tensors) -> onnx.ModelProto:
"""
Expand All @@ -321,6 +317,7 @@ def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: Sequence[onnx.Nod
:return: the simplified onnx model. Redundant ops are all removed.
"""
def recursive_eliminate_const_nodes_in_graph(graph, const_nodes, res):
new_nodes = []
for i, node in enumerate(graph.node):
if node in const_nodes:
for output in node.output:
Expand All @@ -336,12 +333,17 @@ def recursive_eliminate_const_nodes_in_graph(graph, const_nodes, res):
del new_node.output[:]
new_node.output.extend([output])
new_node.attribute.extend([new_attr])
insert_elem(graph.node, i + 1, new_node)
del graph.node[i]
if has_subgraph_in_node(node):
for attr in node.attribute:
recursive_eliminate_const_nodes_in_graph(
attr.g, const_nodes, res)
new_nodes.append(new_node)
else:
new_nodes.append(node)
if has_subgraph_in_node(node):
for attr in node.attribute:
if attr.g is None:
continue
recursive_eliminate_const_nodes_in_graph(
attr.g, const_nodes, res)
del graph.node[:]
graph.node.extend(new_nodes)
recursive_eliminate_const_nodes_in_graph(model.graph, const_nodes, res)

return model
Expand Down Expand Up @@ -471,8 +473,7 @@ def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T:
"""
x = func_a(x)
x = func_b(x)
count = 0
for _ in range(64):
for _ in range(int(os.getenv('ONNXSIM_FIXED_POINT_MAX_ITER', '5'))):
y = func_a(x)
if y == x:
# Since func_b(func_b(x)) == func_b(x),
Expand All @@ -484,7 +485,6 @@ def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T:
if y == x:
return x
x = y
print("Warning: The simplifying takes too long. Stopping..")
return x


Expand Down

0 comments on commit 55b0b58

Please sign in to comment.