From 51dd346030443484f0cdc6ae124ecbf89ac54247 Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Tue, 9 Sep 2025 12:18:26 +0200 Subject: [PATCH 1/2] refactor(fuse_batchnorm): use ClassVar --- .../rewriter/rules/common/_fuse_batchnorm.py | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py index a5ceb00468..9d8b8f23f4 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -15,7 +15,7 @@ """ from abc import ABC, abstractmethod -from typing import Mapping +from typing import ClassVar, Mapping import numpy as np @@ -33,16 +33,6 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra class _FuseBatchNormBase(RewriteRuleClassBase, ABC): """Interface for BatchNormalization nodes fusion.""" - def __init__( - self, - op_type: str, - name: str | None = None, - remove_nodes: bool = True, - as_function: bool = False, - ) -> None: - super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function) - self.op_type = op_type - @abstractmethod def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: """Return the axis along which BatchNorm scale should be broadcasted.""" @@ -116,8 +106,7 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M class FuseBatchNormIntoConv(_FuseBatchNormBase): """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" - def __init__(self): - super().__init__("Conv") + op_type: ClassVar = "Conv" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 0 @@ -133,8 +122,7 @@ def pattern(self, op, x): class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" - def __init__(self): - super().__init__("ConvTranspose") + op_type: ClassVar = "ConvTranspose" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 1 @@ -150,8 +138,7 @@ def pattern(self, op, x): class FuseBatchNormIntoGemm(_FuseBatchNormBase): """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" - def __init__(self): - super().__init__("Gemm") + op_type: ClassVar = "Gemm" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return ( From badd6af648596736ade95f8b767e4df8dec4bc4c Mon Sep 17 00:00:00 2001 From: AyoubMDL Date: Tue, 9 Sep 2025 12:18:53 +0200 Subject: [PATCH 2/2] [Rewriter]: add fuse_batchnorm to the default rules --- onnxscript/rewriter/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 232750af78..fc000dc176 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -35,6 +35,7 @@ _broadcast_to_matmul, _cast_constant_of_shape, _collapse_slices, + _fuse_batchnorm, _fuse_pad_into_conv, _fuse_relus_clips, _min_max_to_clip, @@ -53,6 +54,7 @@ *_basic_rules.basic_optimization_rules(), *_redundant_scatter_nd.rules, *_fuse_pad_into_conv.rules, + *_fuse_batchnorm.rules, )