diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 78eb4398f3..fb93bc703f 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -5,6 +5,7 @@ from typing import Sequence, TypeVar, Union __all__ = [ + "merge_metadata", "pattern", "rewrite", "RewritePass", @@ -31,6 +32,7 @@ RewriteRule, RewriteRuleClassBase, RewriteRuleSet, + merge_metadata, ) from onnxscript.rewriter.rules.common import ( _basic_rules, diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 9c88aa848e..7c73a738ce 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -18,6 +18,7 @@ import onnxscript.rewriter._ir_utils as _ir_utils import onnxscript.rewriter._matcher as _matcher import onnxscript.rewriter._pattern_ir as _pattern_ir +import onnxscript.utils.metadata_merger as metadata_merger from onnxscript import ir from onnxscript.ir import _tape, convenience @@ -614,6 +615,15 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str: overload += 1 +_default_metadata_merger: metadata_merger.MetadataMerger = metadata_merger.MetadataMerger( + {RULE_NAME_TAG: metadata_merger.comma_separator_merger} +) + +# TODO(rama): Generalize this to support custom metadata mergers. For now, we just allow +# enabling/disabling the default merger. +merge_metadata: bool = True + + class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if not rules: @@ -740,6 +750,11 @@ def _apply_to_graph_or_function( delta.new_outputs, ) + if merge_metadata: + _default_metadata_merger.copy_merged_metadata( + delta.match.nodes, delta.new_nodes + ) + count += 1 break diff --git a/onnxscript/utils/metadata_merger.py b/onnxscript/utils/metadata_merger.py new file mode 100644 index 0000000000..121d8db8c8 --- /dev/null +++ b/onnxscript/utils/metadata_merger.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Merging metadata_props""" + +from __future__ import annotations + +from typing import Callable, Iterable + +import onnx_ir as ir + +# Utilities for merging metadata properties, represented as strings. +# The merging-logic will take care of special cases like missing metadata or +# empty string metadata, and so the functions defined below need not handle +# special cases like empty string. (This does assume that an empty string is +# the same as no metadata, which is a reasonable assumption for most metadata.) + +StringMerger = Callable[[str, str], str] + + +def overwrite(_: str, new: str) -> str: + return new + + +def join(separator: str) -> StringMerger: + """Creates a StringMerger that joins two strings with the given separator. + + Args: + separator (str): The separator to use when joining the strings. + + Returns: + StringMerger: A function that joins two strings with the specified separator. + """ + + def merger(first: str, second: str) -> str: + return f"{first}{separator}{second}" + + return merger + + +comma_separator_merger = join(", ") + + +class MetadataMerger: + """Merges metadata properties using specified merging logic. + + Attributes: + mergers: A mapping from metadata property keys to their corresponding merging functions. + default: The default merging function to use when a specific key does not have a defined merger. + If None, the first value is used. (Specify `overwrite` to always use the second value.) + """ + + def __init__( + self, mergers: dict[str, StringMerger], default: StringMerger | None = None + ) -> None: + self.mergers = mergers + self.default = default + + def update_dict(self, updated: dict[str, str], updates: dict[str, str]) -> None: + """Updates the first metadata property dictionary with values from the second. + + Args: + updated: The metadata dictionary to be updated. + updates: The updates metadata dictionary. + """ + for key, new_value in updates.items(): + if new_value == "": + continue + if (key in updated) and ((updated_value := updated[key]) != ""): + merger = self.mergers.get(key, self.default) + if merger is not None: + updated[key] = merger(updated_value, new_value) + else: + updated[key] = new_value + + def copy_merged_metadata( + self, from_nodes: Iterable[ir.Node], to: ir.Node | Iterable[ir.Node] + ) -> None: + """Merges metadata from multiple nodes and assigns it to one or more target nodes. + + Args: + from_nodes: The source nodes from which to merge metadata. + to: The target node(s) to which the merged metadata will be assigned. + """ + if isinstance(to, ir.Node): + updated = to.metadata_props + for node in from_nodes: + self.update_dict(updated, node.metadata_props) + elif len(to) == 1: + # Handle single node in iterable case + target_node = next(iter(to)) + updated = target_node.metadata_props + for node in from_nodes: + self.update_dict(updated, node.metadata_props) + else: + merged_metadata: dict[str, str] = {} + for node in from_nodes: + self.update_dict(merged_metadata, node.metadata_props) + for target_node in to: + self.update_dict(target_node.metadata_props, merged_metadata)