From bfb26eb111f5652e5e666a7d35c1921442a52646 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 Oct 2025 22:01:33 -0700 Subject: [PATCH 1/3] Create default values for the constant folding pass Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 8317d2be6..38397715a 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -959,9 +959,9 @@ class FoldConstantsPass(ir.passes.InPlacePass): def __init__( self, *, - shape_inference: bool, - input_size_limit: int, - output_size_limit: int, + shape_inference: bool = True, + input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, + output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, should_fold: Callable[[ir.Node], bool | None] = lambda node: None, ) -> None: self.shape_inference = shape_inference From b7f51ad06a1abe92c05df764b0a87ccb4faade72 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 Oct 2025 22:23:22 -0700 Subject: [PATCH 2/3] Support multi outputs and store replacement as initializers Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 97 +++++++++++++---------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 38397715a..0d6daa10c 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1038,51 +1038,34 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - def new_constant(self, node: ir.Node, value) -> ir.Node | None: - irvalue = node.outputs[0] - if not isinstance(value, np.ndarray): + def new_initializer(self, old_value, array) -> ir.Value | None: + if not isinstance(array, np.ndarray): # ONNX does not have a way to represent non-tensor constants, eg. a sequence. # So, a constant-value of type sequence is not folded, but it can be used # to optimize subsequent operations when possible. logger.info( "Skip storing constant folded value %s due to unsupported type %s.", - irvalue.name, - type(value), + old_value.name, + type(array), ) return None - tensor = ir.tensor(value) - tensor.name = irvalue.name - irvalue.const_value = tensor - - if value.size > self.output_size_limit: - # Handle examples like Transpose(weight) to be folded even if the size is large, - # as long as weight has no other uses. This won't increase model size. - removed_input_size = 0 - for input in node.inputs: - if (input is not None) and (len(input.uses()) == 1): - array = _get_numpy_value(input) - if array is not None: - removed_input_size += array.size - increased_size = value.size - removed_input_size - if increased_size > 0: - logger.info( - "Skip storing constant folded nvalue %s due to large size %s.", - irvalue.name, - value.size, - ) - return None + tensor = ir.tensor(array) + tensor.name = old_value.name + new_value = ir.Value( + name=old_value.name, + type=ir.TensorType(ir.DataType(tensor.dtype)), + shape=tensor.shape, + const_value=tensor, + ) logger.debug( "New constant for value %s dtype: %s shape: %s", - irvalue.name, - value.dtype, - value.shape, + old_value.name, + new_value.dtype, + new_value.shape, ) - - attributes = ir.convenience.convert_attributes({"value": tensor}) - node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) - return node + return new_value def process_node(self, node: ir.Node) -> Replacement | None: """Process a node and return a Replacement if the node can be replaced.""" @@ -1221,16 +1204,48 @@ def convert(av): if outputs is None: return None - if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): - replacement = self.new_constant(node, outputs) - if replacement is None: - return None - return Replacement(replacement.outputs, [replacement]) - else: + + if not isinstance(outputs, (list, tuple)): + outputs = [outputs] + + if len(outputs) != len(node.outputs): logger.warning( - "Skipping constant folding for op %s with multiple outputs.", node.op_type + "Skipping constant folding for op %s because number of outputs do not match: %d => %d", + node.op_type, + len(node.outputs), + len(outputs), ) - return None + return None + + # Whether we will fold the node regardless of sizes + can_ignore_output_limit = ( + should_fold is True or (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS + ) + replacement_values: list[ir.Value] = [] + for i, array in enumerate(outputs): + new_initializer = self.new_initializer(node.outputs[i], array) + if new_initializer is None: + # Could not create a new initializer for the output + return None + if ( + new_initializer.const_value.size > self.output_size_limit + and not can_ignore_output_limit + ): + logger.info( + "Skipping constant folding for node %r because output size %d exceeds limit %d", + node.name, + new_initializer.const_value.size, + self.output_size_limit, + ) + return None + + replacement_values.append(new_initializer) + + for value in replacement_values: + assert node.graph is not None + node.graph.initializers.add(value) + + return Replacement(replacement_values, []) def replace_node( self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function From c62f45fee5f891dda8aec3178536f494aef421a7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 2 Oct 2025 22:31:35 -0700 Subject: [PATCH 3/3] WIP create initializers Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 0d6daa10c..bd81a8eec 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -935,6 +935,8 @@ def _record_contributing_values(original_node: ir.Node, replacement: Replacement assert input.name is not None folded_from.add(input.name) + + for new_output in replacement.new_outputs: if new_output is None: continue @@ -1238,7 +1240,7 @@ def convert(av): self.output_size_limit, ) return None - + assert new_initializer.const_value is not None replacement_values.append(new_initializer) for value in replacement_values: