Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

# NOTE: This will eventually replace the existing constant_folding.py and evaluator.py files.

from __future__ import annotations

__all__ = [
Expand Down Expand Up @@ -81,8 +79,6 @@
# The API below works only for non-control-flow ops (ops without any graph-attributes).
# This currently used ONNX's reference implementation. But we could also
# use ORT's implementation if we want to.


def _process_constant_node(node: ir.Node) -> None:
"""Sets const_value of output value of a Constant op node."""
if not _is_onnx_op(node, "Constant"):
Expand Down Expand Up @@ -126,7 +122,6 @@

def basic_constant_propagation(nodes: Iterable[ir.Node]) -> None:
"""Performs basic constant propagation for a sequence of nodes.

Just marks the output values of Constant op nodes with their const_value.
"""
for node in nodes:
Expand Down Expand Up @@ -210,12 +205,10 @@

# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).

# A partial-evaluator function takes a node, a RewriterContext, OptimizerState and returns
# a Replacement for the node or None (if no replacement is needed). It may also return just
# the ir.Value or ir.Values to replace the output values of the node, when the new nodes
# can be inferred from the RewriterContext used to build the new nodes.

RewriterContext = _tape.Builder
ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None]
PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue]
Expand Down Expand Up @@ -471,7 +464,6 @@
@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Reshape node by Identity when applicable.

Also propagate symbolic shape values.
"""
input = _get_input(node, 0)
Expand Down Expand Up @@ -562,6 +554,27 @@
return op.Constant(value_int=size)


def _move_initializers_to_graph(src: ir.Graph, dst: ir.Graph) -> None:
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper guarantees uniqueness only against dst.initializers, but ONNX value names must be unique across the whole graph namespace (e.g., node outputs). If name collides with an existing non-initializer value name in dst, register_initializer may raise or the graph may become invalid. Also, popping from src.initializers before successfully registering can irreversibly drop the initializer if an exception occurs. Consider: (1) building a used_names set for the destination graph covering all existing value names (not just initializers), (2) choosing new_name against that set, and (3) only removing from src.initializers after the destination registration succeeds.

Copilot uses AI. Check for mistakes.
"""Move all initializers from src graph to dst graph, ensuring name uniqueness.
When an If branch is inlined into the main graph, the branch subgraph may
hold initializers (e.g. a constant axes tensor for a Squeeze node) that were
folded in a prior pass. Those initializers must be migrated to the main graph
so that the inlined nodes can still reference them; failing to do so leaves the
references dangling and produces an invalid model.
"""
counter: dict[str, int] = {}
for name in list(src.initializers):
initializer = src.initializers.pop(name)
# Ensure name uniqueness in the destination graph
new_name = name
while new_name in dst.initializers:
counter[name] = counter.get(name, 0) + 1
new_name = f"{name}_{counter[name]}"
if new_name != name:
initializer.name = new_name
dst.register_initializer(initializer)
Comment on lines +566 to +575
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper guarantees uniqueness only against dst.initializers, but ONNX value names must be unique across the whole graph namespace (e.g., node outputs). If name collides with an existing non-initializer value name in dst, register_initializer may raise or the graph may become invalid. Also, popping from src.initializers before successfully registering can irreversibly drop the initializer if an exception occurs. Consider: (1) building a used_names set for the destination graph covering all existing value names (not just initializers), (2) choosing new_name against that set, and (3) only removing from src.initializers after the destination registration succeeds.

Copilot uses AI. Check for mistakes.


@register("If")
def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
cond_input = _get_input(node, 0)
Expand All @@ -586,7 +599,6 @@
if actual is not None
}
# TODO: Extend renaming to intermediate values.

def rename(name):
return renamings.get(name, name)

Expand All @@ -599,6 +611,15 @@
# Avoid name collision.
sub_node.name = f"{node.name}_{sub_node.name}"

# Move initializers from the subgraph to the main graph to avoid losing them.
# When the If branch was processed in a prior constant-folding pass, any
# constants inside the branch (e.g. the 'axes' tensor for a Squeeze node)
# may have been folded into subgraph initializers. Without this step those
# initializers would be orphaned once the branch nodes are inlined here.
main_graph = node.graph
if main_graph is not None:
_move_initializers_to_graph(graph, main_graph)

# TODO: we should handle initializers as well!
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO on line 623 is now stale/contradictory because the function is explicitly handling initializers. Please remove or update that TODO to avoid misleading future readers.

Suggested change
# TODO: we should handle initializers as well!

Copilot uses AI. Check for mistakes.
return Replacement(formal_outs, graph_nodes)
return None
Expand Down Expand Up @@ -787,7 +808,6 @@
@register("SplitToSequence")
def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Rewriting pattern.

From

splits = onnx::SplitToSequence(input, split, axis=axis)
Expand Down Expand Up @@ -965,14 +985,13 @@

class FoldConstantsPass(ir.passes.InPlacePass):
"""A pass that folds constant expressions in the model.

