From 0080a85e3fc3bd364baa8d78bdf33b18b5836b8b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 09:55:10 -0700 Subject: [PATCH 1/8] Fix CI tests for the latest onnx-ir version Signed-off-by: Justin Chu --- onnxscript/rewriter/_fusion_utils.py | 13 ------------- .../ort_fusions/fused_matmul_rule_sets_test.py | 12 ++++++------ 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index dbf16ae3d3..aa732767a2 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -13,19 +13,6 @@ Dim = Union[int, ir.SymbolicDim] -def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: - if val.shape is None: - return False - if val.shape.rank() != len(shape): - return False - for actual, expected in zip(val.shape, shape): - if expected not in bindings: - bindings[expected] = actual # type: ignore[assignment] - elif actual != bindings[expected]: - return False - return True - - def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]): if val.shape is None: raise MatchFailureError(f"The shape of {val} is unknown.", val) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index 527d4826d5..f82702d557 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -284,7 +284,7 @@ def _check_model( opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul]) expected = ref.run(None, feeds) got = opt.run(None, feeds) - self.assertEqual(len(expected), len(got)) + self.assertEqual(len(got), len(expected)) for a, b in zip(expected, got): np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) @@ -319,7 +319,7 @@ def test_fused_matmul_div_models(self, name, script_func, input_types, output_ty rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Constant", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Constant", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -354,7 +354,7 @@ def test_fused_matmul_with_transpose(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand([("should_not_match", _should_not_match)]) @@ -366,8 +366,8 @@ def test_should_not_match(self, _, script_func): self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) self.assertEqual( - ["Transpose", "MatMul", "Transpose"], [n.op_type for n in ir_model.graph], + ["Transpose", "MatMul", "Transpose"], ) self._check_model(model_proto, rewritten_model, atol=1e-6) @@ -391,7 +391,7 @@ def test_fused_matmul_with_other_node_in_middle(self, _, script_func): common_passes.ShapeInferencePass()(ir_model) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Identity", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Identity", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -440,7 +440,7 @@ def test_transpose_fused_matmul_with_batch(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) From 5917922decf106d454d7615d79a254fbbe8c95e5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 10:00:54 -0700 Subject: [PATCH 2/8] ints Signed-off-by: Justin Chu --- onnxscript/rewriter/_rewrite_rule.py | 4 ++-- onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 9481ca5077..af0165dea0 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -392,7 +392,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False """ @@ -463,7 +463,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 5082c20464..36ba95b3f6 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -188,7 +188,7 @@ def check( trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" trans_batch = fused_node.attributes.get_int(trans_batch_property, 0) transposed_node = _get_node(transposed, "Transpose") - perm = transposed_node.attributes["perm"].as_ints() + perm = list(transposed_node.attributes["perm"].as_ints()) if not perm: return check_result.fail("Permutation values for Transpose are not correct.") From 5e9c766b9871da49a776a26311902cf62cb50303 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 15:04:52 -0700 Subject: [PATCH 3/8] Fix the matcher Signed-off-by: Justin Chu --- onnxscript/rewriter/_pattern_ir.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index f64d3fca3c..e4128c3bdd 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -123,10 +123,17 @@ class AttrConstantPattern(AttrPattern): def __init__(self, value: SupportedAttrTypes): super().__init__(None) + if isinstance(value, Sequence): + value = tuple(value) self._value = value def matches(self, attr: ir.Attr) -> bool: - return isinstance(attr, ir.Attr) and attr.value == self._value + if not isinstance(attr, ir.Attr): + return False + if attr.type in {ir.AttributeType.INTS, ir.AttributeType.FLOATS, ir.AttributeType.STRINGS}: + # Since the type of attr.value is Sequence, we need to convert to the same type for comparison. + return tuple(attr.value) == self._value + return attr.value == self._value def __str__(self) -> str: return str(self._value) From 763c85d7808fd6cf10d3ae112661095bb721994d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 15:07:48 -0700 Subject: [PATCH 4/8] test Signed-off-by: Justin Chu --- onnxscript/rewriter/_pattern_ir.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index e4128c3bdd..8ab6109259 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -128,8 +128,6 @@ def __init__(self, value: SupportedAttrTypes): self._value = value def matches(self, attr: ir.Attr) -> bool: - if not isinstance(attr, ir.Attr): - return False if attr.type in {ir.AttributeType.INTS, ir.AttributeType.FLOATS, ir.AttributeType.STRINGS}: # Since the type of attr.value is Sequence, we need to convert to the same type for comparison. return tuple(attr.value) == self._value From e64083f192089ad5c348060f9ba68b86a12441b5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 15:12:05 -0700 Subject: [PATCH 5/8] Rename function Signed-off-by: Justin Chu --- onnxscript/rewriter/_fusion_utils.py | 13 +++++++++++++ onnxscript/rewriter/ort_fusions/attention.py | 2 +- onnxscript/rewriter/ort_fusions/gqa.py | 2 +- onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py | 2 +- onnxscript/rewriter/ort_fusions/mha.py | 2 +- onnxscript/rewriter/ort_fusions/mha_bias.py | 2 +- .../rewriter/ort_fusions/skip_normalization.py | 4 ++-- 7 files changed, 20 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index aa732767a2..f6a7204ac8 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -13,6 +13,19 @@ Dim = Union[int, ir.SymbolicDim] +def check_shape_bool(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: + if val.shape is None: + return False + if val.shape.rank() != len(shape): + return False + for actual, expected in zip(val.shape, shape): + if expected not in bindings: + bindings[expected] = actual # type: ignore[assignment] + elif actual != bindings[expected]: + return False + return True + + def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]): if val.shape is None: raise MatchFailureError(f"The shape of {val} is unknown.", val) diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index 4a4cd0ad8e..ce234bbb63 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -160,7 +160,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 99852f712a..5fff910bcf 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -247,7 +247,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return False diff --git a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py index 0d404b2754..51355fc8cf 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py +++ b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py @@ -84,7 +84,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) # Check that if x is being split into q, k, v correctly # based on hidden sizes diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index e2987cfc5e..321e895f44 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -157,7 +157,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/mha_bias.py b/onnxscript/rewriter/ort_fusions/mha_bias.py index 28b9646ddc..9ecf2ce017 100644 --- a/onnxscript/rewriter/ort_fusions/mha_bias.py +++ b/onnxscript/rewriter/ort_fusions/mha_bias.py @@ -78,7 +78,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if query_matmul.dtype not in valid_float_types: return check_result.fail("Query is not a float or float16 type.", query_matmul) diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py index f7a376aef9..c76a7454cb 100644 --- a/onnxscript/rewriter/ort_fusions/skip_normalization.py +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -60,7 +60,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( @@ -184,7 +184,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( From 57294478775f9d0fe8e540f507c26b0bfe491d33 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 16:43:02 -0700 Subject: [PATCH 6/8] Fix rules Signed-off-by: Justin Chu --- onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 36ba95b3f6..cdc50c99ae 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -79,7 +79,7 @@ def check( # Check that last two dimensions are swapped expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - if perm != expected_perm: + if list(perm) != expected_perm: return check_result.fail("Permutation values for Transpose are not correct.") elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or ( self._pos == 2 and not _ir_utils.has_rank(y, 2) @@ -296,7 +296,7 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult: if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2): if perm: # Check that the two dimensions are swapped - if perm != [1, 0]: + if tuple(perm) != (1, 0): return check_result.fail( "Permutation values for Transpose are not correct." ) From 9340dcfb164b15365303c9ac1ca8fa71d44fd052 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 16:51:48 -0700 Subject: [PATCH 7/8] Fix pattern matcher Signed-off-by: Justin Chu --- onnxscript/rewriter/_pattern_ir.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index 8ab6109259..9b81e33581 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -123,14 +123,16 @@ class AttrConstantPattern(AttrPattern): def __init__(self, value: SupportedAttrTypes): super().__init__(None) - if isinstance(value, Sequence): - value = tuple(value) self._value = value def matches(self, attr: ir.Attr) -> bool: - if attr.type in {ir.AttributeType.INTS, ir.AttributeType.FLOATS, ir.AttributeType.STRINGS}: + if attr.type in { + ir.AttributeType.INTS, + ir.AttributeType.FLOATS, + ir.AttributeType.STRINGS, + }: # Since the type of attr.value is Sequence, we need to convert to the same type for comparison. - return tuple(attr.value) == self._value + return tuple(attr.value) == tuple(self._value) return attr.value == self._value def __str__(self) -> str: From ecbcaae87300c996350fd8407205c59c17b4b831 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 17:21:22 -0700 Subject: [PATCH 8/8] test_extract_function Signed-off-by: Justin Chu --- onnxscript/rewriter/pattern_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 49ace2fb81..0a29080b4d 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -674,7 +674,7 @@ def test_model(x: FLOAT[1024, 512], y: FLOAT[1024, 512]) -> FLOAT[512, 1024]: function = model.functions[function_id] self.assertEqual([x.op_type for x in function], ["Add", "Transpose"]) transpose_node = function[1] - self.assertEqual(transpose_node.attributes["perm"].value, [1, 0]) + self.assertEqual(list(transpose_node.attributes["perm"].value), [1, 0]) onnxscript.optimizer.inline(model) self.assertEqual([x.op_type for x in model.graph], ["Add", "Transpose"])