Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/rewriter/matmul_add_to_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import abc
from typing import ClassVar

from onnxscript.rewriter import _ir_utils
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet

Expand All @@ -30,7 +31,7 @@ def check(self, context, input_a, input_b, **_):
del context # Not used
check_result = MatchResult()
# Rank of input_a and input_b must be 2
if len(input_a.shape) != 2 or len(input_b.shape) != 2:
if not (_ir_utils.has_rank(input_a, 2) and _ir_utils.has_rank(input_b, 2)):
return check_result.fail("Rank of input_a and input_b must be 2")
return check_result

Expand Down
Loading