Attributes:
shape_inference: Whether to perform shape inference.
input_size_limit: Maximum size of input tensors to fold.
output_size_limit: Maximum size of output tensors to fold.
should_fold: An optional function that takes a node and returns True if
the node should be considered for folding.
The function should return True/False value to indicate if this particular
The function should return True/False value to indicate if this particular
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the class docstring, the sentence on line 994 is mis-indented relative to the should_fold: attribute description, which makes the rendered doc confusing (it reads like a new attribute). Re-indent it so it’s clearly part of should_fold’s description.

Suggested change
The function should return True/False value to indicate if this particular
The function should return True/False value to indicate if this particular

Copilot uses AI. Check for mistakes.
node should be folded, or None to use the default folding rules.
"""

Expand Down Expand Up @@ -1201,7 +1220,6 @@
node.domain,
node.op_type,
)

return None

if _is_non_deterministic_op(node):
Expand Down Expand Up @@ -1240,8 +1258,7 @@
for op_type in DEFAULT_CONSTANT_FOLD_BLACKLIST:
if _is_onnx_op(node, op_type):
logger.info(
"Skipping constant folding for node %r because "
"%s is preserved by default",
"Skipping constant folding for node %r because %s is preserved by default",
node.name,
op_type,
)
Expand Down Expand Up @@ -1464,12 +1481,11 @@

Returns:
An instance of `FoldConstantsResult`.

"""
folder_pass = FoldConstantsPass(
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
should_fold=should_fold,
)
return folder_pass(model) # type: ignore[return-value]
return folder_pass(model)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Check warning

Code scanning / lintrunner

RUFF/W292 Warning

Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s trailing whitespace on the return line. Also, removing # type: ignore[return-value] may reintroduce static type-check failures if InPlacePass.__call__ isn’t typed to return FoldConstantsResult. If the repo runs mypy/pyright, consider either restoring the targeted ignore or adjusting the pass/call typing so folder_pass(model) is correctly typed without ignores.

Suggested change
return folder_pass(model)
return typing.cast(FoldConstantsResult, folder_pass(model))

Copilot uses AI. Check for mistakes.
99 changes: 79 additions & 20 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,76 @@


class FoldConstantsTest(unittest.TestCase):
def test_fold_if_cond_with_subgraph_initializer(self):
model = ir.from_onnx_text("""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[16, 16] x, bool cond) => (float[16, 16] z) {
two = Constant <value_float=2.0> ()
three = Constant <value_float=3.0> ()
z = If (cond) <
then_branch = then_graph () => (then_z) {
temp = Add (two, three)
then_z = Mul (temp, x)
},
else_branch = else_graph () => (else_z) {
else_z = Identity (x)
}
>
}
""")

# Pass 1: fold Add(2.0, 3.0) into a subgraph initializer called 'temp'.
# The If condition is still non-constant so the branch is NOT inlined yet.
_constant_folding.fold_constants(model)
optimizer.remove_unused_nodes(model)
if_node = next(n for n in model.graph if n.op_type == "If")
then_branch = if_node.attributes["then_branch"].as_graph()
self.assertIn("temp", then_branch.initializers)
self.assertNotIn("temp", model.graph.initializers)

# Make the condition a known True constant to trigger branch inlining.
const_true = ir.Value(name="const_true")
const_true.const_value = ir.Tensor(np.array(True))
if_node.replace_input_with(0, const_true)

# Pass 2: inline the If branch.
# 'temp' must be migrated from the subgraph to the main graph.
_constant_folding.fold_constants(model)
optimizer.remove_unused_nodes(model)
self.assertIn("temp", model.graph.initializers)
onnx.checker.check_model(ir.serde.serialize_model(model))

