diff --git a/onnxscript/rewriter/matmul_add_to_gemm.py b/onnxscript/rewriter/matmul_add_to_gemm.py index 6b63a83e44..dc0364a778 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm.py +++ b/onnxscript/rewriter/matmul_add_to_gemm.py @@ -10,6 +10,7 @@ import abc from typing import ClassVar +from onnxscript.rewriter import _ir_utils from onnxscript.rewriter._basics import MatchResult from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet @@ -30,7 +31,7 @@ def check(self, context, input_a, input_b, **_): del context # Not used check_result = MatchResult() # Rank of input_a and input_b must be 2 - if len(input_a.shape) != 2 or len(input_b.shape) != 2: + if not (_ir_utils.has_rank(input_a, 2) and _ir_utils.has_rank(input_b, 2)): return check_result.fail("Rank of input_a and input_b must be 2") return check_result