diff --git a/onnxsim/onnx_simplifier.py b/onnxsim/onnx_simplifier.py index 20056f3..ad72904 100644 --- a/onnxsim/onnx_simplifier.py +++ b/onnxsim/onnx_simplifier.py @@ -228,6 +228,9 @@ def check_node(graph): # Skip QuantizeLinear and DequantizeLinear to preserve quantization info pass elif all([x in const_tensors for x in node.input]) and not is_non_deterministic_node(node): + # Skip these nodes to avoid bloating the model size + if node.op_type in ['ConstantOfShape', 'Tile']: + continue const_nodes.append(node) const_tensors.extend(node.output) @@ -305,13 +308,6 @@ def forward_for_node_outputs(model: onnx.ModelProto, return res -def insert_elem(repeated_container, index: int, element): - repeated_container.extend([repeated_container[-1]]) - for i in reversed(range(index + 1, len(repeated_container) - 1)): - repeated_container[i].CopyFrom(repeated_container[i - 1]) - repeated_container[index].CopyFrom(element) - - def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: Sequence[onnx.NodeProto], res: Tensors) -> onnx.ModelProto: """ @@ -321,6 +317,7 @@ def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: Sequence[onnx.Nod :return: the simplified onnx model. Redundant ops are all removed. """ def recursive_eliminate_const_nodes_in_graph(graph, const_nodes, res): + new_nodes = [] for i, node in enumerate(graph.node): if node in const_nodes: for output in node.output: @@ -336,12 +333,17 @@ def recursive_eliminate_const_nodes_in_graph(graph, const_nodes, res): del new_node.output[:] new_node.output.extend([output]) new_node.attribute.extend([new_attr]) - insert_elem(graph.node, i + 1, new_node) - del graph.node[i] - if has_subgraph_in_node(node): - for attr in node.attribute: - recursive_eliminate_const_nodes_in_graph( - attr.g, const_nodes, res) + new_nodes.append(new_node) + else: + new_nodes.append(node) + if has_subgraph_in_node(node): + for attr in node.attribute: + if attr.g is None: + continue + recursive_eliminate_const_nodes_in_graph( + attr.g, const_nodes, res) + del graph.node[:] + graph.node.extend(new_nodes) recursive_eliminate_const_nodes_in_graph(model.graph, const_nodes, res) return model @@ -471,8 +473,7 @@ def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T: """ x = func_a(x) x = func_b(x) - count = 0 - for _ in range(64): + for _ in range(int(os.getenv('ONNXSIM_FIXED_POINT_MAX_ITER', '5'))): y = func_a(x) if y == x: # Since func_b(func_b(x)) == func_b(x), @@ -484,7 +485,6 @@ def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T: if y == x: return x x = y - print("Warning: The simplifying takes too long. Stopping..") return x