diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 5efaf784b0..656b9c0ab0 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -22,6 +22,7 @@ llama_rule_sets, no_op, pattern, + transpose_initializer, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) @@ -32,6 +33,7 @@ *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, *llama_rule_sets.llama_p0_rule_set().rules, + transpose_initializer.rule, ) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index bc90a92a21..49a265c1b2 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -18,7 +18,7 @@ import onnxscript.rewriter._matcher as _matcher import onnxscript.rewriter._pattern_ir as _pattern_ir from onnxscript import ir -from onnxscript.ir import _convenience, _tape +from onnxscript.ir import _tape T = TypeVar("T") @@ -81,6 +81,50 @@ def _update_opset_imports( ) +def _replace_nodes_and_values( + graph_or_function: ir.Graph | ir.Function, + /, + insertion_point: ir.Node, + old_nodes: Sequence[ir.Node], + new_nodes: Sequence[ir.Node], + old_values: Sequence[ir.Value], + new_values: Sequence[ir.Value], +) -> None: + """Replaces nodes and values in the graph or function. + + Args: + graph_or_function: The graph or function to replace nodes and values in. + insertion_point: The node to insert the new nodes after. + old_nodes: The nodes to replace. + new_nodes: The nodes to replace with. + old_values: The values to replace. + new_values: The values to replace with. + """ + + for old_value, new_value in zip(old_values, new_values): + # Propagate relevant info from old value to new value + if new_value.type is None: + new_value.type = old_value.type + if new_value.shape is None: + new_value.shape = old_value.shape + if new_value.const_value is None: + new_value.const_value = old_value.const_value + if new_value.name is None: + new_value.name = old_value.name + + # Reconnect the users of the deleted values to use the new values + ir.convenience.replace_all_uses_with(old_values, new_values) + # Update graph/function outputs if the node generates output + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(graph_or_function.outputs): + if graph_or_function_output in replacement_mapping: + graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + + # insert new nodes after the index node + graph_or_function.insert_after(insertion_point, new_nodes) + graph_or_function.remove(old_nodes, safe=True) + + class RewriteRule: def __init__( self, @@ -525,7 +569,7 @@ def _apply_to_graph_or_function( ) f = ir.Function(domain, name, overload, graph=graph, attributes=()) model.functions[f.identifier()] = f - _convenience.replace_nodes_and_values( + _replace_nodes_and_values( graph_or_function, node, delta.match.nodes if rule.remove_nodes else [], diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py new file mode 100644 index 0000000000..7afa9a2583 --- /dev/null +++ b/onnxscript/rewriter/transpose_initializer.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Rules to collapse Transpose nodes into initializers.""" + +from __future__ import annotations + +import logging + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import _ir_utils as ir_utils +from onnxscript.rewriter import pattern as orp + +logger = logging.getLogger(__name__) + + +class TransposeInitializer(orp.RewriteRuleClassBase): + """Folds Transpose nodes into initializers.""" + + def __init__(self): + super().__init__("TransposeInitializer", remove_nodes=True) + + def pattern(self, op, initializer): + return op.Transpose(initializer, _allow_other_attributes=True) + + def rewrite(self, op, initializer: ir.Value) -> ir.Value: + original_transpose = initializer.consumers()[0] + perm_attr = original_transpose.attributes.get("perm") + + if perm_attr is not None: + perm = perm_attr.as_ints() + else: + perm = None + + array = ir_utils.get_numpy_value(initializer) + if array is None: + # Do nothing + logger.debug("Failed to obtain the initializer value. Do nothing") + # perm=None is filtered out when the attribute is constructed so we are ok + return op.Transpose(initializer, perm=perm_attr) + + # np.transpose does not create a copy. So we don't need to use LazyTensors. + transposed = np.transpose(array, axes=perm) + new_name = f"{initializer.name}_transposed" + return op.initializer(ir.tensor(transposed, name=new_name)) + + def check(self, context, initializer: ir.Value) -> orp.MatchResult: + del context # Unused + check_result = orp.MatchResult() + if not initializer.is_initializer(): + return check_result.fail("Value is not an initializer") + if initializer.is_graph_input(): + return check_result.fail("Value is a graph input") + if initializer.const_value is None: + return check_result.fail("Value.const_value is None") + if len(initializer.uses()) != 1: + return check_result.fail("Initializer is used by more than one node") + # TODO(justinchuby): Avoid matching when it is a graph input + return check_result + + +rule = TransposeInitializer.rule()