Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Sequence, TypeVar, Union

__all__ = [
"merge_metadata",
"pattern",
"rewrite",
"RewritePass",
Expand All @@ -31,6 +32,7 @@
RewriteRule,
RewriteRuleClassBase,
RewriteRuleSet,
merge_metadata,
)
from onnxscript.rewriter.rules.common import (
_basic_rules,
Expand Down
15 changes: 15 additions & 0 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import onnxscript.rewriter._ir_utils as _ir_utils
import onnxscript.rewriter._matcher as _matcher
import onnxscript.rewriter._pattern_ir as _pattern_ir
import onnxscript.utils.metadata_merger as metadata_merger
from onnxscript import ir
from onnxscript.ir import _tape, convenience

Expand Down Expand Up @@ -614,6 +615,15 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:
overload += 1


_default_metadata_merger: metadata_merger.MetadataMerger = metadata_merger.MetadataMerger(
{RULE_NAME_TAG: metadata_merger.comma_separator_merger}
)

# TODO(rama): Generalize this to support custom metadata mergers. For now, we just allow
# enabling/disabling the default merger.
merge_metadata: bool = True


class RewriteRuleSet:
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:
if not rules:
Expand Down Expand Up @@ -740,6 +750,11 @@ def _apply_to_graph_or_function(
delta.new_outputs,
)

if merge_metadata:
_default_metadata_merger.copy_merged_metadata(
delta.match.nodes, delta.new_nodes
)

count += 1
break

Expand Down
99 changes: 99 additions & 0 deletions onnxscript/utils/metadata_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Merging metadata_props"""

from __future__ import annotations

from typing import Callable, Iterable

import onnx_ir as ir

# Utilities for merging metadata properties, represented as strings.
# The merging-logic will take care of special cases like missing metadata or
# empty string metadata, and so the functions defined below need not handle
# special cases like empty string. (This does assume that an empty string is
# the same as no metadata, which is a reasonable assumption for most metadata.)

StringMerger = Callable[[str, str], str]


def overwrite(_: str, new: str) -> str:
return new


def join(separator: str) -> StringMerger:
"""Creates a StringMerger that joins two strings with the given separator.

Args:
separator (str): The separator to use when joining the strings.

Returns:
StringMerger: A function that joins two strings with the specified separator.
"""

def merger(first: str, second: str) -> str:
return f"{first}{separator}{second}"

return merger


comma_separator_merger = join(", ")


class MetadataMerger:
"""Merges metadata properties using specified merging logic.

Attributes:
mergers: A mapping from metadata property keys to their corresponding merging functions.
default: The default merging function to use when a specific key does not have a defined merger.
If None, the first value is used. (Specify `overwrite` to always use the second value.)
"""

def __init__(
self, mergers: dict[str, StringMerger], default: StringMerger | None = None
) -> None:
self.mergers = mergers
self.default = default

def update_dict(self, updated: dict[str, str], updates: dict[str, str]) -> None:
"""Updates the first metadata property dictionary with values from the second.

Args:
updated: The metadata dictionary to be updated.
updates: The updates metadata dictionary.
"""
for key, new_value in updates.items():
if new_value == "":
continue
if (key in updated) and ((updated_value := updated[key]) != ""):
merger = self.mergers.get(key, self.default)
if merger is not None:
updated[key] = merger(updated_value, new_value)
else:
updated[key] = new_value

def copy_merged_metadata(
self, from_nodes: Iterable[ir.Node], to: ir.Node | Iterable[ir.Node]
) -> None:
"""Merges metadata from multiple nodes and assigns it to one or more target nodes.

Args:
from_nodes: The source nodes from which to merge metadata.
to: The target node(s) to which the merged metadata will be assigned.
"""
if isinstance(to, ir.Node):
updated = to.metadata_props
for node in from_nodes:
self.update_dict(updated, node.metadata_props)
elif len(to) == 1:
# Handle single node in iterable case
target_node = next(iter(to))
updated = target_node.metadata_props
for node in from_nodes:
self.update_dict(updated, node.metadata_props)
else:
merged_metadata: dict[str, str] = {}
for node in from_nodes:
self.update_dict(merged_metadata, node.metadata_props)
for target_node in to:
self.update_dict(target_node.metadata_props, merged_metadata)
Loading