-
Notifications
You must be signed in to change notification settings - Fork 109
fix(optimizer): move subgraph initializers to main graph when inlining If branches #2887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,8 +1,6 @@ | ||||||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||||||
| # Licensed under the MIT License. | ||||||
|
|
||||||
| # NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files. | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| __all__ = [ | ||||||
|
|
@@ -81,8 +79,6 @@ | |||||
| # The API below works only for non-control-flow ops (ops without any graph-attributes). | ||||||
| # This currently used ONNX's reference implementation. But we could also | ||||||
| # use ORT's implementation if we want to. | ||||||
|
|
||||||
|
|
||||||
| def _process_constant_node(node: ir.Node) -> None: | ||||||
| """Sets const_value of output value of a Constant op node.""" | ||||||
| if not _is_onnx_op(node, "Constant"): | ||||||
|
|
@@ -126,7 +122,6 @@ | |||||
|
|
||||||
| def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None: | ||||||
| """Performs basic constant propagation for a sequence of nodes. | ||||||
|
|
||||||
| Just marks the output values of Constant op nodes with their const_value. | ||||||
| """ | ||||||
| for node in nodes: | ||||||
|
|
@@ -210,12 +205,10 @@ | |||||
|
|
||||||
| # The "partial evaluators" below are non-standard evaluators. They are used to perform | ||||||
| # partial evaluation and/or static program analysis (abstract interpretation). | ||||||
|
|
||||||
| # A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns | ||||||
| # a Replacement for the node or None (if no replacement is needed). It may also return just | ||||||
| # the ir.Value or ir.Values to replace the output values of the node, when the new nodes | ||||||
| # can be inferred from the RewriterContext used to build the new nodes. | ||||||
|
|
||||||
| RewriterContext = _tape.Builder | ||||||
| ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None] | ||||||
| PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue] | ||||||
|
|
@@ -471,7 +464,6 @@ | |||||
| @register("Reshape") | ||||||
| def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: | ||||||
| """Replace a Reshape node by Identity when applicable. | ||||||
|
|
||||||
| Also propagate symbolic shape values. | ||||||
| """ | ||||||
| input = _get_input(node, 0) | ||||||
|
|
@@ -562,6 +554,27 @@ | |||||
| return op.Constant(value_int=size) | ||||||
|
|
||||||
|
|
||||||
| def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None: | ||||||
| """Move all initializers from src graph to dst graph, ensuring name uniqueness. | ||||||
| When an If branch is inlined into the main graph, the branch subgraph may | ||||||
| hold initializers (e.g. a constant axes tensor for a Squeeze node) that were | ||||||
| folded in a prior pass. Those initializers must be migrated to the main graph | ||||||
| so that the inlined nodes can still reference them; failing to do so leaves the | ||||||
| references dangling and produces an invalid model. | ||||||
| """ | ||||||
| counter: dict[str, int] = {} | ||||||
| for name in list(src.initializers): | ||||||
| initializer = src.initializers.pop(name) | ||||||
| # Ensure name uniqueness in the destination graph | ||||||
| new_name = name | ||||||
| while new_name in dst.initializers: | ||||||
| counter[name] = counter.get(name, 0) + 1 | ||||||
| new_name = f"{name}_{counter[name]}" | ||||||
| if new_name != name: | ||||||
| initializer.name = new_name | ||||||
| dst.register_initializer(initializer) | ||||||
|
Comment on lines
+566
to
+575
|
||||||
|
|
||||||
|
|
||||||
| @register("If") | ||||||
| def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue: | ||||||
| cond_input = _get_input(node, 0) | ||||||
|
|
@@ -586,7 +599,6 @@ | |||||
| if actual is not None | ||||||
| } | ||||||
| # TODO: Extend renaming to intermediate values. | ||||||
|
|
||||||
| def rename(name): | ||||||
| return renamings.get(name, name) | ||||||
|
|
||||||
|
|
@@ -599,6 +611,15 @@ | |||||
| # Avoid name collision. | ||||||
| sub_node.name = f"{node.name}_{sub_node.name}" | ||||||
|
|
||||||
| # Move initializers from the subgraph to the main graph to avoid losing them. | ||||||
| # When the If branch was processed in a prior constant-folding pass, any | ||||||
| # constants inside the branch (e.g. the 'axes' tensor for a Squeeze node) | ||||||
| # may have been folded into subgraph initializers. Without this step those | ||||||
| # initializers would be orphaned once the branch nodes are inlined here. | ||||||
| main_graph = node.graph | ||||||
| if main_graph is not None: | ||||||
| _move_initializers_to_graph(graph, main_graph) | ||||||
|
|
||||||
| # TODO: we should handle initializers as well! | ||||||
|
||||||
| # TODO: we should handle initializers as well! |
Copilot
AI
Apr 14, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the class docstring, the sentence on line 994 is mis-indented relative to the should_fold: attribute description, which makes the rendered doc confusing (it reads like a new attribute). Re-indent it so it’s clearly part of should_fold’s description.
| The function should return True/False value to indicate if this particular | |
| The function should return True/False value to indicate if this particular |
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W291 Warning
See https://docs.astral.sh/ruff/rules/trailing-whitespace
Check warning
Code scanning / lintrunner
RUFF/W292 Warning
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file
Copilot
AI
Apr 14, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There’s trailing whitespace on the return line. Also, removing # type: ignore[return-value] may reintroduce static type-check failures if InPlacePass.__call__ isn’t typed to return FoldConstantsResult. If the repo runs mypy/pyright, consider either restoring the targeted ignore or adjusting the pass/call typing so folder_pass(model) is correctly typed without ignores.
| return folder_pass(model) | |
| return typing.cast(FoldConstantsResult, folder_pass(model)) |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,76 @@ | |||||
|
|
||||||
|
|
||||||
| class FoldConstantsTest(unittest.TestCase): | ||||||
| def test_fold_if_cond_with_subgraph_initializer(self): | ||||||
| model = ir.from_onnx_text(""" | ||||||
| <ir_version: 7, opset_import: [ "" : 17]> | ||||||
| agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { | ||||||
| two = Constant <value_float=2.0> () | ||||||
| three = Constant <value_float=3.0> () | ||||||
| z = If (cond) < | ||||||
| then_branch = then_graph () => (then_z) { | ||||||
| temp = Add (two, three) | ||||||
| then_z = Mul (temp, x) | ||||||
| }, | ||||||
| else_branch = else_graph () => (else_z) { | ||||||
| else_z = Identity (x) | ||||||
| } | ||||||
| > | ||||||
| } | ||||||
| """) | ||||||
|
|
||||||
| # Pass 1: fold Add(2.0, 3.0) into a subgraph initializer called 'temp'. | ||||||
| # The If condition is still non-constant so the branch is NOT inlined yet. | ||||||
| _constant_folding.fold_constants(model) | ||||||
| optimizer.remove_unused_nodes(model) | ||||||
| if_node = next(n for n in model.graph if n.op_type == "If") | ||||||
| then_branch = if_node.attributes["then_branch"].as_graph() | ||||||
| self.assertIn("temp", then_branch.initializers) | ||||||
| self.assertNotIn("temp", model.graph.initializers) | ||||||
|
|
||||||
| # Make the condition a known True constant to trigger branch inlining. | ||||||
| const_true = ir.Value(name="const_true") | ||||||
| const_true.const_value = ir.Tensor(np.array(True)) | ||||||
| if_node.replace_input_with(0, const_true) | ||||||
|
|
||||||
| # Pass 2: inline the If branch. | ||||||
| # 'temp' must be migrated from the subgraph to the main graph. | ||||||
| _constant_folding.fold_constants(model) | ||||||
| optimizer.remove_unused_nodes(model) | ||||||
| self.assertIn("temp", model.graph.initializers) | ||||||
| onnx.checker.check_model(ir.serde.serialize_model(model)) | ||||||
|
|
||||||
| def test_fold_if_cond_with_subgraph_initializer_name_collision(self): | ||||||
| """Subgraph initializer names that clash with main-graph names get a unique suffix.""" | ||||||
|
Comment on lines
+56
to
+57
|
||||||
| model = ir.from_onnx_text(""" | ||||||
| <ir_version: 7, opset_import: [ "" : 17]> | ||||||
| agraph (float[1, 4] x, bool cond) => (float[4] z) { | ||||||
| axes_val = Constant <value_ints=[0]> () | ||||||
| z = If (cond) < | ||||||
| then_branch = then_branch_graph () => (then_z) { | ||||||
| axes_val_inner = Constant <value_ints=[0]> () | ||||||
| then_z = Squeeze (x, axes_val_inner) | ||||||
| }, | ||||||
| else_branch = else_branch_graph () => (else_z) { | ||||||
| else_z = Squeeze (x, axes_val) | ||||||
| } | ||||||
| > | ||||||
| } | ||||||
| """) | ||||||
|
|
||||||
| _constant_folding.fold_constants(model) | ||||||
| optimizer.remove_unused_nodes(model) | ||||||
|
|
||||||
| if_node = next(n for n in model.graph if n.op_type == "If") | ||||||
| const_true = ir.Value(name="const_true_collision") | ||||||
| const_true.const_value = ir.Tensor(np.array(True)) | ||||||
| if_node.replace_input_with(0, const_true) | ||||||
|
|
||||||
| # Must not crash or silently overwrite on name collision. | ||||||
| _constant_folding.fold_constants(model) | ||||||
| optimizer.remove_unused_nodes(model) | ||||||
| onnx.checker.check_model(ir.serde.serialize_model(model)) | ||||||
|
|
||||||
| def _fold( | ||||||
| self, | ||||||
| model: ir.Model | str, | ||||||
|
|
@@ -236,9 +306,7 @@ | |||||
| self.assertEqual(len(optimized.graph), 1) | ||||||
| self.assertIn("C", optimized.graph.initializers) | ||||||
|
|
||||||
| def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( | ||||||
| self, | ||||||
| ): | ||||||
| def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(self): | ||||||
|
||||||
| def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(self): | |
| def test_static_split_to_sequence_with_scalar_split_and_sequence_at_is_folded_as_split(self): |
Copilot
AI
Apr 14, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The word 'squence' appears to be a misspelling of 'sequence' in these updated test names. Since these lines were modified in this PR, it’s a good opportunity to correct the spelling for clarity and searchability.
Copilot
AI
Apr 14, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The word 'squence' appears to be a misspelling of 'sequence' in these updated test names. Since these lines were modified in this PR, it’s a good opportunity to correct the spelling for clarity and searchability.
Check warning
Code scanning / lintrunner
RUFF/W292 Warning
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helper guarantees uniqueness only against
dst.initializers, but ONNX value names must be unique across the whole graph namespace (e.g., node outputs). Ifnamecollides with an existing non-initializer value name indst,register_initializermay raise or the graph may become invalid. Also, popping fromsrc.initializersbefore successfully registering can irreversibly drop the initializer if an exception occurs. Consider: (1) building aused_namesset for the destination graph covering all existing value names (not just initializers), (2) choosingnew_nameagainst that set, and (3) only removing fromsrc.initializersafter the destination registration succeeds.