Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"RewriterContext",
"MatchingTracer",
"MatchStatus",
"RULE_NAME_TAG",
]

import onnx
Expand All @@ -25,6 +26,7 @@
from onnxscript.rewriter import pattern
from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus
from onnxscript.rewriter._rewrite_rule import (
RULE_NAME_TAG,
RewriterContext,
RewriteRule,
RewriteRuleClassBase,
Expand Down
12 changes: 12 additions & 0 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@

RewriterContext = _tape.Builder

# TODO(rama): Standardize metadata property keys. May be worth standardizing at ONNX level for
# source/producer metadata.

RULE_NAME_TAG = "pkg.onnxscript.rewriter.rule_name"


@dataclasses.dataclass
class ReplacementSubgraph:
Expand Down Expand Up @@ -719,6 +724,13 @@ def _apply_to_graph_or_function(
_ir_utils.display_nodes(delta.new_nodes)
print("++++End Replacement Nodes++++")

# Capture rewrite rule name as metadata.
# TODO(rama): This is just a basic version. We may wish to compose "source" metadata
# from multiple rules in future.
if rule.name:
for n in delta.new_nodes:
n.metadata_props[RULE_NAME_TAG] = rule.name

convenience.replace_nodes_and_values(
graph_or_function,
node,
Expand Down
39 changes: 39 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import onnx.parser

import onnxscript.optimizer
import onnxscript.rewriter
from onnxscript import FLOAT, ir, script
from onnxscript import opset17 as op
from onnxscript.rewriter import pattern
Expand Down Expand Up @@ -936,6 +937,44 @@ def add_pattern(op, x, y):
match_result = rule_pattern.match(model, model.graph, add_nodes[2])
self.assertFalse(bool(match_result))

def test_rule_name_metadata(self):
"""Test that RewriteRule carries name metadata."""

class ReciprocalMulRule(pattern.RewriteRuleClassBase):
def __init__(self, name: str | None = None):
super().__init__(name)

def pattern(self, op, x, y):
return (1 / x) * y

def rewrite(self, op, x, y):
return op.Div(y, x)

@script()
def test_script(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]:
return op.Mul(op.Div(op.Constant(value_float=1.0), x), y)

rule = ReciprocalMulRule.rule(name="ReciprocalMulToDiv")
model_proto = test_script.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
count = rule.apply_to_model(model)
self.assertEqual(count, 1)
for node in model.graph:
if node.op_type == "Div":
tag = onnxscript.rewriter.RULE_NAME_TAG
self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulToDiv")

# By default, the rule name is the class name (if not provided)
rule = ReciprocalMulRule.rule()
model_proto = test_script.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
count = rule.apply_to_model(model)
self.assertEqual(count, 1)
for node in model.graph:
if node.op_type == "Div":
tag = onnxscript.rewriter.RULE_NAME_TAG
self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulRule")


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
Loading