From c60d048c50e246e08d11667a9b25cdb1d8f7ff73 Mon Sep 17 00:00:00 2001 From: xuzhenqi Date: Fri, 27 Sep 2024 14:20:04 +0800 Subject: [PATCH 1/4] Check inputs num for node matching --- onnxscript/rewriter/pattern.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index be265963c2..ef61a8e3b5 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -959,6 +959,9 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node + if len(node.inputs) != len(pattern_node.inputs): + return self.fail("Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}") + for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): # arg_pattern could be a Var, if it's the original arg. if arg_pattern is None: From 07bad1e06cf7dd1be981e3e6e15470b5b3c169d0 Mon Sep 17 00:00:00 2001 From: xuzhenqi Date: Tue, 8 Oct 2024 10:51:41 +0800 Subject: [PATCH 2/4] Add comment --- onnxscript/rewriter/pattern.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index ef61a8e3b5..ddff1b93b0 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -959,6 +959,7 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node + # Note: Need to revisit this to handle optional trailing inputs better. if len(node.inputs) != len(pattern_node.inputs): return self.fail("Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}") From 4419f276623a59ae9f680460babb8bf6b07f67f0 Mon Sep 17 00:00:00 2001 From: xuzhenqi Date: Wed, 9 Oct 2024 10:34:34 +0800 Subject: [PATCH 3/4] Fix lint errors Signed-off-by: xuzhenqi --- onnxscript/rewriter/pattern.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index ddff1b93b0..6788bc907e 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -959,9 +959,11 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node - # Note: Need to revisit this to handle optional trailing inputs better. + # Note: Need to revisit this to handle optional trailing inputs better. if len(node.inputs) != len(pattern_node.inputs): - return self.fail("Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}") + return self.fail( + "Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" + ) for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): # arg_pattern could be a Var, if it's the original arg. From 3aaa86682d1523649100b09e98a4309491a7f81a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 9 Oct 2024 07:40:13 -0700 Subject: [PATCH 4/4] Update onnxscript/rewriter/pattern.py --- onnxscript/rewriter/pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 6788bc907e..1f00840d47 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -959,7 +959,7 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node - # Note: Need to revisit this to handle optional trailing inputs better. + # TODO: Revisit this to handle optional trailing inputs better. if len(node.inputs) != len(pattern_node.inputs): return self.fail( "Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}"