From c662194f5170b3e5e1743f6e3dd17b54dcbed938 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Wed, 2 Apr 2025 22:23:13 +0000 Subject: [PATCH 01/11] support lift constants to initializers pass --- .../common/lift_constants_to_initializers.py | 50 +++++++++++++++++ .../lift_constants_to_initializers_test.py | 55 +++++++++++++++++++ onnxscript/optimizer/_optimizer.py | 1 + 3 files changed, 106 insertions(+) create mode 100644 onnxscript/ir/passes/common/lift_constants_to_initializers.py create mode 100644 onnxscript/ir/passes/common/lift_constants_to_initializers_test.py diff --git a/onnxscript/ir/passes/common/lift_constants_to_initializers.py b/onnxscript/ir/passes/common/lift_constants_to_initializers.py new file mode 100644 index 0000000000..fa35151227 --- /dev/null +++ b/onnxscript/ir/passes/common/lift_constants_to_initializers.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Lift constants to initializers.""" + +from __future__ import annotations + +__all__ = [ + "LiftConstantsToInitializersPass", +] + +import logging + +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +class LiftConstantsToInitializersPass(ir.passes.InPlacePass): + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Convert constant nodes in main graph to initializers.""" + count = 0 + for node in model.graph: + if node.op_type != "Constant": + continue + if "value" not in node.attributes: + logger.debug("Constant node '%s' has no 'value' attribute", node.name) + continue + # The value of attribute can only be ir.Attr, as + # ir.RefAttr is only defined in Functions. + tensor = node.attributes["value"].as_tensor() # type: ignore[union-attr] + # Register an initializer with the tensor value + initializer_name = node.outputs[0].name + assert initializer_name is not None + initializer = ir.Value( + name=initializer_name, + shape=tensor.shape, # type: ignore[arg-type] + type=ir.TensorType(tensor.dtype), + const_value=tensor, + ) + model.graph.initializers[initializer_name] = initializer + # Replace the constant node with the initilizer + ir.convenience.replace_all_uses_with(node.outputs[0], initializer) + model.graph.remove(node, safe=True) + count += 1 + logger.info( + "Converted constant node '%s' to initializer '%s'", node.name, initializer_name + ) + if count: + logger.info("Lifted %s constants to initializers", count) + return ir.passes.PassResult(model, modified=bool(count)) diff --git a/onnxscript/ir/passes/common/lift_constants_to_initializers_test.py b/onnxscript/ir/passes/common/lift_constants_to_initializers_test.py new file mode 100644 index 0000000000..2fdf558073 --- /dev/null +++ b/onnxscript/ir/passes/common/lift_constants_to_initializers_test.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np + +from onnxscript import ir +from onnxscript.ir.passes.common import lift_constants_to_initializers + + +class TestLiftConstantsToInitializersPass(unittest.TestCase): + def test_pass_with_lifting_constants_to_initializers(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 3)), + ), + ] + + constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + attribute = ir.convenience.convert_attributes({"value": constant_tensor}) + const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1) + add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]]) + mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=mul_node.outputs, + nodes=[const_node, add_node, mul_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Check that the initializer is not in the graph yet + self.assertEqual(len(model.graph.initializers), 0) + # And 1 constant node + self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) + + # Perform lift constants to initializers + result = lift_constants_to_initializers.LiftConstantsToInitializersPass()(model) + self.assertTrue(result.modified) + # Check that the constant node is lifted to an initializer + self.assertEqual(len(result.model.graph.initializers), 1) + # And 0 constant node + self.assertEqual( + len([node for node in result.model.graph if node.op_type == "Constant"]), 0 + ) diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index dd3c8563c2..6df0b513b6 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -62,3 +62,4 @@ def optimize_ir( ) rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) onnxscript.optimizer.remove_unused_nodes(model) + ir.passes.common.lift_constants_to_initializers.LiftConstantsToInitializersPass()(model) From c0ea8e62fda736526239332c1981c5c1c8e7f72b Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Wed, 2 Apr 2025 22:58:19 +0000 Subject: [PATCH 02/11] lint --- onnxscript/optimizer/_optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 496e787b36..df94ef97e7 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -4,6 +4,7 @@ import logging +import onnxscript.ir.passes.common.lift_constants_to_initializers import onnxscript.ir.passes.common.unused_removal import onnxscript.optimizer from onnxscript import ir, rewriter From 7db9be02e41b5826c0fc06acc8b708882db17381 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 3 Apr 2025 00:17:27 +0000 Subject: [PATCH 03/11] update name authority --- onnxscript/ir/_core.py | 10 ++++++++++ .../ir/passes/common/lift_constants_to_initializers.py | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index ddb0e80309..2da442158c 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1924,6 +1924,10 @@ def __init__( # Be sure the initialize the name authority before extending the nodes # because it is used to name the nodes and their outputs self._name_authority = _name_authority.NameAuthority() + # NOTE: input and initializer value names could be duplicated + # to auto-generated value names from name authority and crash ort + # https://github.com/microsoft/onnxruntime/blob/bc7b07dbb41a2f441dbed1a91855563ba0dd8a31/onnxruntime/core/graph/graph.cc#L1536 + self._set_input_and_initializer_value_names_into_name_authority() # Call self.extend not self._nodes.extend so the graph reference is added to the nodes self.extend(nodes) @@ -1999,6 +2003,12 @@ def __iter__(self) -> Iterator[Node]: def __reversed__(self) -> Iterator[Node]: return reversed(self._nodes) + def _set_input_and_initializer_value_names_into_name_authority(self): + for value in self.inputs: + self._name_authority.register_or_name_value(value) + for value in self.initializers.values(): + self._name_authority.register_or_name_value(value) + def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node: """Set the graph reference for the node and assign names to it and its outputs if they don't have one.""" if node.graph is not None and node.graph is not self: diff --git a/onnxscript/ir/passes/common/lift_constants_to_initializers.py b/onnxscript/ir/passes/common/lift_constants_to_initializers.py index fa35151227..6864e64b03 100644 --- a/onnxscript/ir/passes/common/lift_constants_to_initializers.py +++ b/onnxscript/ir/passes/common/lift_constants_to_initializers.py @@ -20,7 +20,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: """Convert constant nodes in main graph to initializers.""" count = 0 for node in model.graph: - if node.op_type != "Constant": + if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): continue if "value" not in node.attributes: logger.debug("Constant node '%s' has no 'value' attribute", node.name) @@ -37,6 +37,8 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: type=ir.TensorType(tensor.dtype), const_value=tensor, ) + # TODO(titaiwang): Is it possible that the initializer name has + # been taken? model.graph.initializers[initializer_name] = initializer # Replace the constant node with the initilizer ir.convenience.replace_all_uses_with(node.outputs[0], initializer) From 1956c7ffafebb3cf43da73457ea6b0f8eed41b07 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 4 Apr 2025 23:33:35 +0000 Subject: [PATCH 04/11] update --- onnxscript/ir/_core.py | 4 +--- ...s_to_initializers.py => constant_manipulation.py} | 2 +- ...alizers_test.py => constant_manipulation_test.py} | 12 +++++++++--- onnxscript/optimizer/_optimizer.py | 4 ++-- 4 files changed, 13 insertions(+), 9 deletions(-) rename onnxscript/ir/passes/common/{lift_constants_to_initializers.py => constant_manipulation.py} (96%) rename onnxscript/ir/passes/common/{lift_constants_to_initializers_test.py => constant_manipulation_test.py} (84%) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 2da442158c..10447ec4b5 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1924,9 +1924,7 @@ def __init__( # Be sure the initialize the name authority before extending the nodes # because it is used to name the nodes and their outputs self._name_authority = _name_authority.NameAuthority() - # NOTE: input and initializer value names could be duplicated - # to auto-generated value names from name authority and crash ort - # https://github.com/microsoft/onnxruntime/blob/bc7b07dbb41a2f441dbed1a91855563ba0dd8a31/onnxruntime/core/graph/graph.cc#L1536 + # TODO(justinchuby): Trigger again if inputs or initializers are modified. self._set_input_and_initializer_value_names_into_name_authority() # Call self.extend not self._nodes.extend so the graph reference is added to the nodes self.extend(nodes) diff --git a/onnxscript/ir/passes/common/lift_constants_to_initializers.py b/onnxscript/ir/passes/common/constant_manipulation.py similarity index 96% rename from onnxscript/ir/passes/common/lift_constants_to_initializers.py rename to onnxscript/ir/passes/common/constant_manipulation.py index 6864e64b03..668a88f470 100644 --- a/onnxscript/ir/passes/common/lift_constants_to_initializers.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -39,7 +39,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: ) # TODO(titaiwang): Is it possible that the initializer name has # been taken? - model.graph.initializers[initializer_name] = initializer + model.graph.register_initializer(initializer) # Replace the constant node with the initilizer ir.convenience.replace_all_uses_with(node.outputs[0], initializer) model.graph.remove(node, safe=True) diff --git a/onnxscript/ir/passes/common/lift_constants_to_initializers_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py similarity index 84% rename from onnxscript/ir/passes/common/lift_constants_to_initializers_test.py rename to onnxscript/ir/passes/common/constant_manipulation_test.py index 2fdf558073..d9b1ef3743 100644 --- a/onnxscript/ir/passes/common/lift_constants_to_initializers_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -7,8 +7,9 @@ import numpy as np from onnxscript import ir -from onnxscript.ir.passes.common import lift_constants_to_initializers +from onnxscript.ir.passes.common import constant_manipulation +import onnx class TestLiftConstantsToInitializersPass(unittest.TestCase): def test_pass_with_lifting_constants_to_initializers(self): @@ -45,11 +46,16 @@ def test_pass_with_lifting_constants_to_initializers(self): self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) # Perform lift constants to initializers - result = lift_constants_to_initializers.LiftConstantsToInitializersPass()(model) + result = constant_manipulation.LiftConstantsToInitializersPass()(model) self.assertTrue(result.modified) # Check that the constant node is lifted to an initializer self.assertEqual(len(result.model.graph.initializers), 1) + # Check the value + self.assertEqual( + result.model.graph.initializers["val_0"].const_value, # name created by name_authority + constant_tensor, + ) # And 0 constant node self.assertEqual( len([node for node in result.model.graph if node.op_type == "Constant"]), 0 - ) + ) \ No newline at end of file diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index df94ef97e7..a1d1a65208 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -4,7 +4,7 @@ import logging -import onnxscript.ir.passes.common.lift_constants_to_initializers +import onnxscript.ir.passes.common.constant_manipulation import onnxscript.ir.passes.common.unused_removal import onnxscript.optimizer from onnxscript import ir, rewriter @@ -71,7 +71,7 @@ def optimize_ir( early_stop=stop_if_no_change, ), onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(), - onnxscript.ir.passes.common.lift_constants_to_initializers.LiftConstantsToInitializersPass(), + onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(), ) assert optimizer_pass.in_place result = optimizer_pass(model) From 0ba7025d02955a37e25551cd8fd04482ce4a6c2c Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Sat, 5 Apr 2025 00:05:51 +0000 Subject: [PATCH 05/11] add constant attribute variations --- .../ir/passes/common/constant_manipulation.py | 58 +++++++++++++++++-- .../common/constant_manipulation_test.py | 7 ++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 668a88f470..3637faa920 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -10,6 +10,8 @@ import logging +import numpy as np + from onnxscript import ir logger = logging.getLogger(__name__) @@ -22,15 +24,61 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: for node in model.graph: if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): continue - if "value" not in node.attributes: - logger.debug("Constant node '%s' has no 'value' attribute", node.name) + + allowed_constant_attributes = { + "value", + "value_int", + "value_ints", + "value_float", + "value_floats", + "value_string", + "value_strings", + } + constant_node_attribute = set(node.attributes.keys()) + if len(constant_node_attribute) != 1: + logger.debug( + "Invalid constant node '%s' has more than one attribute", node.name + ) + continue + if constant_node_attribute not in allowed_constant_attributes: + logger.debug("Invalid constant node '%s' has unsupported attribute", node.name) continue + + initializer_name = node.outputs[0].name + assert initializer_name is not None # The value of attribute can only be ir.Attr, as # ir.RefAttr is only defined in Functions. - tensor = node.attributes["value"].as_tensor() # type: ignore[union-attr] + attr_value = node.attributes[constant_node_attribute] + if constant_node_attribute == "value": + tensor = attr_value.as_tensor() # type: ignore[union-attr] + elif constant_node_attribute == "value_int": + tensor = ir.Tensor( + np.array(attr_value.as_int(), dtype=np.int64), name=initializer_name + ) + elif constant_node_attribute == "value_ints": + tensor = ir.Tensor( + np.array(attr_value.as_ints(), dtype=np.int64), name=initializer_name + ) + elif constant_node_attribute == "value_float": + tensor = ir.Tensor( + np.array(attr_value.as_float(), dtype=np.float32), name=initializer_name + ) + elif constant_node_attribute == "value_floats": + tensor = ir.Tensor( + np.array(attr_value.as_floats(), dtype=np.float32), name=initializer_name + ) + elif constant_node_attribute == "value_string": + tensor = ir.Tensor( + np.array(attr_value.as_string(), dtype=np.object_), name=initializer_name + ) + elif constant_node_attribute == "value_strings": + tensor = ir.Tensor( + np.array(attr_value.as_strings(), dtype=np.object_), name=initializer_name + ) + else: + logger.debug("Invalid constant node '%s' has unsupported attribute", node.name) + continue # Register an initializer with the tensor value - initializer_name = node.outputs[0].name - assert initializer_name is not None initializer = ir.Value( name=initializer_name, shape=tensor.shape, # type: ignore[arg-type] diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index d9b1ef3743..5a7c9abf55 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -9,7 +9,6 @@ from onnxscript import ir from onnxscript.ir.passes.common import constant_manipulation -import onnx class TestLiftConstantsToInitializersPass(unittest.TestCase): def test_pass_with_lifting_constants_to_initializers(self): @@ -52,10 +51,12 @@ def test_pass_with_lifting_constants_to_initializers(self): self.assertEqual(len(result.model.graph.initializers), 1) # Check the value self.assertEqual( - result.model.graph.initializers["val_0"].const_value, # name created by name_authority + result.model.graph.initializers[ + "val_0" + ].const_value, # name created by name_authority constant_tensor, ) # And 0 constant node self.assertEqual( len([node for node in result.model.graph if node.op_type == "Constant"]), 0 - ) \ No newline at end of file + ) From c83097215532d06111f9daf82cd232fa4ac1ecd8 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Mon, 7 Apr 2025 17:24:05 +0000 Subject: [PATCH 06/11] add tests --- .../ir/passes/common/constant_manipulation.py | 85 ++++++++++--------- .../common/constant_manipulation_test.py | 17 ++-- 2 files changed, 54 insertions(+), 48 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 3637faa920..d154f295dc 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -25,58 +25,24 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): continue - allowed_constant_attributes = { - "value", - "value_int", - "value_ints", - "value_float", - "value_floats", - "value_string", - "value_strings", - } constant_node_attribute = set(node.attributes.keys()) if len(constant_node_attribute) != 1: logger.debug( "Invalid constant node '%s' has more than one attribute", node.name ) continue - if constant_node_attribute not in allowed_constant_attributes: - logger.debug("Invalid constant node '%s' has unsupported attribute", node.name) - continue + attr_name, attr_value = next(iter(node.attributes.items())) initializer_name = node.outputs[0].name assert initializer_name is not None - # The value of attribute can only be ir.Attr, as - # ir.RefAttr is only defined in Functions. - attr_value = node.attributes[constant_node_attribute] - if constant_node_attribute == "value": - tensor = attr_value.as_tensor() # type: ignore[union-attr] - elif constant_node_attribute == "value_int": - tensor = ir.Tensor( - np.array(attr_value.as_int(), dtype=np.int64), name=initializer_name - ) - elif constant_node_attribute == "value_ints": - tensor = ir.Tensor( - np.array(attr_value.as_ints(), dtype=np.int64), name=initializer_name - ) - elif constant_node_attribute == "value_float": - tensor = ir.Tensor( - np.array(attr_value.as_float(), dtype=np.float32), name=initializer_name - ) - elif constant_node_attribute == "value_floats": - tensor = ir.Tensor( - np.array(attr_value.as_floats(), dtype=np.float32), name=initializer_name - ) - elif constant_node_attribute == "value_string": - tensor = ir.Tensor( - np.array(attr_value.as_string(), dtype=np.object_), name=initializer_name - ) - elif constant_node_attribute == "value_strings": - tensor = ir.Tensor( - np.array(attr_value.as_strings(), dtype=np.object_), name=initializer_name + assert isinstance(attr_value, ir.Attr) + tensor = _constant_node_attribute_to_tensor( + attr_name, attr_value, initializer_name + ) + if tensor is None: + logger.debug( + "Invalid constant node '%s' has unsupported attribute value", node.name ) - else: - logger.debug("Invalid constant node '%s' has unsupported attribute", node.name) continue # Register an initializer with the tensor value initializer = ir.Value( @@ -98,3 +64,38 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: if count: logger.info("Lifted %s constants to initializers", count) return ir.passes.PassResult(model, modified=bool(count)) + + +def _constant_node_attribute_to_tensor( + attr_name: str, attr_value: ir.Attr, initializer_name: str +) -> ir.Tensor | None: + """Convert constant node attribute to tensor.""" + if attr_name == "value": + tensor = attr_value.as_tensor() # type: ignore[union-attr] + elif attr_name == "value_int": + tensor = ir.Tensor( + np.array(attr_value.as_int(), dtype=np.int64), name=initializer_name + ) + elif attr_name == "value_ints": + tensor = ir.Tensor( + np.array(attr_value.as_ints(), dtype=np.int64), name=initializer_name + ) + elif attr_name == "value_float": + tensor = ir.Tensor( + np.array(attr_value.as_float(), dtype=np.float32), name=initializer_name + ) + elif attr_name == "value_floats": + tensor = ir.Tensor( + np.array(attr_value.as_floats(), dtype=np.float32), name=initializer_name + ) + elif attr_name == "value_string": + tensor = ir.Tensor( + np.array(attr_value.as_string(), dtype=np.object_), name=initializer_name + ) + elif attr_name == "value_strings": + tensor = ir.Tensor( + np.array(attr_value.as_strings(), dtype=np.object_), name=initializer_name + ) + else: + tensor = None + return tensor # type: ignore[return-value] diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index 5a7c9abf55..32e859a98f 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -5,25 +5,30 @@ import unittest import numpy as np +import parameterized from onnxscript import ir from onnxscript.ir.passes.common import constant_manipulation class TestLiftConstantsToInitializersPass(unittest.TestCase): - def test_pass_with_lifting_constants_to_initializers(self): + @parameterized.parameterized.expand( + [ + (ir.DataType.FLOAT, np.float32), + (ir.DataType.INT64, np.int64), + ] + ) + def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype): inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), + ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))), ir.Value( name="input_b", - type=ir.TensorType(ir.DataType.FLOAT), + type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3)), ), ] - constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype)) attribute = ir.convenience.convert_attributes({"value": constant_tensor}) const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1) add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]]) From a285ceb97918447d7a34d159236bae4785a3f7de Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Wed, 9 Apr 2025 00:24:31 +0000 Subject: [PATCH 07/11] add support to subgraph --- .../ir/passes/common/constant_manipulation.py | 41 +++++----- .../common/constant_manipulation_test.py | 75 +++++++++++++++++++ 2 files changed, 92 insertions(+), 24 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index d154f295dc..b13600a3ff 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -19,9 +19,9 @@ class LiftConstantsToInitializersPass(ir.passes.InPlacePass): def call(self, model: ir.Model) -> ir.passes.PassResult: - """Convert constant nodes in main graph to initializers.""" + """Convert constant nodes from node belonged graph to its initializers.""" count = 0 - for node in model.graph: + for node in ir.traversal.RecursiveGraphIterator(model.graph): if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): continue @@ -51,18 +51,17 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: type=ir.TensorType(tensor.dtype), const_value=tensor, ) - # TODO(titaiwang): Is it possible that the initializer name has - # been taken? - model.graph.register_initializer(initializer) + assert node.graph is not None + node.graph.register_initializer(initializer) # Replace the constant node with the initilizer ir.convenience.replace_all_uses_with(node.outputs[0], initializer) - model.graph.remove(node, safe=True) + node.graph.remove(node, safe=True) count += 1 - logger.info( + logger.debug( "Converted constant node '%s' to initializer '%s'", node.name, initializer_name ) if count: - logger.info("Lifted %s constants to initializers", count) + logger.debug("Lifted %s constants to initializers", count) return ir.passes.PassResult(model, modified=bool(count)) @@ -73,28 +72,22 @@ def _constant_node_attribute_to_tensor( if attr_name == "value": tensor = attr_value.as_tensor() # type: ignore[union-attr] elif attr_name == "value_int": - tensor = ir.Tensor( - np.array(attr_value.as_int(), dtype=np.int64), name=initializer_name - ) + tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name) elif attr_name == "value_ints": - tensor = ir.Tensor( - np.array(attr_value.as_ints(), dtype=np.int64), name=initializer_name + tensor = ir.tensor( + attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name ) elif attr_name == "value_float": - tensor = ir.Tensor( - np.array(attr_value.as_float(), dtype=np.float32), name=initializer_name + tensor = ir.tensor( + attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name ) elif attr_name == "value_floats": - tensor = ir.Tensor( - np.array(attr_value.as_floats(), dtype=np.float32), name=initializer_name - ) - elif attr_name == "value_string": - tensor = ir.Tensor( - np.array(attr_value.as_string(), dtype=np.object_), name=initializer_name + tensor = ir.tensor( + attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name ) - elif attr_name == "value_strings": - tensor = ir.Tensor( - np.array(attr_value.as_strings(), dtype=np.object_), name=initializer_name + elif attr_name in ("value_string", "value_strings"): + tensor = ir.StringTensor( + np.array(attr_value.value, dtype=np.bytes_), name=initializer_name ) else: tensor = None diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index 32e859a98f..3ae6470aa0 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -65,3 +65,78 @@ def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype self.assertEqual( len([node for node in result.model.graph if node.op_type == "Constant"]), 0 ) + + def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ) + + then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + attribute = ir.convenience.convert_attributes({"value": then_constant_tensor}) + then_const_node = ir.Node( + "", "Constant", inputs=[], attributes=attribute, num_outputs=1 + ) + # then branch adds the constant to the input + # else branch multiplies the input by the constant + add_node = ir.Node("", "Add", inputs=[input_value, then_const_node.outputs[0]]) + then_graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[then_const_node, add_node], + opset_imports={"": 20}, + ) + else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + attribute = ir.convenience.convert_attributes({"value": else_constant_tensor}) + else_const_node = ir.Node( + "", "Constant", inputs=[], attributes=attribute, num_outputs=1 + ) + mul_node = ir.Node("", "Mul", inputs=[input_value, else_const_node.outputs[0]]) + else_graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[else_const_node, mul_node], + opset_imports={"": 20}, + ) + # create a conditional node that uses the then and else graphs + attribute = ir.convenience.convert_attributes( + {"then_branch": then_graph, "else_branch": else_graph} + ) + cond_node = ir.Node( + "", + "If", + inputs=[input_value], + attributes=attribute, + num_outputs=1, + ) + # construnct the model + main_graph = ir.Graph( + inputs=[input_value], + outputs=cond_node.outputs, + nodes=[cond_node], + opset_imports={"": 20}, + ) + main_graph.sort() + model = ir.Model( + graph=main_graph, + ir_version=10, + ) + result = constant_manipulation.LiftConstantsToInitializersPass()(model) + self.assertTrue(result.modified) + # Check that the constant node is lifted to the subgraph initializers + for node in ir.traversal.RecursiveGraphIterator(result.model.graph): + if node.op_type == "Constant": + raise AssertionError( + f"Constant node '{node.name}' was not lifted to initializers" + ) + if node.op_type == "Add": + self.assertEqual(len(node.graph.initializers), 1) + self.assertEqual( + node.graph.initializers["val_0"].const_value, + then_constant_tensor, + ) + if node.op_type == "Mul": + self.assertEqual(len(node.graph.initializers), 1) + self.assertEqual( + node.graph.initializers["val_0"].const_value, + else_constant_tensor, + ) From 22df6745f7d8bfdef37ba18c843e280fecb9f808 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 10 Apr 2025 00:07:37 +0000 Subject: [PATCH 08/11] add new tests --- .../ir/passes/common/constant_manipulation.py | 1 + .../common/constant_manipulation_test.py | 113 +++++++++++++----- 2 files changed, 83 insertions(+), 31 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index b13600a3ff..3032b33d44 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -52,6 +52,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: const_value=tensor, ) assert node.graph is not None + assert isinstance(node.graph, ir.Graph) node.graph.register_initializer(initializer) # Replace the constant node with the initilizer ir.convenience.replace_all_uses_with(node.outputs[0], initializer) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index 3ae6470aa0..12b13ee2ce 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -18,7 +18,9 @@ class TestLiftConstantsToInitializersPass(unittest.TestCase): (ir.DataType.INT64, np.int64), ] ) - def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype): + def test_pass_with_lifting_float_and_int_constants_to_initializers( + self, ir_dtype, numpy_dtype + ): inputs = [ ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))), ir.Value( @@ -29,10 +31,11 @@ def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype ] constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype)) - attribute = ir.convenience.convert_attributes({"value": constant_tensor}) - const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1) - add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]]) - mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]]) + const_node = ir.node( + "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 + ) + add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) model = ir.Model( graph=ir.Graph( @@ -72,13 +75,12 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): ) then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - attribute = ir.convenience.convert_attributes({"value": then_constant_tensor}) - then_const_node = ir.Node( - "", "Constant", inputs=[], attributes=attribute, num_outputs=1 + then_const_node = ir.node( + "Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1 ) # then branch adds the constant to the input # else branch multiplies the input by the constant - add_node = ir.Node("", "Add", inputs=[input_value, then_const_node.outputs[0]]) + add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) then_graph = ir.Graph( inputs=[input_value], outputs=[add_node.outputs[0]], @@ -86,11 +88,10 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): opset_imports={"": 20}, ) else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - attribute = ir.convenience.convert_attributes({"value": else_constant_tensor}) - else_const_node = ir.Node( - "", "Constant", inputs=[], attributes=attribute, num_outputs=1 + else_const_node = ir.node( + "Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 ) - mul_node = ir.Node("", "Mul", inputs=[input_value, else_const_node.outputs[0]]) + mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) else_graph = ir.Graph( inputs=[input_value], outputs=[mul_node.outputs[0]], @@ -98,14 +99,10 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): opset_imports={"": 20}, ) # create a conditional node that uses the then and else graphs - attribute = ir.convenience.convert_attributes( - {"then_branch": then_graph, "else_branch": else_graph} - ) - cond_node = ir.Node( - "", + cond_node = ir.node( "If", inputs=[input_value], - attributes=attribute, + attributes={"then_branch": then_graph, "else_branch": else_graph}, num_outputs=1, ) # construnct the model @@ -128,15 +125,69 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): raise AssertionError( f"Constant node '{node.name}' was not lifted to initializers" ) - if node.op_type == "Add": - self.assertEqual(len(node.graph.initializers), 1) - self.assertEqual( - node.graph.initializers["val_0"].const_value, - then_constant_tensor, - ) - if node.op_type == "Mul": - self.assertEqual(len(node.graph.initializers), 1) - self.assertEqual( - node.graph.initializers["val_0"].const_value, - else_constant_tensor, - ) + self.assertEqual(len(else_graph.initializers), 1) + self.assertEqual(len(then_graph.initializers), 1) + self.assertEqual( + else_graph.initializers["val_0"].const_value, + else_constant_tensor, + ) + self.assertEqual( + then_graph.initializers["val_0"].const_value, + then_constant_tensor, + ) + + @parameterized.parameterized.expand( + [ + (1.0, "value_float", np.float32), + (1, "value_int", np.int64), + ("hello world!", "value_string", np.bytes_), + ([1.0, 2.0, 3.0], "value_floats", np.float32), + ([1, 2, 3], "value_ints", np.int64), + (["hello world!", "thank you."], "value_strings", np.bytes_), + ] + ) + def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( + self, value, constant_attribute, np_dtype + ): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ) + + constant_value = value + const_node = ir.node( + "Constant", + inputs=[], + attributes={constant_attribute: constant_value}, + num_outputs=1, + ) + identity_node_constant = ir.node( + "Identity", inputs=[const_node.outputs[0]], num_outputs=1 + ) + identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) + + model = ir.Model( + graph=ir.Graph( + inputs=[input_value], + outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]], + nodes=[identity_node_input, const_node, identity_node_constant], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Check that the initializer is not in the graph yet + assert len(model.graph.initializers) == 0 + # And 1 constant node + assert len([node for node in model.graph if node.op_type == "Constant"]) == 1 + + # Perform lift constants to initializers + result = constant_manipulation.LiftConstantsToInitializersPass()(model) + assert result.modified + # Check that the constant node is lifted to an initializer + assert len(result.model.graph.initializers) == 1 + self.assertTrue( + np.array_equal( + result.model.graph.initializers["val_1"].const_value.raw, + np.array(constant_value, dtype=np_dtype), + ) + ) From 82c4016178abac8771a2fc2c901c00be906849c0 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 10 Apr 2025 16:40:28 +0000 Subject: [PATCH 09/11] address reviews --- .../ir/passes/common/constant_manipulation_test.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index 12b13ee2ce..66819792c8 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -14,13 +14,11 @@ class TestLiftConstantsToInitializersPass(unittest.TestCase): @parameterized.parameterized.expand( [ - (ir.DataType.FLOAT, np.float32), - (ir.DataType.INT64, np.int64), + (ir.DataType.FLOAT,), + (ir.DataType.INT64,), ] ) - def test_pass_with_lifting_float_and_int_constants_to_initializers( - self, ir_dtype, numpy_dtype - ): + def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype): inputs = [ ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))), ir.Value( @@ -30,7 +28,7 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers( ), ] - constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype)) + constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir_dtype.numpy())) const_node = ir.node( "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 ) @@ -127,11 +125,11 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): ) self.assertEqual(len(else_graph.initializers), 1) self.assertEqual(len(then_graph.initializers), 1) - self.assertEqual( + self.assertIs( else_graph.initializers["val_0"].const_value, else_constant_tensor, ) - self.assertEqual( + self.assertIs( then_graph.initializers["val_0"].const_value, then_constant_tensor, ) From a2e9d2afe377787890ea2adcb8eaf119356d2d20 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 10 Apr 2025 17:15:59 +0000 Subject: [PATCH 10/11] assert to self.assert --- .../passes/common/constant_manipulation_test.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index 66819792c8..f2e610d816 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -174,18 +174,16 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( ) # Check that the initializer is not in the graph yet - assert len(model.graph.initializers) == 0 + self.assertEqual(len(model.graph.initializers), 0) # And 1 constant node - assert len([node for node in model.graph if node.op_type == "Constant"]) == 1 + self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) # Perform lift constants to initializers result = constant_manipulation.LiftConstantsToInitializersPass()(model) - assert result.modified + self.assertTrue(result.modified) # Check that the constant node is lifted to an initializer - assert len(result.model.graph.initializers) == 1 - self.assertTrue( - np.array_equal( - result.model.graph.initializers["val_1"].const_value.raw, - np.array(constant_value, dtype=np_dtype), - ) + self.assertEqual(len(result.model.graph.initializers), 1) + np.testing.assert_array_equal( + result.model.graph.initializers["val_1"].const_value.raw, + np.array(constant_value, dtype=np_dtype), ) From e64816710fff8f5d0982a36585e915816fb284f1 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 10 Apr 2025 10:46:54 -0700 Subject: [PATCH 11/11] Update onnxscript/ir/passes/common/constant_manipulation_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/constant_manipulation_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index f2e610d816..2d1696e7fd 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -184,6 +184,6 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( # Check that the constant node is lifted to an initializer self.assertEqual(len(result.model.graph.initializers), 1) np.testing.assert_array_equal( - result.model.graph.initializers["val_1"].const_value.raw, + result.model.graph.initializers["val_1"].const_value.numpy(), np.array(constant_value, dtype=np_dtype), )