Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 75 additions & 19 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,11 @@
return x
if isinstance(x, (int, float)):
return Constant(x)
# TODO(rama): support lists of int/float
# if isinstance(x, list):
# if all(isinstance(i, (int, float)) for i in x):
# return Constant(x)
# raise ValueError("Only lists of int/float can be used as a ValuePattern")
# TODO(titaiwang): Could this be wrapped Constant?
if isinstance(x, Sequence):
if all(isinstance(i, (int, float)) for i in x):
return Constant(x)
raise ValueError("Only lists of int/float can be used as a ValuePattern")

Check warning on line 288 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L287-L288

Added lines #L287 - L288 were not covered by tests

raise TypeError(f"Cannot convert {type(x)} to ValuePattern")


Expand Down Expand Up @@ -602,10 +601,13 @@
"""Represents a pattern that matches against a scalar constant value."""

def __init__(
self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8
self,
value: int | float | Sequence[int] | Sequence[float],
rel_tol: float = 1e-5,
abs_tol: float = 1e-8,
) -> None:
super().__init__(None)
self._value = value
self._value = list(value) if isinstance(value, Sequence) else value
self._rel_tol = rel_tol
self._abs_tol = abs_tol

Expand All @@ -614,7 +616,7 @@
return Constant(self._value, self._rel_tol, self._abs_tol)

@property
def value(self) -> int | float:
def value(self) -> int | float | list[int] | list[float]:
return self._value

def matches(self, value: ir.Value, match: MatchResult) -> MatchResult:
Expand All @@ -623,6 +625,24 @@
return match.fail(f"Value is not a constant, expecting {self.value}.")

constant_value_numpy = constant_value.numpy()
if isinstance(self._value, list):
if constant_value_numpy.shape != (len(self._value),):
return match.fail(f"Value has mismatching shape, expecting ({self.value},).")

Check warning on line 630 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L630

Added line #L630 was not covered by tests
if not all(
math.isclose(
constant_value_numpy.item(i),
self._value[i],
rel_tol=self._rel_tol,
abs_tol=self._abs_tol,
)
for i in range(len(self._value))
):
return match.fail(

Check warning on line 640 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L640

Added line #L640 was not covered by tests
f"Value mismatch: expected {self._value}, got {constant_value_numpy}."
)
return match

Check warning on line 643 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L643

Added line #L643 was not covered by tests

# Scalar constant case:
# TODO (rama): allow users to specify shape requirement, if desired.
if constant_value_numpy.size != 1:
return match.fail(f"Value is not a scalar, expecting {self.value}.")
Expand Down Expand Up @@ -664,6 +684,20 @@
return node_patterns


def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> None:
"""Adds all nodes in the backward slice of given node to the set `backward_slice`.

The backward slice of a node is the set of all nodes that are reachable from the node
in a backward traversal from the given node.
"""
if node in backward_slice:
return
backward_slice.add(node)
for value_pattern in node.inputs:
if isinstance(value_pattern, NodeOutputPattern):
_add_backward_slice(value_pattern.producer(), backward_slice)


class GraphPattern:
"""Represents a pattern that can be matched against a subgraph."""

Expand All @@ -679,8 +713,10 @@
raise ValueError("GraphPattern must have at least one output")
self._nodes = nodes # _nodes_in_pattern(outputs)

# Check if all outputs are produced by the same node.
# Determine the output nodes of the pattern. These are a minimal set of nodes
# whose backward-slices cover the entire pattern.
output_nodes: set[NodePattern] = set()
covered: set[NodePattern] = set()
for value_pattern in outputs:
if not isinstance(value_pattern, ValuePattern):
raise TypeError(
Expand All @@ -691,7 +727,11 @@
"Constant values are not allowed as graph pattern outputs."
)
if isinstance(value_pattern, NodeOutputPattern):
output_nodes.add(value_pattern.producer())
candidate = value_pattern.producer()
if candidate not in covered:
output_nodes.add(candidate)
_add_backward_slice(candidate, covered)

self.output_nodes: list[NodePattern] = list(output_nodes)

@property
Expand Down Expand Up @@ -924,20 +964,41 @@
constant_value_numpy = constant_value.numpy()
except FileNotFoundError:
return self.fail(f"Constant value of {value.name} not available.")

pattern_constant_value = pattern_constant._value

if isinstance(pattern_constant_value, list):
expected_shape = (len(pattern_constant_value),)

Check warning on line 971 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L971

Added line #L971 was not covered by tests
if constant_value_numpy.shape != expected_shape:
return self.fail(f"Value has mismatching shape, expecting {expected_shape}.")

Check warning on line 973 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L973

Added line #L973 was not covered by tests
if not all(
math.isclose(
constant_value_numpy.item(i),
pattern_constant_value[i],
rel_tol=pattern_constant._rel_tol,
abs_tol=pattern_constant._abs_tol,
)
for i in range(len(pattern_constant_value))
):
return self.fail(

Check warning on line 983 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L983

Added line #L983 was not covered by tests
f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}."
)
return True

Check warning on line 986 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L986

Added line #L986 was not covered by tests

# TODO (rama): allow users to specify shape requirement, if desired.
if constant_value_numpy.size != 1:
return self.fail(
f"Value {value.name} is not a scalar, expecting {pattern_constant.value}.",
f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.",
)

if not math.isclose(
constant_value_numpy.item(),
pattern_constant._value,
pattern_constant_value,
rel_tol=pattern_constant._rel_tol,
abs_tol=pattern_constant._abs_tol,
):
return self.fail(
f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value_numpy.item()}.",
f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.",
)

return True
Expand Down Expand Up @@ -1079,11 +1140,6 @@
if not _valid_to_replace(match.nodes, output_values):
return match.fail("Matched nodes have other uses preventing replacement.")

if len(node.outputs) != pattern.num_outputs:
return match.fail(
f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}."
)

match.outputs.extend(output_values)
return match

Expand Down
Loading