def test_fold_if_cond_with_subgraph_initializer_name_collision(self):
"""Subgraph initializer names that clash with main-graph names get a unique suffix."""
Comment on lines +56 to +57
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test currently asserts only that it 'must not crash' and that the final model passes onnx.checker.check_model, but it does not verify that a name collision actually occurred or that suffixing happened (and the chosen graph text may not force a collision depending on what gets folded into initializers). To make this a robust regression test for the collision logic, assert that: (1) the main graph already contains an initializer with the colliding name before inlining, (2) after inlining there are two distinct initializers (original + suffixed), and (3) the inlined nodes reference the suffixed initializer (not the original).

Copilot uses AI. Check for mistakes.
model = ir.from_onnx_text("""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[1, 4] x, bool cond) => (float[4] z) {
axes_val = Constant <value_ints=[0]> ()
z = If (cond) <
then_branch = then_branch_graph () => (then_z) {
axes_val_inner = Constant <value_ints=[0]> ()
then_z = Squeeze (x, axes_val_inner)
},
else_branch = else_branch_graph () => (else_z) {
else_z = Squeeze (x, axes_val)
}
>
}
""")

_constant_folding.fold_constants(model)
optimizer.remove_unused_nodes(model)

if_node = next(n for n in model.graph if n.op_type == "If")
const_true = ir.Value(name="const_true_collision")
const_true.const_value = ir.Tensor(np.array(True))
if_node.replace_input_with(0, const_true)

# Must not crash or silently overwrite on name collision.
_constant_folding.fold_constants(model)
optimizer.remove_unused_nodes(model)
onnx.checker.check_model(ir.serde.serialize_model(model))

def _fold(
self,
model: ir.Model | str,
Expand Down Expand Up @@ -236,9 +306,7 @@
self.assertEqual(len(optimized.graph), 1)
self.assertIn("C", optimized.graph.initializers)

def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(
self,
):
def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(self):
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'squence' appears to be a misspelling of 'sequence' in these updated test names. Since these lines were modified in this PR, it’s a good opportunity to correct the spelling for clarity and searchability.

Suggested change
def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(self):
def test_static_split_to_sequence_with_scalar_split_and_sequence_at_is_folded_as_split(self):

Copilot uses AI. Check for mistakes.
model = """
<
ir_version: 8,
Expand All @@ -260,8 +328,7 @@

# TODO: There is an unrelated limitation that `symbolic_value` is not
# utilized when the value is only referenced by graph output.
# E.g., the following test model will not have this optimization
# applied.
# E.g., the following test model will not have this optimization applied.
#
# <
# ir_version: 8,
Expand All @@ -284,9 +351,7 @@
self.assertEqual(len(optimized.graph[-2].outputs), 4)
self.assertEqual(optimized.graph[-2].op_type, "Split")

def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split(
self,
):
def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split(self):
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'squence' appears to be a misspelling of 'sequence' in these updated test names. Since these lines were modified in this PR, it’s a good opportunity to correct the spelling for clarity and searchability.

Copilot uses AI. Check for mistakes.
model = """
<
ir_version: 8,
Expand All @@ -309,9 +374,7 @@
self.assertEqual(len(optimized.graph[-2].outputs), 3)
self.assertEqual(optimized.graph[-2].op_type, "Split")

def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze(
self,
):
def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze(self):
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'squence' appears to be a misspelling of 'sequence' in these updated test names. Since these lines were modified in this PR, it’s a good opportunity to correct the spelling for clarity and searchability.

Copilot uses AI. Check for mistakes.
model = """
<
ir_version: 8,
Expand All @@ -334,9 +397,7 @@
self.assertEqual(optimized.graph[1].op_type, "Split")
self.assertEqual(len([n for n in optimized.graph if n.op_type == "Squeeze"]), 3)

def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0(
self,
):
def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0(self):
model = """
<
ir_version: 8,
Expand All @@ -352,9 +413,7 @@
self.assertEqual(len(optimized.graph), 3)
self.assertEqual(optimized.graph[2].op_type, "Concat")

def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
self,
):
def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(self):
model = """
<
ir_version: 8,
Expand Down Expand Up @@ -736,8 +795,8 @@
self.assertEqual([input.name for input in optimized.graph.inputs], ["x"])

# This should not be constant-foldable as the constant references an
# attribute and thus the shape cannot be resolved. At the same time it
# should not fail due to the attribute value being None in
# attribute and thus the shape cannot be resolved.
# At the same time it should not fail due to the attribute value being None in
# _process_constant_node
def test_attribute_reference(self):
model = """
Expand Down Expand Up @@ -798,4 +857,4 @@


if __name__ == "__main__":
unittest.main()
unittest.main()

Check warning

Code scanning / lintrunner

RUFF/W292 Warning

Loading