From 0f9a03938f8dc6a1e3cb8046ac6358b8784c342a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 29 Aug 2025 10:25:56 -0700 Subject: [PATCH 01/12] Refactor rewrite rules into the rewriter.rules namespace Organize all rules into a directory that is not with the rewriter infrastructure Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 36 +++++++++---------- onnxscript/rewriter/no_op_test.py | 2 +- onnxscript/rewriter/ort_fusions/_core.py | 3 +- onnxscript/rewriter/pattern_test.py | 3 +- .../rewriter/{ => rules}/basic_rules.py | 0 .../rewriter/{ => rules}/basic_rules_test.py | 2 +- .../{ => rules}/broadcast_to_matmul.py | 0 .../{ => rules}/broadcast_to_matmul_test.py | 2 +- .../{ => rules}/cast_constant_of_shape.py | 0 .../cast_constant_of_shape_test.py | 2 +- .../rewriter/{ => rules}/collapse_slices.py | 0 .../{ => rules}/collapse_slices_test.py | 3 +- .../rewriter/{ => rules}/fuse_batchnorm.py | 0 .../{ => rules}/fuse_batchnorm_test.py | 3 +- .../{ => rules}/fuse_pad_into_conv.py | 0 .../{ => rules}/fuse_pad_into_conv_test.py | 2 +- .../rewriter/{ => rules}/fuse_relus_clips.py | 0 .../{ => rules}/fuse_relus_clips_test.py | 4 +-- .../{ => rules}/gemm_to_matmul_add.py | 2 +- .../{ => rules}/gemm_to_matmul_add_test.py | 2 +- .../{ => rules}/matmul_add_to_gemm.py | 0 .../{ => rules}/matmul_add_to_gemm_test.py | 5 +-- onnxscript/rewriter/{ => rules}/no_op.py | 0 .../{ => rules}/redundant_scatter_nd.py | 0 .../{ => rules}/redundant_scatter_nd_test.py | 2 +- 25 files changed, 38 insertions(+), 35 deletions(-) rename onnxscript/rewriter/{ => rules}/basic_rules.py (100%) rename onnxscript/rewriter/{ => rules}/basic_rules_test.py (99%) rename onnxscript/rewriter/{ => rules}/broadcast_to_matmul.py (100%) rename onnxscript/rewriter/{ => rules}/broadcast_to_matmul_test.py (99%) rename onnxscript/rewriter/{ => rules}/cast_constant_of_shape.py (100%) rename onnxscript/rewriter/{ => rules}/cast_constant_of_shape_test.py (96%) rename onnxscript/rewriter/{ => rules}/collapse_slices.py (100%) rename onnxscript/rewriter/{ => rules}/collapse_slices_test.py (98%) rename onnxscript/rewriter/{ => rules}/fuse_batchnorm.py (100%) rename onnxscript/rewriter/{ => rules}/fuse_batchnorm_test.py (98%) rename onnxscript/rewriter/{ => rules}/fuse_pad_into_conv.py (100%) rename onnxscript/rewriter/{ => rules}/fuse_pad_into_conv_test.py (99%) rename onnxscript/rewriter/{ => rules}/fuse_relus_clips.py (100%) rename onnxscript/rewriter/{ => rules}/fuse_relus_clips_test.py (99%) rename onnxscript/rewriter/{ => rules}/gemm_to_matmul_add.py (89%) rename onnxscript/rewriter/{ => rules}/gemm_to_matmul_add_test.py (99%) rename onnxscript/rewriter/{ => rules}/matmul_add_to_gemm.py (100%) rename onnxscript/rewriter/{ => rules}/matmul_add_to_gemm_test.py (98%) rename onnxscript/rewriter/{ => rules}/no_op.py (100%) rename onnxscript/rewriter/{ => rules}/redundant_scatter_nd.py (100%) rename onnxscript/rewriter/{ => rules}/redundant_scatter_nd_test.py (98%) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index d3e7a7891e..bb30c4237a 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -22,17 +22,15 @@ import onnx_ir.passes.common as common_passes from onnxscript import ir -from onnxscript.rewriter import ( - basic_rules, - broadcast_to_matmul, - cast_constant_of_shape, - collapse_slices, - fuse_pad_into_conv, - fuse_relus_clips, - no_op, - pattern, - redundant_scatter_nd, -) +from onnxscript.rewriter.rules import basic_rules as _basic_rules +from onnxscript.rewriter.rules import broadcast_to_matmul as _broadcast_to_matmul +from onnxscript.rewriter.rules import cast_constant_of_shape as _cast_constant_of_shape +from onnxscript.rewriter.rules import collapse_slices as _collapse_slices +from onnxscript.rewriter.rules import fuse_pad_into_conv as _fuse_pad_into_conv +from onnxscript.rewriter.rules import fuse_relus_clips as _fuse_relus_clips +from onnxscript.rewriter.rules import no_op as _no_op +from onnxscript.rewriter import pattern +from onnxscript.rewriter.rules import redundant_scatter_nd as _redundant_scatter_nd from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus from onnxscript.rewriter._rewrite_rule import ( RewriterContext, @@ -43,14 +41,14 @@ _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( - *no_op.rules.rules, # TODO: merge this rule into constant folding? - *broadcast_to_matmul.rules.rules, - *cast_constant_of_shape.rules.rules, - *collapse_slices.rules.rules, - *fuse_relus_clips.fuse_relus_clips_rules().rules, - *basic_rules.basic_optimization_rules().rules, - *redundant_scatter_nd.rules.rules, - *fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules, + *_no_op.rules.rules, # TODO: merge this rule into constant folding? + *_broadcast_to_matmul.rules.rules, + *_cast_constant_of_shape.rules.rules, + *_collapse_slices.rules.rules, + *_fuse_relus_clips.fuse_relus_clips_rules().rules, + *_basic_rules.basic_optimization_rules().rules, + *_redundant_scatter_nd.rules.rules, + *_fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules, ) diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/no_op_test.py index 2b2a57f32a..bdfcfd857f 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/no_op_test.py @@ -5,7 +5,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import no_op +from onnxscript.rewriter.rules import no_op class NoOpTest(unittest.TestCase): diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index ed33807db9..72c18ebeca 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -8,7 +8,7 @@ import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization from onnxscript.optimizer import optimize -from onnxscript.rewriter import gemm_to_matmul_add, rewrite +from onnxscript.rewriter import rewrite from onnxscript.rewriter.ort_fusions import ( instance_to_group_normalization, softmax, @@ -33,6 +33,7 @@ fuse_skip_layer_normalization, fuse_skip_rms_normalization, ) +from onnxscript.rewriter.rules import gemm_to_matmul_add ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index bf5940e97c..30ea175984 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -12,7 +12,8 @@ import onnxscript.optimizer from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op -from onnxscript.rewriter import cast_constant_of_shape, pattern +from onnxscript.rewriter import pattern +from onnxscript.rewriter.rules import cast_constant_of_shape logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/basic_rules.py b/onnxscript/rewriter/rules/basic_rules.py similarity index 100% rename from onnxscript/rewriter/basic_rules.py rename to onnxscript/rewriter/rules/basic_rules.py diff --git a/onnxscript/rewriter/basic_rules_test.py b/onnxscript/rewriter/rules/basic_rules_test.py similarity index 99% rename from onnxscript/rewriter/basic_rules_test.py rename to onnxscript/rewriter/rules/basic_rules_test.py index bcb6db4aa8..f4bf00c617 100644 --- a/onnxscript/rewriter/basic_rules_test.py +++ b/onnxscript/rewriter/rules/basic_rules_test.py @@ -12,7 +12,7 @@ import onnxscript import onnxscript.onnx_types as ot -import onnxscript.rewriter.basic_rules as basic_rules +import onnxscript.rewriter.rules.basic_rules as basic_rules from onnxscript import ir from onnxscript.onnx_opset import opset18 diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/rules/broadcast_to_matmul.py similarity index 100% rename from onnxscript/rewriter/broadcast_to_matmul.py rename to onnxscript/rewriter/rules/broadcast_to_matmul.py diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/rules/broadcast_to_matmul_test.py similarity index 99% rename from onnxscript/rewriter/broadcast_to_matmul_test.py rename to onnxscript/rewriter/rules/broadcast_to_matmul_test.py index c2f3b31f90..7d339521bd 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/rules/broadcast_to_matmul_test.py @@ -9,7 +9,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import broadcast_to_matmul +from onnxscript.rewriter.rules import broadcast_to_matmul def _infer_shapes(model: ir.Model) -> ir.Model: diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/rules/cast_constant_of_shape.py similarity index 100% rename from onnxscript/rewriter/cast_constant_of_shape.py rename to onnxscript/rewriter/rules/cast_constant_of_shape.py diff --git a/onnxscript/rewriter/cast_constant_of_shape_test.py b/onnxscript/rewriter/rules/cast_constant_of_shape_test.py similarity index 96% rename from onnxscript/rewriter/cast_constant_of_shape_test.py rename to onnxscript/rewriter/rules/cast_constant_of_shape_test.py index 35151e17d9..24e4fb0d49 100644 --- a/onnxscript/rewriter/cast_constant_of_shape_test.py +++ b/onnxscript/rewriter/rules/cast_constant_of_shape_test.py @@ -6,7 +6,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter import cast_constant_of_shape +from onnxscript.rewriter.rules import cast_constant_of_shape class CastConstantOfShapeTest(unittest.TestCase): diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/rules/collapse_slices.py similarity index 100% rename from onnxscript/rewriter/collapse_slices.py rename to onnxscript/rewriter/rules/collapse_slices.py diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/rules/collapse_slices_test.py similarity index 98% rename from onnxscript/rewriter/collapse_slices_test.py rename to onnxscript/rewriter/rules/collapse_slices_test.py index 52b59f9037..91fc86f6e2 100644 --- a/onnxscript/rewriter/collapse_slices_test.py +++ b/onnxscript/rewriter/rules/collapse_slices_test.py @@ -9,7 +9,8 @@ import onnx.shape_inference from onnxscript import ir -from onnxscript.rewriter import collapse_slices, testing +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules import collapse_slices _INT64_MAX = 9223372036854775807 diff --git a/onnxscript/rewriter/fuse_batchnorm.py b/onnxscript/rewriter/rules/fuse_batchnorm.py similarity index 100% rename from onnxscript/rewriter/fuse_batchnorm.py rename to onnxscript/rewriter/rules/fuse_batchnorm.py diff --git a/onnxscript/rewriter/fuse_batchnorm_test.py b/onnxscript/rewriter/rules/fuse_batchnorm_test.py similarity index 98% rename from onnxscript/rewriter/fuse_batchnorm_test.py rename to onnxscript/rewriter/rules/fuse_batchnorm_test.py index 20d272abd7..17fc99e415 100644 --- a/onnxscript/rewriter/fuse_batchnorm_test.py +++ b/onnxscript/rewriter/rules/fuse_batchnorm_test.py @@ -8,7 +8,8 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import fuse_batchnorm, testing +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules import fuse_batchnorm class FuseBatchnormTest(unittest.TestCase): diff --git a/onnxscript/rewriter/fuse_pad_into_conv.py b/onnxscript/rewriter/rules/fuse_pad_into_conv.py similarity index 100% rename from onnxscript/rewriter/fuse_pad_into_conv.py rename to onnxscript/rewriter/rules/fuse_pad_into_conv.py diff --git a/onnxscript/rewriter/fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/fuse_pad_into_conv_test.py similarity index 99% rename from onnxscript/rewriter/fuse_pad_into_conv_test.py rename to onnxscript/rewriter/rules/fuse_pad_into_conv_test.py index dfbf117bd1..4f3fc6ca08 100644 --- a/onnxscript/rewriter/fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/fuse_pad_into_conv_test.py @@ -12,7 +12,7 @@ from onnxscript.rewriter import pattern as orp from onnxscript.rewriter import testing -from onnxscript.rewriter.fuse_pad_into_conv import ( +from onnxscript.rewriter.rules.fuse_pad_into_conv import ( fuse_pad_into_conv, fuse_pad_into_conv_rule_set, normalize_pad_format_conv, diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/rules/fuse_relus_clips.py similarity index 100% rename from onnxscript/rewriter/fuse_relus_clips.py rename to onnxscript/rewriter/rules/fuse_relus_clips.py diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/rules/fuse_relus_clips_test.py similarity index 99% rename from onnxscript/rewriter/fuse_relus_clips_test.py rename to onnxscript/rewriter/rules/fuse_relus_clips_test.py index d58b493fb4..4f992cfa03 100644 --- a/onnxscript/rewriter/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/fuse_relus_clips_test.py @@ -13,14 +13,14 @@ MatchingTracer, MatchStatus, RewriteRule, - fuse_relus_clips, testing, ) -from onnxscript.rewriter.fuse_relus_clips import ( +from onnxscript.rewriter.rules.fuse_relus_clips import ( fuse_successive_clip_relu_rule, fuse_successive_clip_rule, fuse_successive_relu_clip_rule, ) +from onnxscript.rewriter.rules import fuse_relus_clips class _FuseReluClipTestBase(unittest.TestCase): diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/rules/gemm_to_matmul_add.py similarity index 89% rename from onnxscript/rewriter/gemm_to_matmul_add.py rename to onnxscript/rewriter/rules/gemm_to_matmul_add.py index 09666466d3..0654755235 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/gemm_to_matmul_add.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from onnxscript.rewriter._rewrite_rule import RewriteRule -from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape +from onnxscript.rewriter.rules.broadcast_to_matmul import check_if_not_need_reshape # Pattern to match against diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/rules/gemm_to_matmul_add_test.py similarity index 99% rename from onnxscript/rewriter/gemm_to_matmul_add_test.py rename to onnxscript/rewriter/rules/gemm_to_matmul_add_test.py index aab56cc3fe..71fcf34a9d 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/rules/gemm_to_matmul_add_test.py @@ -5,7 +5,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter import gemm_to_matmul_add +from onnxscript.rewriter.rules import gemm_to_matmul_add class ReshapeGemmReshapeTest(unittest.TestCase): diff --git a/onnxscript/rewriter/matmul_add_to_gemm.py b/onnxscript/rewriter/rules/matmul_add_to_gemm.py similarity index 100% rename from onnxscript/rewriter/matmul_add_to_gemm.py rename to onnxscript/rewriter/rules/matmul_add_to_gemm.py diff --git a/onnxscript/rewriter/matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py similarity index 98% rename from onnxscript/rewriter/matmul_add_to_gemm_test.py rename to onnxscript/rewriter/rules/matmul_add_to_gemm_test.py index fd08125807..69d076636f 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py @@ -9,8 +9,9 @@ from parameterized import parameterized from onnxscript import ir -from onnxscript.rewriter import MatchingTracer, MatchStatus, matmul_add_to_gemm, testing -from onnxscript.rewriter.matmul_add_to_gemm import matmul_add_to_gemm_rule +from onnxscript.rewriter import MatchingTracer, MatchStatus, testing +from onnxscript.rewriter.rules.matmul_add_to_gemm import matmul_add_to_gemm_rule +from onnxscript.rewriter.rules import matmul_add_to_gemm class _MatMulAddToGemmTestBase(unittest.TestCase): diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/rules/no_op.py similarity index 100% rename from onnxscript/rewriter/no_op.py rename to onnxscript/rewriter/rules/no_op.py diff --git a/onnxscript/rewriter/redundant_scatter_nd.py b/onnxscript/rewriter/rules/redundant_scatter_nd.py similarity index 100% rename from onnxscript/rewriter/redundant_scatter_nd.py rename to onnxscript/rewriter/rules/redundant_scatter_nd.py diff --git a/onnxscript/rewriter/redundant_scatter_nd_test.py b/onnxscript/rewriter/rules/redundant_scatter_nd_test.py similarity index 98% rename from onnxscript/rewriter/redundant_scatter_nd_test.py rename to onnxscript/rewriter/rules/redundant_scatter_nd_test.py index d2ba51eec4..fd9c2ab300 100644 --- a/onnxscript/rewriter/redundant_scatter_nd_test.py +++ b/onnxscript/rewriter/rules/redundant_scatter_nd_test.py @@ -13,7 +13,7 @@ import onnxscript.optimizer from onnxscript import FLOAT, script from onnxscript import opset18 as op -from onnxscript.rewriter import redundant_scatter_nd +from onnxscript.rewriter.rules import redundant_scatter_nd shape_inference = ShapeInferencePass() onnx_check = CheckerPass(True) From 72ea6665a0cc4f1d1a1e2cdffe9e43441cd6d2a5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 29 Aug 2025 10:56:14 -0700 Subject: [PATCH 02/12] Format Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 16 ++++++++-------- .../rewriter/rules/fuse_relus_clips_test.py | 2 +- .../rewriter/rules/matmul_add_to_gemm_test.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index bb30c4237a..54dd1c231f 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -22,15 +22,7 @@ import onnx_ir.passes.common as common_passes from onnxscript import ir -from onnxscript.rewriter.rules import basic_rules as _basic_rules -from onnxscript.rewriter.rules import broadcast_to_matmul as _broadcast_to_matmul -from onnxscript.rewriter.rules import cast_constant_of_shape as _cast_constant_of_shape -from onnxscript.rewriter.rules import collapse_slices as _collapse_slices -from onnxscript.rewriter.rules import fuse_pad_into_conv as _fuse_pad_into_conv -from onnxscript.rewriter.rules import fuse_relus_clips as _fuse_relus_clips -from onnxscript.rewriter.rules import no_op as _no_op from onnxscript.rewriter import pattern -from onnxscript.rewriter.rules import redundant_scatter_nd as _redundant_scatter_nd from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus from onnxscript.rewriter._rewrite_rule import ( RewriterContext, @@ -38,6 +30,14 @@ RewriteRuleClassBase, RewriteRuleSet, ) +from onnxscript.rewriter.rules import basic_rules as _basic_rules +from onnxscript.rewriter.rules import broadcast_to_matmul as _broadcast_to_matmul +from onnxscript.rewriter.rules import cast_constant_of_shape as _cast_constant_of_shape +from onnxscript.rewriter.rules import collapse_slices as _collapse_slices +from onnxscript.rewriter.rules import fuse_pad_into_conv as _fuse_pad_into_conv +from onnxscript.rewriter.rules import fuse_relus_clips as _fuse_relus_clips +from onnxscript.rewriter.rules import no_op as _no_op +from onnxscript.rewriter.rules import redundant_scatter_nd as _redundant_scatter_nd _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( diff --git a/onnxscript/rewriter/rules/fuse_relus_clips_test.py b/onnxscript/rewriter/rules/fuse_relus_clips_test.py index 4f992cfa03..7dfa4de430 100644 --- a/onnxscript/rewriter/rules/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/fuse_relus_clips_test.py @@ -15,12 +15,12 @@ RewriteRule, testing, ) +from onnxscript.rewriter.rules import fuse_relus_clips from onnxscript.rewriter.rules.fuse_relus_clips import ( fuse_successive_clip_relu_rule, fuse_successive_clip_rule, fuse_successive_relu_clip_rule, ) -from onnxscript.rewriter.rules import fuse_relus_clips class _FuseReluClipTestBase(unittest.TestCase): diff --git a/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py index 69d076636f..c31d4ab6c6 100644 --- a/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py @@ -10,8 +10,8 @@ from onnxscript import ir from onnxscript.rewriter import MatchingTracer, MatchStatus, testing -from onnxscript.rewriter.rules.matmul_add_to_gemm import matmul_add_to_gemm_rule from onnxscript.rewriter.rules import matmul_add_to_gemm +from onnxscript.rewriter.rules.matmul_add_to_gemm import matmul_add_to_gemm_rule class _MatMulAddToGemmTestBase(unittest.TestCase): From 05b620e932fdaac8c1c8361a28668e5c54bb8962 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 29 Aug 2025 15:18:43 -0700 Subject: [PATCH 03/12] Move around Signed-off-by: Justin Chu --- onnxscript/rewriter/rules/{ => common}/basic_rules.py | 0 onnxscript/rewriter/rules/{ => common}/basic_rules_test.py | 2 +- onnxscript/rewriter/rules/{ => common}/broadcast_to_matmul.py | 0 .../rewriter/rules/{ => common}/broadcast_to_matmul_test.py | 0 .../rewriter/rules/{ => common}/cast_constant_of_shape.py | 0 .../rewriter/rules/{ => common}/cast_constant_of_shape_test.py | 0 onnxscript/rewriter/rules/{ => common}/collapse_slices.py | 0 onnxscript/rewriter/rules/{ => common}/collapse_slices_test.py | 0 onnxscript/rewriter/rules/{ => common}/fuse_batchnorm.py | 0 onnxscript/rewriter/rules/{ => common}/fuse_batchnorm_test.py | 0 onnxscript/rewriter/rules/{ => common}/fuse_pad_into_conv.py | 0 .../rewriter/rules/{ => common}/fuse_pad_into_conv_test.py | 2 +- onnxscript/rewriter/rules/{ => common}/fuse_relus_clips.py | 0 onnxscript/rewriter/rules/{ => common}/fuse_relus_clips_test.py | 2 +- onnxscript/rewriter/rules/{ => common}/gemm_to_matmul_add.py | 2 +- .../rewriter/rules/{ => common}/gemm_to_matmul_add_test.py | 0 onnxscript/rewriter/rules/{ => common}/matmul_add_to_gemm.py | 0 .../rewriter/rules/{ => common}/matmul_add_to_gemm_test.py | 2 +- onnxscript/rewriter/rules/{ => common}/no_op.py | 0 onnxscript/rewriter/rules/{ => common}/redundant_scatter_nd.py | 0 .../rewriter/rules/{ => common}/redundant_scatter_nd_test.py | 0 .../rewriter/{onnx_fusions => rules/fusion}/_layer_norm.py | 0 .../rewriter/{onnx_fusions => rules/fusion}/_layer_norm_test.py | 2 +- .../rewriter/{onnx_fusions => rules/fusion}/_onnx_fusions.py | 2 +- .../{onnx_fusions => rules/fusion}/_onnx_fusions_test.py | 2 +- .../{onnx_fusions => rules/fusion}/_rms_normalization.py | 0 .../{onnx_fusions => rules/fusion}/_rotary_embedding.py | 0 27 files changed, 8 insertions(+), 8 deletions(-) rename onnxscript/rewriter/rules/{ => common}/basic_rules.py (100%) rename onnxscript/rewriter/rules/{ => common}/basic_rules_test.py (99%) rename onnxscript/rewriter/rules/{ => common}/broadcast_to_matmul.py (100%) rename onnxscript/rewriter/rules/{ => common}/broadcast_to_matmul_test.py (100%) rename onnxscript/rewriter/rules/{ => common}/cast_constant_of_shape.py (100%) rename onnxscript/rewriter/rules/{ => common}/cast_constant_of_shape_test.py (100%) rename onnxscript/rewriter/rules/{ => common}/collapse_slices.py (100%) rename onnxscript/rewriter/rules/{ => common}/collapse_slices_test.py (100%) rename onnxscript/rewriter/rules/{ => common}/fuse_batchnorm.py (100%) rename onnxscript/rewriter/rules/{ => common}/fuse_batchnorm_test.py (100%) rename onnxscript/rewriter/rules/{ => common}/fuse_pad_into_conv.py (100%) rename onnxscript/rewriter/rules/{ => common}/fuse_pad_into_conv_test.py (99%) rename onnxscript/rewriter/rules/{ => common}/fuse_relus_clips.py (100%) rename onnxscript/rewriter/rules/{ => common}/fuse_relus_clips_test.py (99%) rename onnxscript/rewriter/rules/{ => common}/gemm_to_matmul_add.py (89%) rename onnxscript/rewriter/rules/{ => common}/gemm_to_matmul_add_test.py (100%) rename onnxscript/rewriter/rules/{ => common}/matmul_add_to_gemm.py (100%) rename onnxscript/rewriter/rules/{ => common}/matmul_add_to_gemm_test.py (99%) rename onnxscript/rewriter/rules/{ => common}/no_op.py (100%) rename onnxscript/rewriter/rules/{ => common}/redundant_scatter_nd.py (100%) rename onnxscript/rewriter/rules/{ => common}/redundant_scatter_nd_test.py (100%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_layer_norm.py (100%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_layer_norm_test.py (98%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_onnx_fusions.py (95%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_onnx_fusions_test.py (97%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_rms_normalization.py (100%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_rotary_embedding.py (100%) diff --git a/onnxscript/rewriter/rules/basic_rules.py b/onnxscript/rewriter/rules/common/basic_rules.py similarity index 100% rename from onnxscript/rewriter/rules/basic_rules.py rename to onnxscript/rewriter/rules/common/basic_rules.py diff --git a/onnxscript/rewriter/rules/basic_rules_test.py b/onnxscript/rewriter/rules/common/basic_rules_test.py similarity index 99% rename from onnxscript/rewriter/rules/basic_rules_test.py rename to onnxscript/rewriter/rules/common/basic_rules_test.py index f4bf00c617..b876e4fce0 100644 --- a/onnxscript/rewriter/rules/basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/basic_rules_test.py @@ -12,7 +12,7 @@ import onnxscript import onnxscript.onnx_types as ot -import onnxscript.rewriter.rules.basic_rules as basic_rules +import onnxscript.rewriter.rules.common.basic_rules as basic_rules from onnxscript import ir from onnxscript.onnx_opset import opset18 diff --git a/onnxscript/rewriter/rules/broadcast_to_matmul.py b/onnxscript/rewriter/rules/common/broadcast_to_matmul.py similarity index 100% rename from onnxscript/rewriter/rules/broadcast_to_matmul.py rename to onnxscript/rewriter/rules/common/broadcast_to_matmul.py diff --git a/onnxscript/rewriter/rules/broadcast_to_matmul_test.py b/onnxscript/rewriter/rules/common/broadcast_to_matmul_test.py similarity index 100% rename from onnxscript/rewriter/rules/broadcast_to_matmul_test.py rename to onnxscript/rewriter/rules/common/broadcast_to_matmul_test.py diff --git a/onnxscript/rewriter/rules/cast_constant_of_shape.py b/onnxscript/rewriter/rules/common/cast_constant_of_shape.py similarity index 100% rename from onnxscript/rewriter/rules/cast_constant_of_shape.py rename to onnxscript/rewriter/rules/common/cast_constant_of_shape.py diff --git a/onnxscript/rewriter/rules/cast_constant_of_shape_test.py b/onnxscript/rewriter/rules/common/cast_constant_of_shape_test.py similarity index 100% rename from onnxscript/rewriter/rules/cast_constant_of_shape_test.py rename to onnxscript/rewriter/rules/common/cast_constant_of_shape_test.py diff --git a/onnxscript/rewriter/rules/collapse_slices.py b/onnxscript/rewriter/rules/common/collapse_slices.py similarity index 100% rename from onnxscript/rewriter/rules/collapse_slices.py rename to onnxscript/rewriter/rules/common/collapse_slices.py diff --git a/onnxscript/rewriter/rules/collapse_slices_test.py b/onnxscript/rewriter/rules/common/collapse_slices_test.py similarity index 100% rename from onnxscript/rewriter/rules/collapse_slices_test.py rename to onnxscript/rewriter/rules/common/collapse_slices_test.py diff --git a/onnxscript/rewriter/rules/fuse_batchnorm.py b/onnxscript/rewriter/rules/common/fuse_batchnorm.py similarity index 100% rename from onnxscript/rewriter/rules/fuse_batchnorm.py rename to onnxscript/rewriter/rules/common/fuse_batchnorm.py diff --git a/onnxscript/rewriter/rules/fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/fuse_batchnorm_test.py similarity index 100% rename from onnxscript/rewriter/rules/fuse_batchnorm_test.py rename to onnxscript/rewriter/rules/common/fuse_batchnorm_test.py diff --git a/onnxscript/rewriter/rules/fuse_pad_into_conv.py b/onnxscript/rewriter/rules/common/fuse_pad_into_conv.py similarity index 100% rename from onnxscript/rewriter/rules/fuse_pad_into_conv.py rename to onnxscript/rewriter/rules/common/fuse_pad_into_conv.py diff --git a/onnxscript/rewriter/rules/fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/fuse_pad_into_conv_test.py similarity index 99% rename from onnxscript/rewriter/rules/fuse_pad_into_conv_test.py rename to onnxscript/rewriter/rules/common/fuse_pad_into_conv_test.py index 4f3fc6ca08..84c8fab4fb 100644 --- a/onnxscript/rewriter/rules/fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/fuse_pad_into_conv_test.py @@ -12,7 +12,7 @@ from onnxscript.rewriter import pattern as orp from onnxscript.rewriter import testing -from onnxscript.rewriter.rules.fuse_pad_into_conv import ( +from onnxscript.rewriter.rules.common.fuse_pad_into_conv import ( fuse_pad_into_conv, fuse_pad_into_conv_rule_set, normalize_pad_format_conv, diff --git a/onnxscript/rewriter/rules/fuse_relus_clips.py b/onnxscript/rewriter/rules/common/fuse_relus_clips.py similarity index 100% rename from onnxscript/rewriter/rules/fuse_relus_clips.py rename to onnxscript/rewriter/rules/common/fuse_relus_clips.py diff --git a/onnxscript/rewriter/rules/fuse_relus_clips_test.py b/onnxscript/rewriter/rules/common/fuse_relus_clips_test.py similarity index 99% rename from onnxscript/rewriter/rules/fuse_relus_clips_test.py rename to onnxscript/rewriter/rules/common/fuse_relus_clips_test.py index 7dfa4de430..2979bd8458 100644 --- a/onnxscript/rewriter/rules/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/common/fuse_relus_clips_test.py @@ -16,7 +16,7 @@ testing, ) from onnxscript.rewriter.rules import fuse_relus_clips -from onnxscript.rewriter.rules.fuse_relus_clips import ( +from onnxscript.rewriter.rules.common.fuse_relus_clips import ( fuse_successive_clip_relu_rule, fuse_successive_clip_rule, fuse_successive_relu_clip_rule, diff --git a/onnxscript/rewriter/rules/gemm_to_matmul_add.py b/onnxscript/rewriter/rules/common/gemm_to_matmul_add.py similarity index 89% rename from onnxscript/rewriter/rules/gemm_to_matmul_add.py rename to onnxscript/rewriter/rules/common/gemm_to_matmul_add.py index 0654755235..0deb0c7a23 100644 --- a/onnxscript/rewriter/rules/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/common/gemm_to_matmul_add.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from onnxscript.rewriter._rewrite_rule import RewriteRule -from onnxscript.rewriter.rules.broadcast_to_matmul import check_if_not_need_reshape +from onnxscript.rewriter.rules.common.broadcast_to_matmul import check_if_not_need_reshape # Pattern to match against diff --git a/onnxscript/rewriter/rules/gemm_to_matmul_add_test.py b/onnxscript/rewriter/rules/common/gemm_to_matmul_add_test.py similarity index 100% rename from onnxscript/rewriter/rules/gemm_to_matmul_add_test.py rename to onnxscript/rewriter/rules/common/gemm_to_matmul_add_test.py diff --git a/onnxscript/rewriter/rules/matmul_add_to_gemm.py b/onnxscript/rewriter/rules/common/matmul_add_to_gemm.py similarity index 100% rename from onnxscript/rewriter/rules/matmul_add_to_gemm.py rename to onnxscript/rewriter/rules/common/matmul_add_to_gemm.py diff --git a/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/matmul_add_to_gemm_test.py similarity index 99% rename from onnxscript/rewriter/rules/matmul_add_to_gemm_test.py rename to onnxscript/rewriter/rules/common/matmul_add_to_gemm_test.py index c31d4ab6c6..206c5f1c48 100644 --- a/onnxscript/rewriter/rules/matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/matmul_add_to_gemm_test.py @@ -11,7 +11,7 @@ from onnxscript import ir from onnxscript.rewriter import MatchingTracer, MatchStatus, testing from onnxscript.rewriter.rules import matmul_add_to_gemm -from onnxscript.rewriter.rules.matmul_add_to_gemm import matmul_add_to_gemm_rule +from onnxscript.rewriter.rules.common.matmul_add_to_gemm import matmul_add_to_gemm_rule class _MatMulAddToGemmTestBase(unittest.TestCase): diff --git a/onnxscript/rewriter/rules/no_op.py b/onnxscript/rewriter/rules/common/no_op.py similarity index 100% rename from onnxscript/rewriter/rules/no_op.py rename to onnxscript/rewriter/rules/common/no_op.py diff --git a/onnxscript/rewriter/rules/redundant_scatter_nd.py b/onnxscript/rewriter/rules/common/redundant_scatter_nd.py similarity index 100% rename from onnxscript/rewriter/rules/redundant_scatter_nd.py rename to onnxscript/rewriter/rules/common/redundant_scatter_nd.py diff --git a/onnxscript/rewriter/rules/redundant_scatter_nd_test.py b/onnxscript/rewriter/rules/common/redundant_scatter_nd_test.py similarity index 100% rename from onnxscript/rewriter/rules/redundant_scatter_nd_test.py rename to onnxscript/rewriter/rules/common/redundant_scatter_nd_test.py diff --git a/onnxscript/rewriter/onnx_fusions/_layer_norm.py b/onnxscript/rewriter/rules/fusion/_layer_norm.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_layer_norm.py rename to onnxscript/rewriter/rules/fusion/_layer_norm.py diff --git a/onnxscript/rewriter/onnx_fusions/_layer_norm_test.py b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py similarity index 98% rename from onnxscript/rewriter/onnx_fusions/_layer_norm_test.py rename to onnxscript/rewriter/rules/fusion/_layer_norm_test.py index 6c9734d058..6ea7f116fb 100644 --- a/onnxscript/rewriter/onnx_fusions/_layer_norm_test.py +++ b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py @@ -10,7 +10,7 @@ import onnxscript.rewriter.testing from onnxscript import FLOAT, OnnxFunction, script from onnxscript import opset18 as op -from onnxscript.rewriter.onnx_fusions._layer_norm import fuse_layer_normalization +from onnxscript.rewriter.rules.fusion._layer_norm import fuse_layer_normalization @script() diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/rules/fusion/_onnx_fusions.py similarity index 95% rename from onnxscript/rewriter/onnx_fusions/_onnx_fusions.py rename to onnxscript/rewriter/rules/fusion/_onnx_fusions.py index 0a45f3017c..bd73cb1f6d 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/rules/fusion/_onnx_fusions.py @@ -4,7 +4,7 @@ import onnx_ir as ir -from onnxscript.rewriter.onnx_fusions import _rms_normalization, _rotary_embedding +from onnxscript.rewriter.rules.fusion import _rms_normalization, _rotary_embedding def _get_onnx_opset_version(model: ir.Model) -> int | None: diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/rules/fusion/_onnx_fusions_test.py similarity index 97% rename from onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py rename to onnxscript/rewriter/rules/fusion/_onnx_fusions_test.py index 59a460005a..b7bd754d1f 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py +++ b/onnxscript/rewriter/rules/fusion/_onnx_fusions_test.py @@ -8,7 +8,7 @@ from parameterized import parameterized import onnxscript -import onnxscript.rewriter.onnx_fusions as onnx_fusions +import onnxscript.rewriter.rules.fusion as onnx_fusions from onnxscript.rewriter.models import _rotary_embedding_models diff --git a/onnxscript/rewriter/onnx_fusions/_rms_normalization.py b/onnxscript/rewriter/rules/fusion/_rms_normalization.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_rms_normalization.py rename to onnxscript/rewriter/rules/fusion/_rms_normalization.py diff --git a/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_rotary_embedding.py rename to onnxscript/rewriter/rules/fusion/_rotary_embedding.py From ae7fea1d8bfea24da6fe358c3f694a9a54b73063 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 29 Aug 2025 15:21:02 -0700 Subject: [PATCH 04/12] no op test Signed-off-by: Justin Chu --- onnxscript/rewriter/{ => rules/common}/no_op_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename onnxscript/rewriter/{ => rules/common}/no_op_test.py (99%) diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/rules/common/no_op_test.py similarity index 99% rename from onnxscript/rewriter/no_op_test.py rename to onnxscript/rewriter/rules/common/no_op_test.py index bdfcfd857f..23ecda96eb 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/rules/common/no_op_test.py @@ -5,7 +5,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter.rules import no_op +from onnxscript.rewriter.rules.common import no_op class NoOpTest(unittest.TestCase): From 0f9dbbdc36d5a1207e5fd10624b169e376d14d38 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 29 Aug 2025 15:21:54 -0700 Subject: [PATCH 05/12] path Signed-off-by: Justin Chu --- onnxscript/rewriter/ort_fusions/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 72c18ebeca..5762d7329b 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -33,7 +33,7 @@ fuse_skip_layer_normalization, fuse_skip_rms_normalization, ) -from onnxscript.rewriter.rules import gemm_to_matmul_add +from onnxscript.rewriter.rules.common import gemm_to_matmul_add ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, From a54164dfccef357b25471b464357cd396cbfa502 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 29 Aug 2025 15:22:39 -0700 Subject: [PATCH 06/12] path Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 54dd1c231f..dc514d311b 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -30,14 +30,14 @@ RewriteRuleClassBase, RewriteRuleSet, ) -from onnxscript.rewriter.rules import basic_rules as _basic_rules -from onnxscript.rewriter.rules import broadcast_to_matmul as _broadcast_to_matmul -from onnxscript.rewriter.rules import cast_constant_of_shape as _cast_constant_of_shape -from onnxscript.rewriter.rules import collapse_slices as _collapse_slices -from onnxscript.rewriter.rules import fuse_pad_into_conv as _fuse_pad_into_conv -from onnxscript.rewriter.rules import fuse_relus_clips as _fuse_relus_clips -from onnxscript.rewriter.rules import no_op as _no_op -from onnxscript.rewriter.rules import redundant_scatter_nd as _redundant_scatter_nd +from onnxscript.rewriter.rules.common import basic_rules as _basic_rules +from onnxscript.rewriter.rules.common import broadcast_to_matmul as _broadcast_to_matmul +from onnxscript.rewriter.rules.common import cast_constant_of_shape as _cast_constant_of_shape +from onnxscript.rewriter.rules.common import collapse_slices as _collapse_slices +from onnxscript.rewriter.rules.common import fuse_pad_into_conv as _fuse_pad_into_conv +from onnxscript.rewriter.rules.common import fuse_relus_clips as _fuse_relus_clips +from onnxscript.rewriter.rules.common import no_op as _no_op +from onnxscript.rewriter.rules.common import redundant_scatter_nd as _redundant_scatter_nd _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( From 668c6008a3162f4b0cb86ea5b809d8803bc25265 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 29 Aug 2025 16:05:59 -0700 Subject: [PATCH 07/12] Update Signed-off-by: Justin Chu --- .../rewriter/{rules/fusion => onnx_fusions}/_onnx_fusions.py | 0 .../{rules/fusion => onnx_fusions}/_onnx_fusions_test.py | 2 +- onnxscript/rewriter/rules/__init__.py | 2 ++ onnxscript/rewriter/rules/common/__init__.py | 2 ++ onnxscript/rewriter/rules/fusion/__init__.py | 2 ++ 5 files changed, 7 insertions(+), 1 deletion(-) rename onnxscript/rewriter/{rules/fusion => onnx_fusions}/_onnx_fusions.py (100%) rename onnxscript/rewriter/{rules/fusion => onnx_fusions}/_onnx_fusions_test.py (97%) create mode 100644 onnxscript/rewriter/rules/__init__.py create mode 100644 onnxscript/rewriter/rules/common/__init__.py create mode 100644 onnxscript/rewriter/rules/fusion/__init__.py diff --git a/onnxscript/rewriter/rules/fusion/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py similarity index 100% rename from onnxscript/rewriter/rules/fusion/_onnx_fusions.py rename to onnxscript/rewriter/onnx_fusions/_onnx_fusions.py diff --git a/onnxscript/rewriter/rules/fusion/_onnx_fusions_test.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py similarity index 97% rename from onnxscript/rewriter/rules/fusion/_onnx_fusions_test.py rename to onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py index b7bd754d1f..22d6120da1 100644 --- a/onnxscript/rewriter/rules/fusion/_onnx_fusions_test.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py @@ -8,7 +8,7 @@ from parameterized import parameterized import onnxscript -import onnxscript.rewriter.rules.fusion as onnx_fusions +from onnxscript.rewriter import onnx_fusions from onnxscript.rewriter.models import _rotary_embedding_models diff --git a/onnxscript/rewriter/rules/__init__.py b/onnxscript/rewriter/rules/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/rules/fusion/__init__.py b/onnxscript/rewriter/rules/fusion/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. From e2a15b1e397d959b43a7a49aa2a443dedc65155b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 2 Sep 2025 11:24:37 -0700 Subject: [PATCH 08/12] Refactor Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 18 +++++++------ onnxscript/rewriter/ort_fusions/_core.py | 4 +-- .../{basic_rules.py => _basic_rules.py} | 0 ...sic_rules_test.py => _basic_rules_test.py} | 20 +++++++------- ...t_to_matmul.py => _broadcast_to_matmul.py} | 0 ...l_test.py => _broadcast_to_matmul_test.py} | 4 +-- ...of_shape.py => _cast_constant_of_shape.py} | 0 ...est.py => _cast_constant_of_shape_test.py} | 6 ++--- ...collapse_slices.py => _collapse_slices.py} | 0 ...lices_test.py => _collapse_slices_test.py} | 13 +++++----- .../{fuse_batchnorm.py => _fuse_batchnorm.py} | 0 ...chnorm_test.py => _fuse_batchnorm_test.py} | 12 ++++----- ...ad_into_conv.py => _fuse_pad_into_conv.py} | 0 ...nv_test.py => _fuse_pad_into_conv_test.py} | 2 +- ...se_relus_clips.py => _fuse_relus_clips.py} | 0 ...lips_test.py => _fuse_relus_clips_test.py} | 2 +- ...o_matmul_add.py => _gemm_to_matmul_add.py} | 2 +- ...dd_test.py => _gemm_to_matmul_add_test.py} | 26 +++++++++---------- ..._add_to_gemm.py => _matmul_add_to_gemm.py} | 0 ...mm_test.py => _matmul_add_to_gemm_test.py} | 15 +++++------ .../rules/common/{no_op.py => _no_op.py} | 0 .../common/{no_op_test.py => _no_op_test.py} | 4 +-- ...scatter_nd.py => _redundant_scatter_nd.py} | 0 ..._test.py => _redundant_scatter_nd_test.py} | 6 ++--- 24 files changed, 67 insertions(+), 67 deletions(-) rename onnxscript/rewriter/rules/common/{basic_rules.py => _basic_rules.py} (100%) rename onnxscript/rewriter/rules/common/{basic_rules_test.py => _basic_rules_test.py} (96%) rename onnxscript/rewriter/rules/common/{broadcast_to_matmul.py => _broadcast_to_matmul.py} (100%) rename onnxscript/rewriter/rules/common/{broadcast_to_matmul_test.py => _broadcast_to_matmul_test.py} (99%) rename onnxscript/rewriter/rules/common/{cast_constant_of_shape.py => _cast_constant_of_shape.py} (100%) rename onnxscript/rewriter/rules/common/{cast_constant_of_shape_test.py => _cast_constant_of_shape_test.py} (89%) rename onnxscript/rewriter/rules/common/{collapse_slices.py => _collapse_slices.py} (100%) rename onnxscript/rewriter/rules/common/{collapse_slices_test.py => _collapse_slices_test.py} (92%) rename onnxscript/rewriter/rules/common/{fuse_batchnorm.py => _fuse_batchnorm.py} (100%) rename onnxscript/rewriter/rules/common/{fuse_batchnorm_test.py => _fuse_batchnorm_test.py} (94%) rename onnxscript/rewriter/rules/common/{fuse_pad_into_conv.py => _fuse_pad_into_conv.py} (100%) rename onnxscript/rewriter/rules/common/{fuse_pad_into_conv_test.py => _fuse_pad_into_conv_test.py} (99%) rename onnxscript/rewriter/rules/common/{fuse_relus_clips.py => _fuse_relus_clips.py} (100%) rename onnxscript/rewriter/rules/common/{fuse_relus_clips_test.py => _fuse_relus_clips_test.py} (99%) rename onnxscript/rewriter/rules/common/{gemm_to_matmul_add.py => _gemm_to_matmul_add.py} (88%) rename onnxscript/rewriter/rules/common/{gemm_to_matmul_add_test.py => _gemm_to_matmul_add_test.py} (93%) rename onnxscript/rewriter/rules/common/{matmul_add_to_gemm.py => _matmul_add_to_gemm.py} (100%) rename onnxscript/rewriter/rules/common/{matmul_add_to_gemm_test.py => _matmul_add_to_gemm_test.py} (94%) rename onnxscript/rewriter/rules/common/{no_op.py => _no_op.py} (100%) rename onnxscript/rewriter/rules/common/{no_op_test.py => _no_op_test.py} (98%) rename onnxscript/rewriter/rules/common/{redundant_scatter_nd.py => _redundant_scatter_nd.py} (100%) rename onnxscript/rewriter/rules/common/{redundant_scatter_nd_test.py => _redundant_scatter_nd_test.py} (96%) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index dc514d311b..920bddb3d4 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -30,14 +30,16 @@ RewriteRuleClassBase, RewriteRuleSet, ) -from onnxscript.rewriter.rules.common import basic_rules as _basic_rules -from onnxscript.rewriter.rules.common import broadcast_to_matmul as _broadcast_to_matmul -from onnxscript.rewriter.rules.common import cast_constant_of_shape as _cast_constant_of_shape -from onnxscript.rewriter.rules.common import collapse_slices as _collapse_slices -from onnxscript.rewriter.rules.common import fuse_pad_into_conv as _fuse_pad_into_conv -from onnxscript.rewriter.rules.common import fuse_relus_clips as _fuse_relus_clips -from onnxscript.rewriter.rules.common import no_op as _no_op -from onnxscript.rewriter.rules.common import redundant_scatter_nd as _redundant_scatter_nd +from onnxscript.rewriter.rules.common import ( + _basic_rules, + _broadcast_to_matmul, + _cast_constant_of_shape, + _collapse_slices, + _fuse_pad_into_conv, + _fuse_relus_clips, + _no_op, + _redundant_scatter_nd, +) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 5762d7329b..a1c7d4df5a 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -33,7 +33,7 @@ fuse_skip_layer_normalization, fuse_skip_rms_normalization, ) -from onnxscript.rewriter.rules.common import gemm_to_matmul_add +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, @@ -134,7 +134,7 @@ def optimize_for_ort( - The optimized `ir.Model` after applying transformer-specific fusions. - A dictionary with a count of each of the fusions applied. """ - rewrite(model, [gemm_to_matmul_add.rule]) + rewrite(model, [_gemm_to_matmul_add.rule]) model, fusion_count = fuse_xformers( model, debug=debug, diff --git a/onnxscript/rewriter/rules/common/basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py similarity index 100% rename from onnxscript/rewriter/rules/common/basic_rules.py rename to onnxscript/rewriter/rules/common/_basic_rules.py diff --git a/onnxscript/rewriter/rules/common/basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py similarity index 96% rename from onnxscript/rewriter/rules/common/basic_rules_test.py rename to onnxscript/rewriter/rules/common/_basic_rules_test.py index b876e4fce0..8709300763 100644 --- a/onnxscript/rewriter/rules/common/basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -12,9 +12,9 @@ import onnxscript import onnxscript.onnx_types as ot -import onnxscript.rewriter.rules.common.basic_rules as basic_rules from onnxscript import ir from onnxscript.onnx_opset import opset18 +from onnxscript.rewriter.rules.common import _basic_rules FLOAT = onnx.TensorProto.FLOAT @@ -98,7 +98,7 @@ def _check_model( ] ) def test_basic_optimization_rules_identity(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -126,7 +126,7 @@ def test_basic_optimization_rules_identity(self, _: str, model: ir.Model): ] ) def test_basic_optimization_rules_transpose_transpose(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -153,7 +153,7 @@ def cast_cast_model(x): ] ) def test_cast_cast_rule(self, _: str, type1, type2, type3): - rule = basic_rules.cast_cast_rule + rule = _basic_rules.cast_cast_rule model_proto = self._double_cast_model(type1, type2, type3) model = ir.serde.deserialize_model(model_proto) rule.apply_to_model(model) @@ -172,7 +172,7 @@ def test_cast_cast_rule(self, _: str, type1, type2, type3): ] ) def test_cast_identity_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -228,7 +228,7 @@ def test_cast_identity_rule(self, _: str, model: ir.Model): def test_expand_identity_rule( self, _: str, model: ir.Model, expected_nodes: tuple[str, ...] ): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -310,7 +310,7 @@ def test_expand_identity_rule( ] ) def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -369,7 +369,7 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): ] ) def test_reshape_reshape_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -420,7 +420,7 @@ def _slices_split_models(cls): def test_slices_split_rule(self): for model_proto in self._slices_split_models(): ir_model = ir.serde.deserialize_model(model_proto) - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) @@ -428,7 +428,7 @@ def test_slices_split_rule(self): self._check_model(model_proto, rewritten_model) def test_squeeze_reshape_1d_rule(self): - rule = basic_rules.squeeze_reshape_1d_rule + rule = _basic_rules.squeeze_reshape_1d_rule def check(model_script, expected_count) -> None: model_proto = model_script.to_model_proto() diff --git a/onnxscript/rewriter/rules/common/broadcast_to_matmul.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul.py similarity index 100% rename from onnxscript/rewriter/rules/common/broadcast_to_matmul.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul.py diff --git a/onnxscript/rewriter/rules/common/broadcast_to_matmul_test.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py similarity index 99% rename from onnxscript/rewriter/rules/common/broadcast_to_matmul_test.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py index 7d339521bd..39c4eaeb6e 100644 --- a/onnxscript/rewriter/rules/common/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py @@ -9,7 +9,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter.rules import broadcast_to_matmul +from onnxscript.rewriter.rules.common import _broadcast_to_matmul def _infer_shapes(model: ir.Model) -> ir.Model: @@ -38,7 +38,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) diff --git a/onnxscript/rewriter/rules/common/cast_constant_of_shape.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape.py similarity index 100% rename from onnxscript/rewriter/rules/common/cast_constant_of_shape.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape.py diff --git a/onnxscript/rewriter/rules/common/cast_constant_of_shape_test.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py similarity index 89% rename from onnxscript/rewriter/rules/common/cast_constant_of_shape_test.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py index 24e4fb0d49..794491024b 100644 --- a/onnxscript/rewriter/rules/common/cast_constant_of_shape_test.py +++ b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py @@ -6,7 +6,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter.rules import cast_constant_of_shape +from onnxscript.rewriter.rules.common import _cast_constant_of_shape class CastConstantOfShapeTest(unittest.TestCase): @@ -23,7 +23,7 @@ def test_cast_after_constant_of_shape_is_fused(self): ) onnx.checker.check_model(input_model_proto, True) model = ir.serde.deserialize_model(input_model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) @@ -42,7 +42,7 @@ def test_cast_after_constant_of_shape_without_value_is_fused(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) diff --git a/onnxscript/rewriter/rules/common/collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py similarity index 100% rename from onnxscript/rewriter/rules/common/collapse_slices.py rename to onnxscript/rewriter/rules/common/_collapse_slices.py diff --git a/onnxscript/rewriter/rules/common/collapse_slices_test.py b/onnxscript/rewriter/rules/common/_collapse_slices_test.py similarity index 92% rename from onnxscript/rewriter/rules/common/collapse_slices_test.py rename to onnxscript/rewriter/rules/common/_collapse_slices_test.py index 91fc86f6e2..727240344d 100644 --- a/onnxscript/rewriter/rules/common/collapse_slices_test.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices_test.py @@ -6,11 +6,10 @@ import numpy as np import onnx.parser -import onnx.shape_inference from onnxscript import ir from onnxscript.rewriter import testing -from onnxscript.rewriter.rules import collapse_slices +from onnxscript.rewriter.rules.common import _collapse_slices _INT64_MAX = 9223372036854775807 @@ -31,7 +30,7 @@ def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) self.assertIn("Identity", [node.op_type for node in model.graph]) @@ -56,7 +55,7 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) self.assertIn("Identity", [node.op_type for node in model.graph]) @@ -81,7 +80,7 @@ def test_slice_unequal_dynamic_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 0) def test_slice_equal_dynamic_shape(self): @@ -99,7 +98,7 @@ def test_slice_equal_dynamic_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) def test_slice_equal_dynamic_shape_but_step_reverse(self): @@ -117,6 +116,6 @@ def test_slice_equal_dynamic_shape_but_step_reverse(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) # Should not change the output shape if we did not use the default step of 1 self.assertEqual(count, 0) diff --git a/onnxscript/rewriter/rules/common/fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py similarity index 100% rename from onnxscript/rewriter/rules/common/fuse_batchnorm.py rename to onnxscript/rewriter/rules/common/_fuse_batchnorm.py diff --git a/onnxscript/rewriter/rules/common/fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py similarity index 94% rename from onnxscript/rewriter/rules/common/fuse_batchnorm_test.py rename to onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py index 17fc99e415..90412e4fe2 100644 --- a/onnxscript/rewriter/rules/common/fuse_batchnorm_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py @@ -9,7 +9,7 @@ from onnxscript import ir from onnxscript.rewriter import testing -from onnxscript.rewriter.rules import fuse_batchnorm +from onnxscript.rewriter.rules.common import _fuse_batchnorm class FuseBatchnormTest(unittest.TestCase): @@ -74,7 +74,7 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -133,7 +133,7 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -197,7 +197,7 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -224,7 +224,7 @@ def test_fuse_batchnorm_non_initializers(self): """) onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) # No changes were applied self.assertEqual(count, 0) @@ -248,7 +248,7 @@ def test_fuse_batchnorm_graph_inputs(self): onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) # No changes were applied as W is a graph input self.assertEqual(count, 0) diff --git a/onnxscript/rewriter/rules/common/fuse_pad_into_conv.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py similarity index 100% rename from onnxscript/rewriter/rules/common/fuse_pad_into_conv.py rename to onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py diff --git a/onnxscript/rewriter/rules/common/fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py similarity index 99% rename from onnxscript/rewriter/rules/common/fuse_pad_into_conv_test.py rename to onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py index 84c8fab4fb..58df13ac0c 100644 --- a/onnxscript/rewriter/rules/common/fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -12,7 +12,7 @@ from onnxscript.rewriter import pattern as orp from onnxscript.rewriter import testing -from onnxscript.rewriter.rules.common.fuse_pad_into_conv import ( +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( fuse_pad_into_conv, fuse_pad_into_conv_rule_set, normalize_pad_format_conv, diff --git a/onnxscript/rewriter/rules/common/fuse_relus_clips.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py similarity index 100% rename from onnxscript/rewriter/rules/common/fuse_relus_clips.py rename to onnxscript/rewriter/rules/common/_fuse_relus_clips.py diff --git a/onnxscript/rewriter/rules/common/fuse_relus_clips_test.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py similarity index 99% rename from onnxscript/rewriter/rules/common/fuse_relus_clips_test.py rename to onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py index 2979bd8458..60747a61b3 100644 --- a/onnxscript/rewriter/rules/common/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py @@ -16,7 +16,7 @@ testing, ) from onnxscript.rewriter.rules import fuse_relus_clips -from onnxscript.rewriter.rules.common.fuse_relus_clips import ( +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( fuse_successive_clip_relu_rule, fuse_successive_clip_rule, fuse_successive_relu_clip_rule, diff --git a/onnxscript/rewriter/rules/common/gemm_to_matmul_add.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py similarity index 88% rename from onnxscript/rewriter/rules/common/gemm_to_matmul_add.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py index 0deb0c7a23..0bce4dd48f 100644 --- a/onnxscript/rewriter/rules/common/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from onnxscript.rewriter._rewrite_rule import RewriteRule -from onnxscript.rewriter.rules.common.broadcast_to_matmul import check_if_not_need_reshape +from onnxscript.rewriter.rules.common._broadcast_to_matmul import check_if_not_need_reshape # Pattern to match against diff --git a/onnxscript/rewriter/rules/common/gemm_to_matmul_add_test.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py similarity index 93% rename from onnxscript/rewriter/rules/common/gemm_to_matmul_add_test.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py index 71fcf34a9d..4abaf5ca13 100644 --- a/onnxscript/rewriter/rules/common/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py @@ -5,7 +5,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter.rules import gemm_to_matmul_add +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add class ReshapeGemmReshapeTest(unittest.TestCase): @@ -25,7 +25,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable(self): ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -70,7 +70,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable_in_nested ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -94,7 +94,7 @@ def test_reshape_gemm_reshape_remain_when_input_last_dim_and_second_last_dim_not """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -115,7 +115,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_not_broadcastable( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -136,7 +136,7 @@ def test_reshape_gemm_reshape_replace_when_inputs_are_broadcastable_with_one_in_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -159,7 +159,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -182,7 +182,7 @@ def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_b """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -203,7 +203,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -226,7 +226,7 @@ def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -247,7 +247,7 @@ def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broad """ ) model = ir.serde.deserialize_model(model_proto) - replacement_count = gemm_to_matmul_add.rule.apply_to_model(model) + replacement_count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(replacement_count, 1) self.assertEqual(len(model.graph), 4) @@ -268,7 +268,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broad """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -289,7 +289,7 @@ def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) diff --git a/onnxscript/rewriter/rules/common/matmul_add_to_gemm.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py similarity index 100% rename from onnxscript/rewriter/rules/common/matmul_add_to_gemm.py rename to onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py diff --git a/onnxscript/rewriter/rules/common/matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py similarity index 94% rename from onnxscript/rewriter/rules/common/matmul_add_to_gemm_test.py rename to onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index 206c5f1c48..c43a21d9df 100644 --- a/onnxscript/rewriter/rules/common/matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -10,8 +10,7 @@ from onnxscript import ir from onnxscript.rewriter import MatchingTracer, MatchStatus, testing -from onnxscript.rewriter.rules import matmul_add_to_gemm -from onnxscript.rewriter.rules.common.matmul_add_to_gemm import matmul_add_to_gemm_rule +from onnxscript.rewriter.rules.common import _matmul_add_to_gemm class _MatMulAddToGemmTestBase(unittest.TestCase): @@ -102,13 +101,13 @@ def check_matmul_add_to_gemm_incompatible_shapes(self, **kwargs): updated_model = self.clone_model(base_model) tracer = MatchingTracer() - count = matmul_add_to_gemm_rule.apply_to_model(updated_model, tracer=tracer) + count = _matmul_add_to_gemm.matmul_add_to_gemm_rule.apply_to_model(updated_model, tracer=tracer) # Check that the model is unchanged self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[matmul_add_to_gemm_rule][0] + tracer_match = tracer.best_matches_map[_matmul_add_to_gemm.matmul_add_to_gemm_rule][0] self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) self.assertRegex( tracer_match.match_result.reason, "Rank of input_a and input_b must be 2" @@ -130,7 +129,7 @@ def test_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): bias_as_inputs=bias_as_inputs, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) # Check MatMul + Add are fused into Gemm self.assertEqual(count, 1) @@ -177,7 +176,7 @@ def test_transpose_a_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transA=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) # Check MatMul(Transpose, W) + Add are fused into Gemm self.assertEqual(count, 1) @@ -226,7 +225,7 @@ def test_transpose_b_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transB=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) # Check MatMul(X, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) @@ -276,7 +275,7 @@ def test_transpose_ab_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inpu transB=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) # Check MatMul(Transpose, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) diff --git a/onnxscript/rewriter/rules/common/no_op.py b/onnxscript/rewriter/rules/common/_no_op.py similarity index 100% rename from onnxscript/rewriter/rules/common/no_op.py rename to onnxscript/rewriter/rules/common/_no_op.py diff --git a/onnxscript/rewriter/rules/common/no_op_test.py b/onnxscript/rewriter/rules/common/_no_op_test.py similarity index 98% rename from onnxscript/rewriter/rules/common/no_op_test.py rename to onnxscript/rewriter/rules/common/_no_op_test.py index 23ecda96eb..7815473e34 100644 --- a/onnxscript/rewriter/rules/common/no_op_test.py +++ b/onnxscript/rewriter/rules/common/_no_op_test.py @@ -5,13 +5,13 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter.rules.common import no_op +from onnxscript.rewriter.rules.common import _no_op class NoOpTest(unittest.TestCase): def _check(self, model_text: str) -> None: model = ir.from_onnx_text(model_text) - count = no_op.rules.apply_to_model(model) + count = _no_op.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(model.graph[-1].op_type, "Identity") diff --git a/onnxscript/rewriter/rules/common/redundant_scatter_nd.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py similarity index 100% rename from onnxscript/rewriter/rules/common/redundant_scatter_nd.py rename to onnxscript/rewriter/rules/common/_redundant_scatter_nd.py diff --git a/onnxscript/rewriter/rules/common/redundant_scatter_nd_test.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py similarity index 96% rename from onnxscript/rewriter/rules/common/redundant_scatter_nd_test.py rename to onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py index fd9c2ab300..96e3bcc80c 100644 --- a/onnxscript/rewriter/rules/common/redundant_scatter_nd_test.py +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py @@ -13,7 +13,7 @@ import onnxscript.optimizer from onnxscript import FLOAT, script from onnxscript import opset18 as op -from onnxscript.rewriter.rules import redundant_scatter_nd +from onnxscript.rewriter.rules.common import _redundant_scatter_nd shape_inference = ShapeInferencePass() onnx_check = CheckerPass(True) @@ -48,7 +48,7 @@ def model_script( onnx_check(model) shape_inference(model) onnxscript.optimizer.fold_constants(model) - count = redundant_scatter_nd.rules.apply_to_model(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) self.assertEqual(count, 1) onnx_check(model) optimized_model_proto = ir.serde.serialize_model(model) @@ -94,7 +94,7 @@ def test_redundant_scatter_nd_static_indices(self): model.graph.initializers["indices"] = indices_value original_model_proto = ir.serde.serialize_model(model) - count = redundant_scatter_nd.rules.apply_to_model(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertIn("Identity", [node.op_type for node in model.graph]) From e14946cfa6066692200b47e322f91f957f35772d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 2 Sep 2025 11:29:41 -0700 Subject: [PATCH 09/12] init Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 920bddb3d4..a099a88dc8 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -43,14 +43,14 @@ _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( - *_no_op.rules.rules, # TODO: merge this rule into constant folding? - *_broadcast_to_matmul.rules.rules, - *_cast_constant_of_shape.rules.rules, - *_collapse_slices.rules.rules, - *_fuse_relus_clips.fuse_relus_clips_rules().rules, - *_basic_rules.basic_optimization_rules().rules, - *_redundant_scatter_nd.rules.rules, - *_fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules, + *_no_op.rules, # TODO: merge this rule into constant folding? + *_broadcast_to_matmul.rules, + *_cast_constant_of_shape.rules, + *_collapse_slices.rules, + *_fuse_relus_clips.fuse_relus_clips_rules(), + *_basic_rules.basic_optimization_rules(), + *_redundant_scatter_nd.rules, + *_fuse_pad_into_conv.fuse_pad_into_conv_rule_set(), ) From d39344c7dc777b6fc5f5d388f908ebdd7343c7ff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 2 Sep 2025 16:01:08 -0700 Subject: [PATCH 10/12] update Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 4 +- onnxscript/rewriter/ort_fusions/_core.py | 2 +- onnxscript/rewriter/rules/common/__init__.py | 101 ++++++++++++++++++ .../rewriter/rules/common/_basic_rules.py | 12 +-- .../rewriter/rules/common/_collapse_slices.py | 6 +- .../rewriter/rules/common/_fuse_batchnorm.py | 23 ++-- .../rules/common/_fuse_batchnorm_test.py | 10 +- .../rules/common/_fuse_pad_into_conv.py | 36 +++---- .../rules/common/_fuse_relus_clips.py | 36 +++---- .../rules/common/_fuse_relus_clips_test.py | 22 ++-- .../rules/common/_gemm_to_matmul_add.py | 2 +- .../rules/common/_gemm_to_matmul_add_test.py | 24 ++--- .../rules/common/_matmul_add_to_gemm.py | 25 ++--- .../rules/common/_matmul_add_to_gemm_test.py | 8 +- .../rules/common/_redundant_scatter_nd.py | 6 +- 15 files changed, 193 insertions(+), 124 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index a099a88dc8..1d07e9f5af 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -47,10 +47,10 @@ *_broadcast_to_matmul.rules, *_cast_constant_of_shape.rules, *_collapse_slices.rules, - *_fuse_relus_clips.fuse_relus_clips_rules(), + *_fuse_relus_clips.rules, *_basic_rules.basic_optimization_rules(), *_redundant_scatter_nd.rules, - *_fuse_pad_into_conv.fuse_pad_into_conv_rule_set(), + *_fuse_pad_into_conv.rules, ) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index a1c7d4df5a..15faf083c8 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -134,7 +134,7 @@ def optimize_for_ort( - The optimized `ir.Model` after applying transformer-specific fusions. - A dictionary with a count of each of the fusions applied. """ - rewrite(model, [_gemm_to_matmul_add.rule]) + rewrite(model, [_gemm_to_matmul_add.gemm_to_matmul_add_rule]) model, fusion_count = fuse_xformers( model, debug=debug, diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 59e481eb93..752e3c9430 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -1,2 +1,103 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +__all__ = [ + "add_0_rule", + "cast_cast_rule", + "cast_constant_of_shape_rule", + "cast_constant_of_shape_without_value_rule", + "collapse_slice_rule", + "collapse_slice2_rule", + "div_by_1_rule", + "dropout_inference_rule", + "dropout_zero_rule", + "fuse_batchnorm_into_conv_rule", + "fuse_batchnorm_into_conv_transpose_rule", + "fuse_batchnorm_into_gemm_rule", + "fuse_pad_into_conv_integer_rule", + "fuse_pad_into_conv_rule", + "gemm_to_matmul_add_rule", + "matmul_add_to_gemm_rule", + "mul_by_1_rule", + "no_op_cast_rule", + "no_op_dynamic_scatter_nd_rule", + "no_op_expand_rule", + "no_op_static_scatter_nd_rule", + "no_op_transpose_rule", + "normalize_pad_format_conv_integer_rule", + "normalize_pad_format_conv_rule", + "one_reshape_matmul_reshape_rule", + "reshape_reshape_rule", + "slice_split_rule", + "squeeze_reshape_1d_rule", + "sub_0_rule", + "successive_clip_relu_rule", + "successive_clip_rule", + "successive_relu_clip_rule", + "successive_relu_rule", + "transpose_a_matmul_add_to_gemm_rule", + "transpose_ab_matmul_add_to_gemm_rule", + "transpose_b_matmul_add_to_gemm_rule", + "transpose_transpose_rule", + "two_reshapes_matmul_reshape_rule", + "unsqueeze_unsqueeze_rule", +] + +from onnxscript.rewriter.rules.common._basic_rules import ( + cast_cast_rule, + no_op_cast_rule, + no_op_expand_rule, + no_op_transpose_rule, + reshape_reshape_rule, + slice_split_rule, + squeeze_reshape_1d_rule, + transpose_transpose_rule, + unsqueeze_unsqueeze_rule, +) +from onnxscript.rewriter.rules.common._broadcast_to_matmul import ( + one_reshape_matmul_reshape_rule, + two_reshapes_matmul_reshape_rule, +) +from onnxscript.rewriter.rules.common._cast_constant_of_shape import ( + cast_constant_of_shape_rule, + cast_constant_of_shape_without_value_rule, +) +from onnxscript.rewriter.rules.common._collapse_slices import ( + collapse_slice2_rule, + collapse_slice_rule, +) +from onnxscript.rewriter.rules.common._fuse_batchnorm import ( + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, +) +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( + fuse_pad_into_conv_integer_rule, + fuse_pad_into_conv_rule, + normalize_pad_format_conv_integer_rule, + normalize_pad_format_conv_rule, +) +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, + successive_relu_rule, +) +from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule +from onnxscript.rewriter.rules.common._matmul_add_to_gemm import ( + matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_ab_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, +) +from onnxscript.rewriter.rules.common._no_op import ( + add_0_rule, + div_by_1_rule, + dropout_inference_rule, + dropout_zero_rule, + mul_by_1_rule, + sub_0_rule, +) +from onnxscript.rewriter.rules.common._redundant_scatter_nd import ( + no_op_dynamic_scatter_nd_rule, + no_op_static_scatter_nd_rule, +) diff --git a/onnxscript/rewriter/rules/common/_basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py index 2788cb7cda..6f38050f3e 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules.py +++ b/onnxscript/rewriter/rules/common/_basic_rules.py @@ -281,11 +281,11 @@ def check(self, context, x, axes1, axes2) -> MatchResult: # Create rule instances cast_cast_rule = CastCast.rule() -cast_identity_rule = CastIdentity.rule() -expand_identity_rule = ExpandIdentity.rule() +no_op_cast_rule = CastIdentity.rule() +no_op_expand_rule = ExpandIdentity.rule() reshape_reshape_rule = ReshapeReshape.rule() slice_split_rule = SlicesSplit.rule() -transpose_identity_rule = TransposeIdentity.rule() +no_op_transpose_rule = TransposeIdentity.rule() transpose_transpose_rule = TransposeTranspose.rule() unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule() squeeze_reshape_1d_rule = SqueezeReshape.rule() @@ -309,11 +309,11 @@ def basic_optimization_rules() -> RewriteRuleSet: return RewriteRuleSet( [ cast_cast_rule, - cast_identity_rule, - expand_identity_rule, + no_op_cast_rule, + no_op_expand_rule, reshape_reshape_rule, slice_split_rule, - transpose_identity_rule, + no_op_transpose_rule, transpose_transpose_rule, unsqueeze_unsqueeze_rule, squeeze_reshape_1d_rule, diff --git a/onnxscript/rewriter/rules/common/_collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py index 291128157d..5e262a785e 100644 --- a/onnxscript/rewriter/rules/common/_collapse_slices.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -89,13 +89,13 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ # Register the rewrite rules -remove_redundant_slice = RewriteRule( +collapse_slice_rule = RewriteRule( _potential_redundant_slice, _identity_to_itself, _check_if_redundant_slice, ) -remove_redundant_slice2 = RewriteRule( +collapse_slice2_rule = RewriteRule( _potential_redundant_slice, _identity_to_itself, _same_shape, @@ -104,4 +104,4 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ # NOTE: The second rule subsumes the first one. So, we may be able to remove the first one, # provided shape-inference is run before the rewriter and computes the shape of the slice output. -rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2]) +rules = RewriteRuleSet([collapse_slice_rule, collapse_slice2_rule]) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py index 51e4e20db3..a5ceb00468 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -167,21 +167,14 @@ def pattern(self, op, x): fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule() -fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule() +fuse_batchnorm_into_conv_transpose_rule = FuseBatchNormIntoConvTranspose().rule() fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule() -def fuse_batchnorm_rule_set() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse BatchNormalization nodes - into preceding nodes such as Conv, ConvTranspose, and Gemm. - - Returns: - RewriteRuleSet - """ - return RewriteRuleSet( - [ - fuse_batchnorm_into_conv_rule, - fuse_batchnorm_into_convtranspose_rule, - fuse_batchnorm_into_gemm_rule, - ] - ) +rules = RewriteRuleSet( + [ + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py index 90412e4fe2..3e617340ff 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py @@ -74,7 +74,7 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -133,7 +133,7 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -197,7 +197,7 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -224,7 +224,7 @@ def test_fuse_batchnorm_non_initializers(self): """) onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # No changes were applied self.assertEqual(count, 0) @@ -248,7 +248,7 @@ def test_fuse_batchnorm_graph_inputs(self): onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = _fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # No changes were applied as W is a graph input self.assertEqual(count, 0) diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py index 7aeae57ccd..39aab00eda 100644 --- a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py @@ -327,25 +327,17 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: return op.ConvInteger(x, _allow_other_inputs=True, _outputs=["conv"]) -normalize_pad_format_conv = NormalizePadFormatConv.rule() -normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule() -fuse_pad_into_conv = FuseConvPad.rule() -fuse_pad_into_conv_integer = FuseConvIntegerPad.rule() - - -def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet: - """Returns a set of rewrite rules that fuse Pad nodes into preceding: - - Conv - - ConvInteger - - Returns: - RewriteRuleSet - """ - return orp.RewriteRuleSet( - [ - normalize_pad_format_conv, - normalize_pad_format_conv_integer, - fuse_pad_into_conv, - fuse_pad_into_conv_integer, - ] - ) +normalize_pad_format_conv_rule = NormalizePadFormatConv.rule() +normalize_pad_format_conv_integer_rule = NormalizePadFormatConvInteger.rule() +fuse_pad_into_conv_rule = FuseConvPad.rule() +fuse_pad_into_conv_integer_rule = FuseConvIntegerPad.rule() + + +rules = orp.RewriteRuleSet( + [ + normalize_pad_format_conv_rule, + normalize_pad_format_conv_integer_rule, + fuse_pad_into_conv_rule, + fuse_pad_into_conv_integer_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_fuse_relus_clips.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py index 484ca679fc..5d294cdbd7 100644 --- a/onnxscript/rewriter/rules/common/_fuse_relus_clips.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py @@ -169,25 +169,17 @@ def pattern(self, op, x): return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"])) -fuse_successive_relu_rule = FuseSuccessiveRelu().rule() -fuse_successive_clip_rule = FuseSuccessiveClip().rule() -fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() -fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule() - - -def fuse_relus_clips_rules() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse successive Relu/Clip nodes. - - Returns: - RewriteRuleSet - """ - - # Order is important - return RewriteRuleSet( - [ - fuse_successive_clip_relu_rule, - fuse_successive_relu_clip_rule, - fuse_successive_relu_rule, - fuse_successive_clip_rule, - ] - ) +successive_relu_rule = FuseSuccessiveRelu().rule() +successive_clip_rule = FuseSuccessiveClip().rule() +successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() +successive_relu_clip_rule = FuseSuccessiveReluClip().rule() + + +rules = RewriteRuleSet( + [ + successive_clip_relu_rule, + successive_relu_clip_rule, + successive_relu_rule, + successive_clip_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py index 60747a61b3..df2d669930 100644 --- a/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py @@ -15,11 +15,11 @@ RewriteRule, testing, ) -from onnxscript.rewriter.rules import fuse_relus_clips +from onnxscript.rewriter.rules.common import _fuse_relus_clips from onnxscript.rewriter.rules.common._fuse_relus_clips import ( - fuse_successive_clip_relu_rule, - fuse_successive_clip_rule, - fuse_successive_relu_clip_rule, + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, ) @@ -40,7 +40,7 @@ def run_test( onnx_checker.CheckerPass(True)(base_model) base_model = shape_inference.infer_shapes(base_model) updated_model = self.clone_model(base_model) - _ = fuse_relus_clips.fuse_relus_clips_rules().apply_to_model(updated_model) + _ = _fuse_relus_clips.rules.apply_to_model(updated_model) # Check expected op_types self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) @@ -214,7 +214,7 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): x1 = Relu(X) Y = Clip(x1, min) """, - fuse_successive_clip_relu_rule, + successive_clip_relu_rule, ), ( "clip_then_relu", @@ -222,7 +222,7 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): x1 = Clip(X, min) Y = Relu(x1) """, - fuse_successive_relu_clip_rule, + successive_relu_clip_rule, ), ] ) @@ -245,7 +245,7 @@ def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite x1 = Relu(X) Y = Clip(x1, min) """, - fuse_successive_clip_relu_rule, + successive_clip_relu_rule, ), ( "clip_then_relu", @@ -253,7 +253,7 @@ def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite x1 = Clip(X, min) Y = Relu(x1) """, - fuse_successive_relu_clip_rule, + successive_relu_clip_rule, ), ] ) @@ -334,7 +334,7 @@ def test_fail_fuse_successive_clips_non_initializers(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test(model, fuse_successive_clip_rule, "is not a constant.") + self.run_failed_condition_test(model, successive_clip_rule, "is not a constant.") def test_fail_fuse_successive_clips_graph_inputs(self): model = ir.from_onnx_text(""" @@ -346,7 +346,7 @@ def test_fail_fuse_successive_clips_graph_inputs(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test(model, fuse_successive_clip_rule, "is a graph input.") + self.run_failed_condition_test(model, successive_clip_rule, "is a graph input.") class FuseReluClipIntegrationTest(_FuseReluClipTestBase): diff --git a/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py index 0bce4dd48f..ee0abf74f2 100644 --- a/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py @@ -18,4 +18,4 @@ def matmul_add(op, input_a, input_b, input_c, **_): return op.Add(matmul, input_c) -rule = RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) +gemm_to_matmul_add_rule = RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) diff --git a/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py index 4abaf5ca13..90551d8d3b 100644 --- a/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py @@ -25,7 +25,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable(self): ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -70,7 +70,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable_in_nested ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -94,7 +94,7 @@ def test_reshape_gemm_reshape_remain_when_input_last_dim_and_second_last_dim_not """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -115,7 +115,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_not_broadcastable( """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -136,7 +136,7 @@ def test_reshape_gemm_reshape_replace_when_inputs_are_broadcastable_with_one_in_ """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -159,7 +159,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -182,7 +182,7 @@ def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_b """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -203,7 +203,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -226,7 +226,7 @@ def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_ """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -247,7 +247,7 @@ def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broad """ ) model = ir.serde.deserialize_model(model_proto) - replacement_count = _gemm_to_matmul_add.rule.apply_to_model(model) + replacement_count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(replacement_count, 1) self.assertEqual(len(model.graph), 4) @@ -268,7 +268,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broad """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -289,7 +289,7 @@ def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = _gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py index dc0364a778..fe7a4a6cd8 100644 --- a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py @@ -84,20 +84,11 @@ def pattern(self, op, input_a, input_b, input_c): transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule() -def gemm_rule_set() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse MatMul + Add patterns into a single Gemm node, - handling cases where one or both MatMul inputs are transposed. - - Returns: - RewriteRuleSet - """ - - # Order is important - return RewriteRuleSet( - [ - transpose_ab_matmul_add_to_gemm_rule, - transpose_a_matmul_add_to_gemm_rule, - transpose_b_matmul_add_to_gemm_rule, - matmul_add_to_gemm_rule, - ] - ) +rules = RewriteRuleSet( + [ + transpose_ab_matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, + matmul_add_to_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index c43a21d9df..6bb3af095f 100644 --- a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -129,7 +129,7 @@ def test_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): bias_as_inputs=bias_as_inputs, ) updated_model = self.clone_model(base_model) - count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul + Add are fused into Gemm self.assertEqual(count, 1) @@ -176,7 +176,7 @@ def test_transpose_a_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transA=True, ) updated_model = self.clone_model(base_model) - count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(Transpose, W) + Add are fused into Gemm self.assertEqual(count, 1) @@ -225,7 +225,7 @@ def test_transpose_b_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transB=True, ) updated_model = self.clone_model(base_model) - count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(X, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) @@ -275,7 +275,7 @@ def test_transpose_ab_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inpu transB=True, ) updated_model = self.clone_model(base_model) - count = _matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(Transpose, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) diff --git a/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py index 5852e85dc3..cca5f36558 100644 --- a/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py @@ -107,7 +107,7 @@ def rewrite(self, op, updates, **_): return op.Identity(updates) -rule = ScatterAllDynamic.rule() -static_rule = ScatterAllStatic.rule() +no_op_dynamic_scatter_nd_rule = ScatterAllDynamic.rule() +no_op_static_scatter_nd_rule = ScatterAllStatic.rule() -rules = RewriteRuleSet([rule, static_rule]) +rules = RewriteRuleSet([no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule]) From a3bdf59eb22d266d5784304b4add5f110b5249b3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 2 Sep 2025 16:01:39 -0700 Subject: [PATCH 11/12] import Signed-off-by: Justin Chu --- .../rules/common/_broadcast_to_matmul_test.py | 24 +++++++++---------- .../rules/common/_gemm_to_matmul_add.py | 4 +++- .../rules/common/_matmul_add_to_gemm_test.py | 4 +++- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py index 39c4eaeb6e..4e33544986 100644 --- a/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py @@ -108,7 +108,7 @@ def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match( """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) model = _infer_shapes(model) @@ -151,7 +151,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest ) ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -178,7 +178,7 @@ def test_reshape_matmul_reshape_remain_when_input_last_dim_and_second_last_dim_n """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -202,7 +202,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_inputs_are_not_broadcast ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -226,7 +226,7 @@ def test_reshape_matmul_reshape_replace_when_inputs_are_broadcastable_with_one_i """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -249,7 +249,7 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_br """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -272,7 +272,7 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_se """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -295,7 +295,7 @@ def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -318,7 +318,7 @@ def test_reshape_matmul_reshape_replace_when_second_input_is_one_dimension_and_b """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -342,7 +342,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_second_input_is_one_dime ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -366,7 +366,7 @@ def test_reshape_matmul_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -387,7 +387,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) # The constant nodes are not removed. They should be removed by a subsequent DCE in optimizer. self.assertEqual(len(model.graph), 3) diff --git a/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py index ee0abf74f2..e51b4b22fa 100644 --- a/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py @@ -18,4 +18,6 @@ def matmul_add(op, input_a, input_b, input_c, **_): return op.Add(matmul, input_c) -gemm_to_matmul_add_rule = RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) +gemm_to_matmul_add_rule = RewriteRule( + reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape +) diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index 6bb3af095f..c4f9abe65c 100644 --- a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -101,7 +101,9 @@ def check_matmul_add_to_gemm_incompatible_shapes(self, **kwargs): updated_model = self.clone_model(base_model) tracer = MatchingTracer() - count = _matmul_add_to_gemm.matmul_add_to_gemm_rule.apply_to_model(updated_model, tracer=tracer) + count = _matmul_add_to_gemm.matmul_add_to_gemm_rule.apply_to_model( + updated_model, tracer=tracer + ) # Check that the model is unchanged self.assertEqual(count, 0) From 715f9754110fc456a8e1eba545f34d004cb1cebd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 2 Sep 2025 16:17:52 -0700 Subject: [PATCH 12/12] tests Signed-off-by: Justin Chu --- onnxscript/rewriter/pattern_test.py | 4 ++-- .../rules/common/_fuse_pad_into_conv_test.py | 24 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 30ea175984..49ace2fb81 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -13,7 +13,7 @@ from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op from onnxscript.rewriter import pattern -from onnxscript.rewriter.rules import cast_constant_of_shape +from onnxscript.rewriter.rules.common import _cast_constant_of_shape logger = logging.getLogger(__name__) @@ -307,7 +307,7 @@ def test_delayed_run_provides_correct_bindings_for_multiple_matches(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 2) self.assertEqual(len(model.graph), 2) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py index 58df13ac0c..740f8b3358 100644 --- a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -12,10 +12,10 @@ from onnxscript.rewriter import pattern as orp from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _fuse_pad_into_conv from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( - fuse_pad_into_conv, - fuse_pad_into_conv_rule_set, - normalize_pad_format_conv, + fuse_pad_into_conv_rule, + normalize_pad_format_conv_rule, ) @@ -118,7 +118,7 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads, conv_a updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) # Check that Pad was fused self.assertEqual(count, 1 if conv_auto_pad is None else 2) @@ -209,11 +209,11 @@ def test_unsupported_fuse_pad_into_conv( # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = fuse_pad_into_conv.apply_to_model(base_model, tracer=tracer) + count = fuse_pad_into_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[fuse_pad_into_conv][0] + tracer_match = tracer.best_matches_map[fuse_pad_into_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, err_msg) @@ -255,7 +255,7 @@ def test_fuse_pad_into_conv_integer( updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) # Check that Pad was fused self.assertEqual(count, 1 if conv_auto_pad is None else 2) @@ -344,7 +344,7 @@ def test_normalize_pad_format(self, dynamic_shape, strides, kernel_shape, auto_p updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) onnx_checker.CheckerPass(True)(updated_model) # Check conv has changed @@ -372,11 +372,11 @@ def test_unsupported_normalize_pad_format(self, input_shape, infer_shapes, error # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, error_msg) @@ -393,11 +393,11 @@ def test_unsupported_normalize_pad_format_on_weights(self): # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, "same length than kernel_shape")