From 6ecd7ab304eae19db8ec9b036822fa89667e95f0 Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Thu, 28 Aug 2025 19:31:27 +0200 Subject: [PATCH] [Rewriter(matmul_add_to_gemm)]: check shapes Ensure that input shapes are not None before checking their rank. Used _ir_utils.has_rank to handle this safely. --- onnxscript/rewriter/matmul_add_to_gemm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/matmul_add_to_gemm.py b/onnxscript/rewriter/matmul_add_to_gemm.py index 6b63a83e44..dc0364a778 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm.py +++ b/onnxscript/rewriter/matmul_add_to_gemm.py @@ -10,6 +10,7 @@ import abc from typing import ClassVar +from onnxscript.rewriter import _ir_utils from onnxscript.rewriter._basics import MatchResult from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet @@ -30,7 +31,7 @@ def check(self, context, input_a, input_b, **_): del context # Not used check_result = MatchResult() # Rank of input_a and input_b must be 2 - if len(input_a.shape) != 2 or len(input_b.shape) != 2: + if not (_ir_utils.has_rank(input_a, 2) and _ir_utils.has_rank(input_b, 2)): return check_result.fail("Rank of input_a and input_b must be 2") return check_result