From 07f3e4cdfc68395dd2879566df4f4b3c3cafb340 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Sep 2025 08:51:34 -0700 Subject: [PATCH 001/123] chore(deps): bump ruff from 0.12.10 to 0.12.11 in /requirements/lintrunner (#2535) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 41e736dcb4..a17c852e86 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.12.10 +ruff==0.12.11 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250402 From 8974f5ec189703d27b402b9d2d4cd8e03895b18f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 2 Sep 2025 22:46:33 +0200 Subject: [PATCH 002/123] Implements repeat_interleave (#2477) Similar to #2464. Does not support all the cases but we can add them in other PRs. --------- Signed-off-by: xadupre Co-authored-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 110 +++++++++++++++++- .../function_libs/torch_lib/e2e_ops_tests.py | 61 ++++++++++ .../function_libs/torch_lib/ops_test_data.py | 34 ++++++ 3 files changed, 201 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..2e6bf9530c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7292,12 +7292,114 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: return op.Tile(self_expanded, repeats) -def aten_repeat_interleave( - repeats: TensorType, output_size: Optional[int] = None +@torch_op("aten::repeat_interleave.self_int", trace_only=True) +def aten_repeat_interleave_self_int( + self: TensorType, repeats: int, dim: Optional[int] = None ) -> TensorType: - """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" + """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor - raise NotImplementedError() + The trick is to repeat in one direction orthogonal to reshape. + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat_interleave(2, dim=0) + + is equivalent to: + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat((1, 2)).reshape((-1, t.shape[1])) + """ + if dim is None: + raise NotImplementedError("No conversion available yet when dim is None.") + + self_rank = len(self.shape) + pos_dim = (dim + self_rank) % self_rank + unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) + tiles = [1] * (self_rank + 1) + tiles[pos_dim + 1] = repeats + tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) + tiled = op.Tile(unsqueezed, tile_repeat) + if self_rank == 1: + return op.Identity(tiled) + final_shape = op.Concat( + op.Shape(self, start=0, end=dim), + op.Constant(value_ints=[-1]), + op.Shape(self, start=dim + 1), + axis=0, + ) + return op.Reshape(tiled, final_shape) + + +@torch_op("aten::repeat_interleave.Tensor", trace_only=True) +def aten_repeat_interleave_Tensor( + self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None +) -> TensorType: + """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor + + When `repeats` is a tensor, each line is multiplied + by a different number. + There are multiple strategies. Here is one. + + .. code-block:: python + + import torch + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + times = torch.tensor([2, 3], dtype=torch.int64) + y = x.repeat_interleave(times, dim=0) + print("repeat_interleave") + print(y) + + ci = times.cumsum(dim=0) + rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1)) + srows = times.shape[0] - rows.to(torch.int64).sum(axis=0) + indices = srows.reshape((-1, )) + print("decomposed") + print(x[indices, :]) + """ + if repeats is None: + repeats = self + self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1) + if dim is None: + # flatten + self = op.Reshape(self, [-1]) + rk = 1 + else: + rk = len(self.shape) + + if rk > 2: + shape_x0 = op.Shape(self, start=0, end=1) + shape_x = op.Shape(self, start=1) + self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) + elif rk == 1: + shape_x = None + self = op.Reshape(self, [-1, 1]) + else: + if rk != 2: + raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave") + shape_x = None + + ci = op.CumSum(repeats, [0]) + last_ci = op.Gather(ci, [-1]) + trange = op.Range(0, op.Squeeze(last_ci, [0]), 1) + rows = op.Less(trange, op.Unsqueeze(ci, [-1])) + srows = op.Sub( + op.Shape(self, start=0, end=1), + op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]), + ) + indices = op.Reshape(srows, [-1]) + values = op.GatherND(self, op.Unsqueeze(indices, [-1])) + if rk == 2: + return values + # shape_x is None at this stage. + assert shape_x is None # for mypy + return op.Reshape( + values, + op.Concat([-1], shape_x, axis=0) if shape_x else [-1], + ) @torch_op("aten::reshape") diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index ab58bbc1a1..a0d0a0d880 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -76,6 +76,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer_1(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_integer_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind, dim=0) + + onnx_program = torch.onnx.export( + Model(), + ( + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + torch.tensor([1, 2], dtype=torch.int64), + ), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor_none(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind) + + inputs = ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.tensor([1, 2, 3, 2], dtype=torch.int64), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + dynamo=True, + optimize=False, + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "ind"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_sdpa_with_bool_attn_mask(self): class ScaledDotProductAttention(torch.nn.Module): def forward(self, query, key, value, attn_mask): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7af7413185..01db7161b5 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1250,6 +1250,40 @@ def _where_input_wrangler( core_ops.aten_remainder, ), TorchLibOpInfo("repeat", core_ops.aten_repeat), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) + .skip( + matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is a Tensor"), + ) + .skip( + dtypes=(torch.bool,), + reason="bool not supported", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor) + .skip( + matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is an int"), + ) + .skip( + dtypes=(torch.bool,), + reason="bool not supported", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), From 7b047742e83a82f867a3c7af873ac58eaaf0eb36 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 10:07:01 -0700 Subject: [PATCH 003/123] [torchlib] Modify aten_unbind to use None for split_sizes (#2536) According to https://onnx.ai/onnx/operators/onnx__SplitToSequence.html#summary, `If the argument split is not specified, a default scalar value of 1 is used as the value of split`, and this is the only case when `keepdims` can be set to `0`. Fixes https://github.com/microsoft/onnxscript/issues/2533 --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2e6bf9530c..e950699aca 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8718,12 +8718,11 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: return op.CastLike(self, other) -@torch_op("aten::unbind.int") +@torch_op("aten::unbind.int", trace_only=True) def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - split_sizes = op.Constant(value_int=1) - return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) + return op.SplitToSequence(self, axis=dim, keepdims=False) @torch_op("aten::unflatten.int", trace_only=True) From 54de7417bea31fdebb6b082ea1cacc4632e1fc81 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 12:03:13 -0700 Subject: [PATCH 004/123] Refactor rewrite rules into the rewriter.rules namespace (#2531) Organize all rules into a directory that is not with the rewriter infrastructure: - `onnxscript.rewriter.rules.common.*` for existing rules - `onnxscript.rewriter.rules.fusion.*` for onnx fusion rules --------- Signed-off-by: Justin Chu --- onnxscript/rewriter/__init__.py | 38 +++---- .../rewriter/onnx_fusions/_onnx_fusions.py | 2 +- .../onnx_fusions/_onnx_fusions_test.py | 2 +- onnxscript/rewriter/ort_fusions/_core.py | 5 +- onnxscript/rewriter/pattern_test.py | 5 +- onnxscript/rewriter/rules/__init__.py | 2 + onnxscript/rewriter/rules/common/__init__.py | 103 ++++++++++++++++++ .../common/_basic_rules.py} | 12 +- .../common/_basic_rules_test.py} | 20 ++-- .../common/_broadcast_to_matmul.py} | 0 .../common/_broadcast_to_matmul_test.py} | 28 ++--- .../common/_cast_constant_of_shape.py} | 0 .../common/_cast_constant_of_shape_test.py} | 6 +- .../common/_collapse_slices.py} | 6 +- .../common/_collapse_slices_test.py} | 14 +-- .../common/_fuse_batchnorm.py} | 23 ++-- .../common/_fuse_batchnorm_test.py} | 13 ++- .../common/_fuse_pad_into_conv.py} | 36 +++--- .../common/_fuse_pad_into_conv_test.py} | 26 ++--- .../common/_fuse_relus_clips.py} | 36 +++--- .../common/_fuse_relus_clips_test.py} | 24 ++-- .../common/_gemm_to_matmul_add.py} | 6 +- .../common/_gemm_to_matmul_add_test.py} | 26 ++--- .../common/_matmul_add_to_gemm.py} | 25 ++--- .../common/_matmul_add_to_gemm_test.py} | 18 +-- .../{no_op.py => rules/common/_no_op.py} | 0 .../common/_no_op_test.py} | 4 +- .../common/_redundant_scatter_nd.py} | 6 +- .../common/_redundant_scatter_nd_test.py} | 6 +- onnxscript/rewriter/rules/fusion/__init__.py | 2 + .../fusion}/_layer_norm.py | 0 .../fusion}/_layer_norm_test.py | 2 +- .../fusion}/_rms_normalization.py | 0 .../fusion}/_rotary_embedding.py | 0 34 files changed, 289 insertions(+), 207 deletions(-) create mode 100644 onnxscript/rewriter/rules/__init__.py create mode 100644 onnxscript/rewriter/rules/common/__init__.py rename onnxscript/rewriter/{basic_rules.py => rules/common/_basic_rules.py} (98%) rename onnxscript/rewriter/{basic_rules_test.py => rules/common/_basic_rules_test.py} (96%) rename onnxscript/rewriter/{broadcast_to_matmul.py => rules/common/_broadcast_to_matmul.py} (100%) rename onnxscript/rewriter/{broadcast_to_matmul_test.py => rules/common/_broadcast_to_matmul_test.py} (94%) rename onnxscript/rewriter/{cast_constant_of_shape.py => rules/common/_cast_constant_of_shape.py} (100%) rename onnxscript/rewriter/{cast_constant_of_shape_test.py => rules/common/_cast_constant_of_shape_test.py} (89%) rename onnxscript/rewriter/{collapse_slices.py => rules/common/_collapse_slices.py} (95%) rename onnxscript/rewriter/{collapse_slices_test.py => rules/common/_collapse_slices_test.py} (91%) rename onnxscript/rewriter/{fuse_batchnorm.py => rules/common/_fuse_batchnorm.py} (92%) rename onnxscript/rewriter/{fuse_batchnorm_test.py => rules/common/_fuse_batchnorm_test.py} (94%) rename onnxscript/rewriter/{fuse_pad_into_conv.py => rules/common/_fuse_pad_into_conv.py} (95%) rename onnxscript/rewriter/{fuse_pad_into_conv_test.py => rules/common/_fuse_pad_into_conv_test.py} (95%) rename onnxscript/rewriter/{fuse_relus_clips.py => rules/common/_fuse_relus_clips.py} (89%) rename onnxscript/rewriter/{fuse_relus_clips_test.py => rules/common/_fuse_relus_clips_test.py} (94%) rename onnxscript/rewriter/{gemm_to_matmul_add.py => rules/common/_gemm_to_matmul_add.py} (76%) rename onnxscript/rewriter/{gemm_to_matmul_add_test.py => rules/common/_gemm_to_matmul_add_test.py} (92%) rename onnxscript/rewriter/{matmul_add_to_gemm.py => rules/common/_matmul_add_to_gemm.py} (84%) rename onnxscript/rewriter/{matmul_add_to_gemm_test.py => rules/common/_matmul_add_to_gemm_test.py} (94%) rename onnxscript/rewriter/{no_op.py => rules/common/_no_op.py} (100%) rename onnxscript/rewriter/{no_op_test.py => rules/common/_no_op_test.py} (98%) rename onnxscript/rewriter/{redundant_scatter_nd.py => rules/common/_redundant_scatter_nd.py} (96%) rename onnxscript/rewriter/{redundant_scatter_nd_test.py => rules/common/_redundant_scatter_nd_test.py} (96%) create mode 100644 onnxscript/rewriter/rules/fusion/__init__.py rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_layer_norm.py (100%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_layer_norm_test.py (98%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_rms_normalization.py (100%) rename onnxscript/rewriter/{onnx_fusions => rules/fusion}/_rotary_embedding.py (100%) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index d3e7a7891e..1d07e9f5af 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -22,17 +22,7 @@ import onnx_ir.passes.common as common_passes from onnxscript import ir -from onnxscript.rewriter import ( - basic_rules, - broadcast_to_matmul, - cast_constant_of_shape, - collapse_slices, - fuse_pad_into_conv, - fuse_relus_clips, - no_op, - pattern, - redundant_scatter_nd, -) +from onnxscript.rewriter import pattern from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus from onnxscript.rewriter._rewrite_rule import ( RewriterContext, @@ -40,17 +30,27 @@ RewriteRuleClassBase, RewriteRuleSet, ) +from onnxscript.rewriter.rules.common import ( + _basic_rules, + _broadcast_to_matmul, + _cast_constant_of_shape, + _collapse_slices, + _fuse_pad_into_conv, + _fuse_relus_clips, + _no_op, + _redundant_scatter_nd, +) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( - *no_op.rules.rules, # TODO: merge this rule into constant folding? - *broadcast_to_matmul.rules.rules, - *cast_constant_of_shape.rules.rules, - *collapse_slices.rules.rules, - *fuse_relus_clips.fuse_relus_clips_rules().rules, - *basic_rules.basic_optimization_rules().rules, - *redundant_scatter_nd.rules.rules, - *fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules, + *_no_op.rules, # TODO: merge this rule into constant folding? + *_broadcast_to_matmul.rules, + *_cast_constant_of_shape.rules, + *_collapse_slices.rules, + *_fuse_relus_clips.rules, + *_basic_rules.basic_optimization_rules(), + *_redundant_scatter_nd.rules, + *_fuse_pad_into_conv.rules, ) diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 0a45f3017c..bd73cb1f6d 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -4,7 +4,7 @@ import onnx_ir as ir -from onnxscript.rewriter.onnx_fusions import _rms_normalization, _rotary_embedding +from onnxscript.rewriter.rules.fusion import _rms_normalization, _rotary_embedding def _get_onnx_opset_version(model: ir.Model) -> int | None: diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py index 59a460005a..22d6120da1 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py @@ -8,7 +8,7 @@ from parameterized import parameterized import onnxscript -import onnxscript.rewriter.onnx_fusions as onnx_fusions +from onnxscript.rewriter import onnx_fusions from onnxscript.rewriter.models import _rotary_embedding_models diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index faca1f9ba8..8f3c7c463a 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -8,7 +8,7 @@ import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization from onnxscript.optimizer import optimize -from onnxscript.rewriter import gemm_to_matmul_add, rewrite +from onnxscript.rewriter import rewrite from onnxscript.rewriter.ort_fusions import ( instance_to_group_normalization, softmax, @@ -33,6 +33,7 @@ fuse_skip_layer_normalization, fuse_skip_rms_normalization, ) +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, @@ -133,7 +134,7 @@ def optimize_for_ort( - The optimized `ir.Model` after applying transformer-specific fusions. - A dictionary with a count of each of the fusions applied. """ - rewrite(model, [gemm_to_matmul_add.rule]) + rewrite(model, [_gemm_to_matmul_add.gemm_to_matmul_add_rule]) model, fusion_count = fuse_xformers( model, debug=debug, diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index bf5940e97c..49ace2fb81 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -12,7 +12,8 @@ import onnxscript.optimizer from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op -from onnxscript.rewriter import cast_constant_of_shape, pattern +from onnxscript.rewriter import pattern +from onnxscript.rewriter.rules.common import _cast_constant_of_shape logger = logging.getLogger(__name__) @@ -306,7 +307,7 @@ def test_delayed_run_provides_correct_bindings_for_multiple_matches(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 2) self.assertEqual(len(model.graph), 2) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) diff --git a/onnxscript/rewriter/rules/__init__.py b/onnxscript/rewriter/rules/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py new file mode 100644 index 0000000000..752e3c9430 --- /dev/null +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +__all__ = [ + "add_0_rule", + "cast_cast_rule", + "cast_constant_of_shape_rule", + "cast_constant_of_shape_without_value_rule", + "collapse_slice_rule", + "collapse_slice2_rule", + "div_by_1_rule", + "dropout_inference_rule", + "dropout_zero_rule", + "fuse_batchnorm_into_conv_rule", + "fuse_batchnorm_into_conv_transpose_rule", + "fuse_batchnorm_into_gemm_rule", + "fuse_pad_into_conv_integer_rule", + "fuse_pad_into_conv_rule", + "gemm_to_matmul_add_rule", + "matmul_add_to_gemm_rule", + "mul_by_1_rule", + "no_op_cast_rule", + "no_op_dynamic_scatter_nd_rule", + "no_op_expand_rule", + "no_op_static_scatter_nd_rule", + "no_op_transpose_rule", + "normalize_pad_format_conv_integer_rule", + "normalize_pad_format_conv_rule", + "one_reshape_matmul_reshape_rule", + "reshape_reshape_rule", + "slice_split_rule", + "squeeze_reshape_1d_rule", + "sub_0_rule", + "successive_clip_relu_rule", + "successive_clip_rule", + "successive_relu_clip_rule", + "successive_relu_rule", + "transpose_a_matmul_add_to_gemm_rule", + "transpose_ab_matmul_add_to_gemm_rule", + "transpose_b_matmul_add_to_gemm_rule", + "transpose_transpose_rule", + "two_reshapes_matmul_reshape_rule", + "unsqueeze_unsqueeze_rule", +] + +from onnxscript.rewriter.rules.common._basic_rules import ( + cast_cast_rule, + no_op_cast_rule, + no_op_expand_rule, + no_op_transpose_rule, + reshape_reshape_rule, + slice_split_rule, + squeeze_reshape_1d_rule, + transpose_transpose_rule, + unsqueeze_unsqueeze_rule, +) +from onnxscript.rewriter.rules.common._broadcast_to_matmul import ( + one_reshape_matmul_reshape_rule, + two_reshapes_matmul_reshape_rule, +) +from onnxscript.rewriter.rules.common._cast_constant_of_shape import ( + cast_constant_of_shape_rule, + cast_constant_of_shape_without_value_rule, +) +from onnxscript.rewriter.rules.common._collapse_slices import ( + collapse_slice2_rule, + collapse_slice_rule, +) +from onnxscript.rewriter.rules.common._fuse_batchnorm import ( + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, +) +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( + fuse_pad_into_conv_integer_rule, + fuse_pad_into_conv_rule, + normalize_pad_format_conv_integer_rule, + normalize_pad_format_conv_rule, +) +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, + successive_relu_rule, +) +from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule +from onnxscript.rewriter.rules.common._matmul_add_to_gemm import ( + matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_ab_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, +) +from onnxscript.rewriter.rules.common._no_op import ( + add_0_rule, + div_by_1_rule, + dropout_inference_rule, + dropout_zero_rule, + mul_by_1_rule, + sub_0_rule, +) +from onnxscript.rewriter.rules.common._redundant_scatter_nd import ( + no_op_dynamic_scatter_nd_rule, + no_op_static_scatter_nd_rule, +) diff --git a/onnxscript/rewriter/basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py similarity index 98% rename from onnxscript/rewriter/basic_rules.py rename to onnxscript/rewriter/rules/common/_basic_rules.py index 2788cb7cda..6f38050f3e 100644 --- a/onnxscript/rewriter/basic_rules.py +++ b/onnxscript/rewriter/rules/common/_basic_rules.py @@ -281,11 +281,11 @@ def check(self, context, x, axes1, axes2) -> MatchResult: # Create rule instances cast_cast_rule = CastCast.rule() -cast_identity_rule = CastIdentity.rule() -expand_identity_rule = ExpandIdentity.rule() +no_op_cast_rule = CastIdentity.rule() +no_op_expand_rule = ExpandIdentity.rule() reshape_reshape_rule = ReshapeReshape.rule() slice_split_rule = SlicesSplit.rule() -transpose_identity_rule = TransposeIdentity.rule() +no_op_transpose_rule = TransposeIdentity.rule() transpose_transpose_rule = TransposeTranspose.rule() unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule() squeeze_reshape_1d_rule = SqueezeReshape.rule() @@ -309,11 +309,11 @@ def basic_optimization_rules() -> RewriteRuleSet: return RewriteRuleSet( [ cast_cast_rule, - cast_identity_rule, - expand_identity_rule, + no_op_cast_rule, + no_op_expand_rule, reshape_reshape_rule, slice_split_rule, - transpose_identity_rule, + no_op_transpose_rule, transpose_transpose_rule, unsqueeze_unsqueeze_rule, squeeze_reshape_1d_rule, diff --git a/onnxscript/rewriter/basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py similarity index 96% rename from onnxscript/rewriter/basic_rules_test.py rename to onnxscript/rewriter/rules/common/_basic_rules_test.py index bcb6db4aa8..8709300763 100644 --- a/onnxscript/rewriter/basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -12,9 +12,9 @@ import onnxscript import onnxscript.onnx_types as ot -import onnxscript.rewriter.basic_rules as basic_rules from onnxscript import ir from onnxscript.onnx_opset import opset18 +from onnxscript.rewriter.rules.common import _basic_rules FLOAT = onnx.TensorProto.FLOAT @@ -98,7 +98,7 @@ def _check_model( ] ) def test_basic_optimization_rules_identity(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -126,7 +126,7 @@ def test_basic_optimization_rules_identity(self, _: str, model: ir.Model): ] ) def test_basic_optimization_rules_transpose_transpose(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -153,7 +153,7 @@ def cast_cast_model(x): ] ) def test_cast_cast_rule(self, _: str, type1, type2, type3): - rule = basic_rules.cast_cast_rule + rule = _basic_rules.cast_cast_rule model_proto = self._double_cast_model(type1, type2, type3) model = ir.serde.deserialize_model(model_proto) rule.apply_to_model(model) @@ -172,7 +172,7 @@ def test_cast_cast_rule(self, _: str, type1, type2, type3): ] ) def test_cast_identity_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -228,7 +228,7 @@ def test_cast_identity_rule(self, _: str, model: ir.Model): def test_expand_identity_rule( self, _: str, model: ir.Model, expected_nodes: tuple[str, ...] ): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -310,7 +310,7 @@ def test_expand_identity_rule( ] ) def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -369,7 +369,7 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): ] ) def test_reshape_reshape_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -420,7 +420,7 @@ def _slices_split_models(cls): def test_slices_split_rule(self): for model_proto in self._slices_split_models(): ir_model = ir.serde.deserialize_model(model_proto) - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) @@ -428,7 +428,7 @@ def test_slices_split_rule(self): self._check_model(model_proto, rewritten_model) def test_squeeze_reshape_1d_rule(self): - rule = basic_rules.squeeze_reshape_1d_rule + rule = _basic_rules.squeeze_reshape_1d_rule def check(model_script, expected_count) -> None: model_proto = model_script.to_model_proto() diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul.py similarity index 100% rename from onnxscript/rewriter/broadcast_to_matmul.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul.py diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py similarity index 94% rename from onnxscript/rewriter/broadcast_to_matmul_test.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py index c2f3b31f90..4e33544986 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py @@ -9,7 +9,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import broadcast_to_matmul +from onnxscript.rewriter.rules.common import _broadcast_to_matmul def _infer_shapes(model: ir.Model) -> ir.Model: @@ -38,7 +38,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -108,7 +108,7 @@ def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match( """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) model = _infer_shapes(model) @@ -151,7 +151,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest ) ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -178,7 +178,7 @@ def test_reshape_matmul_reshape_remain_when_input_last_dim_and_second_last_dim_n """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -202,7 +202,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_inputs_are_not_broadcast ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -226,7 +226,7 @@ def test_reshape_matmul_reshape_replace_when_inputs_are_broadcastable_with_one_i """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -249,7 +249,7 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_br """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -272,7 +272,7 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_se """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -295,7 +295,7 @@ def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -318,7 +318,7 @@ def test_reshape_matmul_reshape_replace_when_second_input_is_one_dimension_and_b """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -342,7 +342,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_second_input_is_one_dime ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -366,7 +366,7 @@ def test_reshape_matmul_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -387,7 +387,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) # The constant nodes are not removed. They should be removed by a subsequent DCE in optimizer. self.assertEqual(len(model.graph), 3) diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape.py similarity index 100% rename from onnxscript/rewriter/cast_constant_of_shape.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape.py diff --git a/onnxscript/rewriter/cast_constant_of_shape_test.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py similarity index 89% rename from onnxscript/rewriter/cast_constant_of_shape_test.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py index 35151e17d9..794491024b 100644 --- a/onnxscript/rewriter/cast_constant_of_shape_test.py +++ b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py @@ -6,7 +6,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter import cast_constant_of_shape +from onnxscript.rewriter.rules.common import _cast_constant_of_shape class CastConstantOfShapeTest(unittest.TestCase): @@ -23,7 +23,7 @@ def test_cast_after_constant_of_shape_is_fused(self): ) onnx.checker.check_model(input_model_proto, True) model = ir.serde.deserialize_model(input_model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) @@ -42,7 +42,7 @@ def test_cast_after_constant_of_shape_without_value_is_fused(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py similarity index 95% rename from onnxscript/rewriter/collapse_slices.py rename to onnxscript/rewriter/rules/common/_collapse_slices.py index 291128157d..5e262a785e 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -89,13 +89,13 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ # Register the rewrite rules -remove_redundant_slice = RewriteRule( +collapse_slice_rule = RewriteRule( _potential_redundant_slice, _identity_to_itself, _check_if_redundant_slice, ) -remove_redundant_slice2 = RewriteRule( +collapse_slice2_rule = RewriteRule( _potential_redundant_slice, _identity_to_itself, _same_shape, @@ -104,4 +104,4 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ # NOTE: The second rule subsumes the first one. So, we may be able to remove the first one, # provided shape-inference is run before the rewriter and computes the shape of the slice output. -rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2]) +rules = RewriteRuleSet([collapse_slice_rule, collapse_slice2_rule]) diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/rules/common/_collapse_slices_test.py similarity index 91% rename from onnxscript/rewriter/collapse_slices_test.py rename to onnxscript/rewriter/rules/common/_collapse_slices_test.py index 52b59f9037..727240344d 100644 --- a/onnxscript/rewriter/collapse_slices_test.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices_test.py @@ -6,10 +6,10 @@ import numpy as np import onnx.parser -import onnx.shape_inference from onnxscript import ir -from onnxscript.rewriter import collapse_slices, testing +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _collapse_slices _INT64_MAX = 9223372036854775807 @@ -30,7 +30,7 @@ def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) self.assertIn("Identity", [node.op_type for node in model.graph]) @@ -55,7 +55,7 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) self.assertIn("Identity", [node.op_type for node in model.graph]) @@ -80,7 +80,7 @@ def test_slice_unequal_dynamic_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 0) def test_slice_equal_dynamic_shape(self): @@ -98,7 +98,7 @@ def test_slice_equal_dynamic_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) def test_slice_equal_dynamic_shape_but_step_reverse(self): @@ -116,6 +116,6 @@ def test_slice_equal_dynamic_shape_but_step_reverse(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) # Should not change the output shape if we did not use the default step of 1 self.assertEqual(count, 0) diff --git a/onnxscript/rewriter/fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py similarity index 92% rename from onnxscript/rewriter/fuse_batchnorm.py rename to onnxscript/rewriter/rules/common/_fuse_batchnorm.py index 51e4e20db3..a5ceb00468 100644 --- a/onnxscript/rewriter/fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -167,21 +167,14 @@ def pattern(self, op, x): fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule() -fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule() +fuse_batchnorm_into_conv_transpose_rule = FuseBatchNormIntoConvTranspose().rule() fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule() -def fuse_batchnorm_rule_set() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse BatchNormalization nodes - into preceding nodes such as Conv, ConvTranspose, and Gemm. - - Returns: - RewriteRuleSet - """ - return RewriteRuleSet( - [ - fuse_batchnorm_into_conv_rule, - fuse_batchnorm_into_convtranspose_rule, - fuse_batchnorm_into_gemm_rule, - ] - ) +rules = RewriteRuleSet( + [ + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py similarity index 94% rename from onnxscript/rewriter/fuse_batchnorm_test.py rename to onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py index 20d272abd7..3e617340ff 100644 --- a/onnxscript/rewriter/fuse_batchnorm_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py @@ -8,7 +8,8 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import fuse_batchnorm, testing +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _fuse_batchnorm class FuseBatchnormTest(unittest.TestCase): @@ -73,7 +74,7 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -132,7 +133,7 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -196,7 +197,7 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -223,7 +224,7 @@ def test_fuse_batchnorm_non_initializers(self): """) onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # No changes were applied self.assertEqual(count, 0) @@ -247,7 +248,7 @@ def test_fuse_batchnorm_graph_inputs(self): onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # No changes were applied as W is a graph input self.assertEqual(count, 0) diff --git a/onnxscript/rewriter/fuse_pad_into_conv.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py similarity index 95% rename from onnxscript/rewriter/fuse_pad_into_conv.py rename to onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py index 7aeae57ccd..39aab00eda 100644 --- a/onnxscript/rewriter/fuse_pad_into_conv.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py @@ -327,25 +327,17 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: return op.ConvInteger(x, _allow_other_inputs=True, _outputs=["conv"]) -normalize_pad_format_conv = NormalizePadFormatConv.rule() -normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule() -fuse_pad_into_conv = FuseConvPad.rule() -fuse_pad_into_conv_integer = FuseConvIntegerPad.rule() - - -def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet: - """Returns a set of rewrite rules that fuse Pad nodes into preceding: - - Conv - - ConvInteger - - Returns: - RewriteRuleSet - """ - return orp.RewriteRuleSet( - [ - normalize_pad_format_conv, - normalize_pad_format_conv_integer, - fuse_pad_into_conv, - fuse_pad_into_conv_integer, - ] - ) +normalize_pad_format_conv_rule = NormalizePadFormatConv.rule() +normalize_pad_format_conv_integer_rule = NormalizePadFormatConvInteger.rule() +fuse_pad_into_conv_rule = FuseConvPad.rule() +fuse_pad_into_conv_integer_rule = FuseConvIntegerPad.rule() + + +rules = orp.RewriteRuleSet( + [ + normalize_pad_format_conv_rule, + normalize_pad_format_conv_integer_rule, + fuse_pad_into_conv_rule, + fuse_pad_into_conv_integer_rule, + ] +) diff --git a/onnxscript/rewriter/fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py similarity index 95% rename from onnxscript/rewriter/fuse_pad_into_conv_test.py rename to onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py index dfbf117bd1..740f8b3358 100644 --- a/onnxscript/rewriter/fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -12,10 +12,10 @@ from onnxscript.rewriter import pattern as orp from onnxscript.rewriter import testing -from onnxscript.rewriter.fuse_pad_into_conv import ( - fuse_pad_into_conv, - fuse_pad_into_conv_rule_set, - normalize_pad_format_conv, +from onnxscript.rewriter.rules.common import _fuse_pad_into_conv +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( + fuse_pad_into_conv_rule, + normalize_pad_format_conv_rule, ) @@ -118,7 +118,7 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads, conv_a updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) # Check that Pad was fused self.assertEqual(count, 1 if conv_auto_pad is None else 2) @@ -209,11 +209,11 @@ def test_unsupported_fuse_pad_into_conv( # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = fuse_pad_into_conv.apply_to_model(base_model, tracer=tracer) + count = fuse_pad_into_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[fuse_pad_into_conv][0] + tracer_match = tracer.best_matches_map[fuse_pad_into_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, err_msg) @@ -255,7 +255,7 @@ def test_fuse_pad_into_conv_integer( updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) # Check that Pad was fused self.assertEqual(count, 1 if conv_auto_pad is None else 2) @@ -344,7 +344,7 @@ def test_normalize_pad_format(self, dynamic_shape, strides, kernel_shape, auto_p updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) onnx_checker.CheckerPass(True)(updated_model) # Check conv has changed @@ -372,11 +372,11 @@ def test_unsupported_normalize_pad_format(self, input_shape, infer_shapes, error # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, error_msg) @@ -393,11 +393,11 @@ def test_unsupported_normalize_pad_format_on_weights(self): # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, "same length than kernel_shape") diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py similarity index 89% rename from onnxscript/rewriter/fuse_relus_clips.py rename to onnxscript/rewriter/rules/common/_fuse_relus_clips.py index 484ca679fc..5d294cdbd7 100644 --- a/onnxscript/rewriter/fuse_relus_clips.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py @@ -169,25 +169,17 @@ def pattern(self, op, x): return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"])) -fuse_successive_relu_rule = FuseSuccessiveRelu().rule() -fuse_successive_clip_rule = FuseSuccessiveClip().rule() -fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() -fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule() - - -def fuse_relus_clips_rules() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse successive Relu/Clip nodes. - - Returns: - RewriteRuleSet - """ - - # Order is important - return RewriteRuleSet( - [ - fuse_successive_clip_relu_rule, - fuse_successive_relu_clip_rule, - fuse_successive_relu_rule, - fuse_successive_clip_rule, - ] - ) +successive_relu_rule = FuseSuccessiveRelu().rule() +successive_clip_rule = FuseSuccessiveClip().rule() +successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() +successive_relu_clip_rule = FuseSuccessiveReluClip().rule() + + +rules = RewriteRuleSet( + [ + successive_clip_relu_rule, + successive_relu_clip_rule, + successive_relu_rule, + successive_clip_rule, + ] +) diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py similarity index 94% rename from onnxscript/rewriter/fuse_relus_clips_test.py rename to onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py index d58b493fb4..df2d669930 100644 --- a/onnxscript/rewriter/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py @@ -13,13 +13,13 @@ MatchingTracer, MatchStatus, RewriteRule, - fuse_relus_clips, testing, ) -from onnxscript.rewriter.fuse_relus_clips import ( - fuse_successive_clip_relu_rule, - fuse_successive_clip_rule, - fuse_successive_relu_clip_rule, +from onnxscript.rewriter.rules.common import _fuse_relus_clips +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, ) @@ -40,7 +40,7 @@ def run_test( onnx_checker.CheckerPass(True)(base_model) base_model = shape_inference.infer_shapes(base_model) updated_model = self.clone_model(base_model) - _ = fuse_relus_clips.fuse_relus_clips_rules().apply_to_model(updated_model) + _ = _fuse_relus_clips.rules.apply_to_model(updated_model) # Check expected op_types self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) @@ -214,7 +214,7 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): x1 = Relu(X) Y = Clip(x1, min) """, - fuse_successive_clip_relu_rule, + successive_clip_relu_rule, ), ( "clip_then_relu", @@ -222,7 +222,7 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): x1 = Clip(X, min) Y = Relu(x1) """, - fuse_successive_relu_clip_rule, + successive_relu_clip_rule, ), ] ) @@ -245,7 +245,7 @@ def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite x1 = Relu(X) Y = Clip(x1, min) """, - fuse_successive_clip_relu_rule, + successive_clip_relu_rule, ), ( "clip_then_relu", @@ -253,7 +253,7 @@ def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite x1 = Clip(X, min) Y = Relu(x1) """, - fuse_successive_relu_clip_rule, + successive_relu_clip_rule, ), ] ) @@ -334,7 +334,7 @@ def test_fail_fuse_successive_clips_non_initializers(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test(model, fuse_successive_clip_rule, "is not a constant.") + self.run_failed_condition_test(model, successive_clip_rule, "is not a constant.") def test_fail_fuse_successive_clips_graph_inputs(self): model = ir.from_onnx_text(""" @@ -346,7 +346,7 @@ def test_fail_fuse_successive_clips_graph_inputs(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test(model, fuse_successive_clip_rule, "is a graph input.") + self.run_failed_condition_test(model, successive_clip_rule, "is a graph input.") class FuseReluClipIntegrationTest(_FuseReluClipTestBase): diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py similarity index 76% rename from onnxscript/rewriter/gemm_to_matmul_add.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py index 09666466d3..e51b4b22fa 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from onnxscript.rewriter._rewrite_rule import RewriteRule -from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape +from onnxscript.rewriter.rules.common._broadcast_to_matmul import check_if_not_need_reshape # Pattern to match against @@ -18,4 +18,6 @@ def matmul_add(op, input_a, input_b, input_c, **_): return op.Add(matmul, input_c) -rule = RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) +gemm_to_matmul_add_rule = RewriteRule( + reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape +) diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py similarity index 92% rename from onnxscript/rewriter/gemm_to_matmul_add_test.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py index aab56cc3fe..90551d8d3b 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py @@ -5,7 +5,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter import gemm_to_matmul_add +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add class ReshapeGemmReshapeTest(unittest.TestCase): @@ -25,7 +25,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable(self): ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -70,7 +70,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable_in_nested ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -94,7 +94,7 @@ def test_reshape_gemm_reshape_remain_when_input_last_dim_and_second_last_dim_not """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -115,7 +115,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_not_broadcastable( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -136,7 +136,7 @@ def test_reshape_gemm_reshape_replace_when_inputs_are_broadcastable_with_one_in_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -159,7 +159,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -182,7 +182,7 @@ def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_b """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -203,7 +203,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -226,7 +226,7 @@ def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -247,7 +247,7 @@ def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broad """ ) model = ir.serde.deserialize_model(model_proto) - replacement_count = gemm_to_matmul_add.rule.apply_to_model(model) + replacement_count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(replacement_count, 1) self.assertEqual(len(model.graph), 4) @@ -268,7 +268,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broad """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -289,7 +289,7 @@ def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) diff --git a/onnxscript/rewriter/matmul_add_to_gemm.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py similarity index 84% rename from onnxscript/rewriter/matmul_add_to_gemm.py rename to onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py index dc0364a778..fe7a4a6cd8 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py @@ -84,20 +84,11 @@ def pattern(self, op, input_a, input_b, input_c): transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule() -def gemm_rule_set() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse MatMul + Add patterns into a single Gemm node, - handling cases where one or both MatMul inputs are transposed. - - Returns: - RewriteRuleSet - """ - - # Order is important - return RewriteRuleSet( - [ - transpose_ab_matmul_add_to_gemm_rule, - transpose_a_matmul_add_to_gemm_rule, - transpose_b_matmul_add_to_gemm_rule, - matmul_add_to_gemm_rule, - ] - ) +rules = RewriteRuleSet( + [ + transpose_ab_matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, + matmul_add_to_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py similarity index 94% rename from onnxscript/rewriter/matmul_add_to_gemm_test.py rename to onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index fd08125807..c4f9abe65c 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -9,8 +9,8 @@ from parameterized import parameterized from onnxscript import ir -from onnxscript.rewriter import MatchingTracer, MatchStatus, matmul_add_to_gemm, testing -from onnxscript.rewriter.matmul_add_to_gemm import matmul_add_to_gemm_rule +from onnxscript.rewriter import MatchingTracer, MatchStatus, testing +from onnxscript.rewriter.rules.common import _matmul_add_to_gemm class _MatMulAddToGemmTestBase(unittest.TestCase): @@ -101,13 +101,15 @@ def check_matmul_add_to_gemm_incompatible_shapes(self, **kwargs): updated_model = self.clone_model(base_model) tracer = MatchingTracer() - count = matmul_add_to_gemm_rule.apply_to_model(updated_model, tracer=tracer) + count = _matmul_add_to_gemm.matmul_add_to_gemm_rule.apply_to_model( + updated_model, tracer=tracer + ) # Check that the model is unchanged self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[matmul_add_to_gemm_rule][0] + tracer_match = tracer.best_matches_map[_matmul_add_to_gemm.matmul_add_to_gemm_rule][0] self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) self.assertRegex( tracer_match.match_result.reason, "Rank of input_a and input_b must be 2" @@ -129,7 +131,7 @@ def test_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): bias_as_inputs=bias_as_inputs, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul + Add are fused into Gemm self.assertEqual(count, 1) @@ -176,7 +178,7 @@ def test_transpose_a_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transA=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(Transpose, W) + Add are fused into Gemm self.assertEqual(count, 1) @@ -225,7 +227,7 @@ def test_transpose_b_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transB=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(X, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) @@ -275,7 +277,7 @@ def test_transpose_ab_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inpu transB=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(Transpose, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/rules/common/_no_op.py similarity index 100% rename from onnxscript/rewriter/no_op.py rename to onnxscript/rewriter/rules/common/_no_op.py diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/rules/common/_no_op_test.py similarity index 98% rename from onnxscript/rewriter/no_op_test.py rename to onnxscript/rewriter/rules/common/_no_op_test.py index 2b2a57f32a..7815473e34 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/rules/common/_no_op_test.py @@ -5,13 +5,13 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import no_op +from onnxscript.rewriter.rules.common import _no_op class NoOpTest(unittest.TestCase): def _check(self, model_text: str) -> None: model = ir.from_onnx_text(model_text) - count = no_op.rules.apply_to_model(model) + count = _no_op.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(model.graph[-1].op_type, "Identity") diff --git a/onnxscript/rewriter/redundant_scatter_nd.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py similarity index 96% rename from onnxscript/rewriter/redundant_scatter_nd.py rename to onnxscript/rewriter/rules/common/_redundant_scatter_nd.py index 5852e85dc3..cca5f36558 100644 --- a/onnxscript/rewriter/redundant_scatter_nd.py +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py @@ -107,7 +107,7 @@ def rewrite(self, op, updates, **_): return op.Identity(updates) -rule = ScatterAllDynamic.rule() -static_rule = ScatterAllStatic.rule() +no_op_dynamic_scatter_nd_rule = ScatterAllDynamic.rule() +no_op_static_scatter_nd_rule = ScatterAllStatic.rule() -rules = RewriteRuleSet([rule, static_rule]) +rules = RewriteRuleSet([no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule]) diff --git a/onnxscript/rewriter/redundant_scatter_nd_test.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py similarity index 96% rename from onnxscript/rewriter/redundant_scatter_nd_test.py rename to onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py index d2ba51eec4..96e3bcc80c 100644 --- a/onnxscript/rewriter/redundant_scatter_nd_test.py +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py @@ -13,7 +13,7 @@ import onnxscript.optimizer from onnxscript import FLOAT, script from onnxscript import opset18 as op -from onnxscript.rewriter import redundant_scatter_nd +from onnxscript.rewriter.rules.common import _redundant_scatter_nd shape_inference = ShapeInferencePass() onnx_check = CheckerPass(True) @@ -48,7 +48,7 @@ def model_script( onnx_check(model) shape_inference(model) onnxscript.optimizer.fold_constants(model) - count = redundant_scatter_nd.rules.apply_to_model(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) self.assertEqual(count, 1) onnx_check(model) optimized_model_proto = ir.serde.serialize_model(model) @@ -94,7 +94,7 @@ def test_redundant_scatter_nd_static_indices(self): model.graph.initializers["indices"] = indices_value original_model_proto = ir.serde.serialize_model(model) - count = redundant_scatter_nd.rules.apply_to_model(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertIn("Identity", [node.op_type for node in model.graph]) diff --git a/onnxscript/rewriter/rules/fusion/__init__.py b/onnxscript/rewriter/rules/fusion/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/onnx_fusions/_layer_norm.py b/onnxscript/rewriter/rules/fusion/_layer_norm.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_layer_norm.py rename to onnxscript/rewriter/rules/fusion/_layer_norm.py diff --git a/onnxscript/rewriter/onnx_fusions/_layer_norm_test.py b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py similarity index 98% rename from onnxscript/rewriter/onnx_fusions/_layer_norm_test.py rename to onnxscript/rewriter/rules/fusion/_layer_norm_test.py index 6c9734d058..6ea7f116fb 100644 --- a/onnxscript/rewriter/onnx_fusions/_layer_norm_test.py +++ b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py @@ -10,7 +10,7 @@ import onnxscript.rewriter.testing from onnxscript import FLOAT, OnnxFunction, script from onnxscript import opset18 as op -from onnxscript.rewriter.onnx_fusions._layer_norm import fuse_layer_normalization +from onnxscript.rewriter.rules.fusion._layer_norm import fuse_layer_normalization @script() diff --git a/onnxscript/rewriter/onnx_fusions/_rms_normalization.py b/onnxscript/rewriter/rules/fusion/_rms_normalization.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_rms_normalization.py rename to onnxscript/rewriter/rules/fusion/_rms_normalization.py diff --git a/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_rotary_embedding.py rename to onnxscript/rewriter/rules/fusion/_rotary_embedding.py From a925acc00f824186fd37bd0036d0745897d4b41a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 3 Sep 2025 15:39:47 -0700 Subject: [PATCH 005/123] [torchlib] Improve pixel_shuffle (#2537) Simplify the graph when input rank is 4, in which case we don't need to do any shape manipulation. Fix https://github.com/pytorch/pytorch/issues/162061 --------- Signed-off-by: Justin Chu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../function_libs/torch_lib/ops/core.py | 27 ++++++++++++------- .../function_libs/torch_lib/ops_test_data.py | 14 ++-------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e950699aca..8bb1665aaf 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6691,34 +6691,41 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType: raise NotImplementedError() -@torch_op("aten::pixel_shuffle") +@torch_op("aten::pixel_shuffle", trace_only=True) def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: """pixel_shuffle(Tensor self, int upscale_factor) -> Tensor""" - self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] + if len(self.shape) == 4: + return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD") + # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") - output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0) + final_dims = op.Shape(depth_to_space, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(depth_to_space, output_shape, allowzero=True) -@torch_op("aten::pixel_unshuffle") +@torch_op("aten::pixel_unshuffle", trace_only=True) def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" + if len(self.shape) == 4: + return op.SpaceToDepth(self, blocksize=downscale_factor) - self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) - output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0) + final_dims = op.Shape(space_to_depth, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(space_to_depth, output_shape, allowzero=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 01db7161b5..646a5133fa 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1084,26 +1084,16 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.pixel_shuffle", core_ops.aten_pixel_shuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "nn.functional.pixel_unshuffle", core_ops.aten_pixel_unshuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "ops.aten.reflection_pad1d", From 456a6bc6d5bdcf31bb4c0b268954e736e555751e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 4 Sep 2025 10:15:52 -0700 Subject: [PATCH 006/123] Update constant folding behavior for large tensors (#2488) Suggested by https://github.com/microsoft/onnxscript/issues/2466, I updated the constant folder logic to allow **Constant folding customization:** * Replaced the `always_fold_ops` parameter with a `should_fold` callable that determines on a per-node basis whether folding should occur. This allows users to specify more complex folding policies and makes the API more explicit. (`FoldConstantsPass`, `fold_constants`) [[1]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L902-R904) [[2]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L913-R918) [[3]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1248-R1268) [[4]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1263-R1285) [[5]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1276-R1295) **Logging and diagnostics improvements:** * Upgraded logging throughout the folding process to provide more informative messages, including reasons for skipping nodes (e.g., control flow, non-deterministic ops, large inputs, or graph inputs) and explicit logging when `should_fold` returns a decision. [[1]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L964-R958) [[2]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L990-R984) [[3]](diffhunk://#diff-99f13fcef1aa8c8c81fd51b7a71477572a86c91ed0dc8e618e18644d70fdf8b5L1075-R1141) **Code cleanup and minor fixes:** * Removed the unused `_update_type` function. Fix https://github.com/microsoft/onnxscript/issues/2466 cc @iksnagreb --------- Signed-off-by: Justin Chu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/optimizer/_constant_folding.py | 155 +++++++++++------- .../optimizer/_constant_folding_test.py | 27 +++ 2 files changed, 122 insertions(+), 60 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 3269f9d51e..5f34e430dc 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -9,7 +9,7 @@ import logging import math import typing -from typing import Any, Callable, Collection, Iterable, Sequence, Union +from typing import Any, Callable, Iterable, Sequence, Union import numpy as np import onnx @@ -34,6 +34,13 @@ } ) +# A list of ops to always fold regardless of their input size limits, as long as +# they are the single consumer of the large input tensors +_DEFAULT_ALWAYS_FOLD_OPS = frozenset( + { + ("", "Transpose"), + } +) logger = logging.getLogger(__name__) @@ -332,12 +339,6 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None: return None -def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None: - if type is not None: - # TODO: merge types - value.type = type - - def _get_input_element_type(node: ir.Node, index: int) -> int: input = _get_input(node, index) if input is not None and input.type is not None: @@ -899,9 +900,10 @@ class FoldConstantsPass(ir.passes.InPlacePass): shape_inference: Whether to perform shape inference. input_size_limit: Maximum size of input tensors to fold. output_size_limit: Maximum size of output tensors to fold. - always_fold_ops: Collection of op types that should always be folded. - For ops from the default opset, only op_type is neede (e.g. "Transpose"), - otherwise specify the domain with ``{domain}::{op_type}``. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding. + The function should return True/False value to indicate if this particular + node should be folded, or None to use the default folding rules. """ def __init__( @@ -910,18 +912,12 @@ def __init__( shape_inference: bool, input_size_limit: int, output_size_limit: int, - always_fold_ops: Collection[str] = frozenset(["Transpose"]), + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, ) -> None: self.shape_inference = shape_inference self.input_size_limit = input_size_limit self.output_size_limit = output_size_limit - ops = [] - for name in always_fold_ops: - domain, op_type = name.split("::", 1) if "::" in name else ("", name) - if domain == "ai.onnx": - domain = "" - ops.append((domain, op_type)) - self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops) + self.should_fold = should_fold self._opset_imports: dict[str, int] = {} self._counts: dict[str, int] = {} @@ -961,7 +957,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: input_data = {k: v for k, v in input_data.items() if v is not None} if any(t is None for t in input_types.values()): logger.debug( - "Skipping shape inference for node %s due to missing input type.", + "Skipping shape inference for node %r due to missing input type.", node.name, ) else: @@ -987,7 +983,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: logger.debug( - "Skipping shape inference for node %s due to exception: %s", + "Skipping shape inference for node %r due to exception: %s", node.name, e, ) @@ -1072,7 +1068,23 @@ def process_node(self, node: ir.Node) -> Replacement | None: output = [output] return Replacement(output, context.nodes) - if _is_control_flow_op(node) or _is_non_deterministic_op(node): + if _is_control_flow_op(node): + logger.info( + "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet", + node.name, + node.domain, + node.op_type, + ) + + return None + + if _is_non_deterministic_op(node): + logger.info( + "Skipping constant folding for non-deterministic op %r (%s::%s)", + node.name, + node.domain, + node.op_type, + ) return None if _is_onnx_op(node, "Constant"): @@ -1080,47 +1092,70 @@ def process_node(self, node: ir.Node) -> Replacement | None: return None if any(x.is_graph_input() for x in node.inputs if x is not None): - # Do not fold any graph inputs to preserve graph signature + logger.info( + "Skipping constant folding for node %r because it is graph input to preserve graph signature", + node.name, + ) return None # Ensure all node inputs are constants if any(x.const_value is None for x in node.inputs if x is not None): - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "Skipping constant folding for node %s because it has non-constant inputs", - node, - [x.name for x in node.inputs if x is not None], - ) return None - input_tensors = [x.const_value if x is not None else None for x in node.inputs] - if any( - tensor.size > self.input_size_limit - for tensor in input_tensors - if tensor is not None - ): - if (node.domain, node.op_type) in self.always_fold_ops and all( - len(input.consumers()) == 1 for input in node.inputs if input is not None - ): - # If the op is in always_fold_ops and all inputs are used only by this node, - # we can still fold it even if the input size exceeds the limit. - logger.debug( - "Folding large constant for node %s because it is in the always_fold_ops list", - node, + should_fold = self.should_fold(node) + + if should_fold is False: + logger.info( + "Skipping constant folding for node %r because should_fold returned False", + node.name, + ) + return None + + elif should_fold is None: + # Use default rules to decide whether to fold the node: + # - ConstantOfShape is preserved to avoid increasing model size unnecessarily + # - If the any tensor input size exceeds the input_size_limit, skip folding the node + if _is_onnx_op(node, "ConstantOfShape"): + logger.info( + "Skipping constant folding for node %r because ConstantOfShape is preserved by default", + node.name, ) - else: - # Skip folding large tensors - if logger.isEnabledFor(logging.DEBUG): - input_sizes = [ - tensor.size for tensor in input_tensors if tensor is not None - ] - logger.debug( - "Skipping constant folding for node %s due to large input size: %s", - node, - input_sizes, - ) return None + input_tensors = [x.const_value if x is not None else None for x in node.inputs] + large_inputs = [ + tensor is not None and tensor.size > self.input_size_limit + for tensor in input_tensors + ] + if any(large_inputs): + # Decide whether to fold large constants + assert len(node.inputs) == len(large_inputs) + if (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS and all( + len(input.consumers()) == 1 or (not is_large) + for input, is_large in zip(node.inputs, large_inputs) + if input is not None + ): + # If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node, + # we can still fold it even if the input size exceeds the limit + pass + else: + # Skip folding large tensors + if logger.isEnabledFor(logging.INFO): + input_sizes = [ + tensor.size for tensor in input_tensors if tensor is not None + ] + logger.info( + "Skipping constant folding for node %r due to large input sizes: %s", + node, + input_sizes, + ) + return None + else: + logger.info( + "Constant folding node %r because should_fold returned True", + node.name, + ) + input_values = [_get_numpy_value(x) for x in node.inputs] def convert(av): @@ -1128,6 +1163,7 @@ def convert(av): return ir.serde.serialize_tensor(av.value) return av.value + # TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node attr_values = {name: convert(attr) for name, attr in node.attributes.items()} outputs = _reference_evaluator.evaluate( node.domain, node.op_type, version, *input_values, **attr_values @@ -1137,7 +1173,7 @@ def convert(av): return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): replacement = self.new_constant(node, outputs) - if _is_onnx_op(node, "ConstantOfShape") or replacement is None: + if replacement is None: return None return Replacement(replacement.outputs, [replacement]) else: @@ -1245,7 +1281,7 @@ def fold_constants( onnx_shape_inference: bool = False, input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, - always_fold_ops: Collection[str] = frozenset(["Transpose"]), + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, ) -> FoldConstantsResult: """ Applies constant folding optimization to the model. @@ -1260,10 +1296,9 @@ def fold_constants( output_size_limit: The maximum size of output tensors that can be stored after constant folding. Defaults to `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`. - always_fold_ops: A collection of op types that should always be folded, - regardless of their input or output sizes. For ops from the default opset, - only op_type is neede (e.g. "Transpose"), otherwise specify the domain - with ``{domain}::{op_type}``. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding, False if it should not be folded, + or None to use the default rules. Defaults to a function that always returns None. Returns: An instance of `FoldConstantsResult`. @@ -1273,6 +1308,6 @@ def fold_constants( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, - always_fold_ops=always_fold_ops, + should_fold=should_fold, ) return folder_pass(model) # type: ignore[return-value] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 8c05fbc0a4..6b2557551e 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -581,6 +581,33 @@ def test_transpose_is_always_folded(self): ops = [node.op_type for node in optimized.graph] self.assertEqual(ops, ["Constant"]) + def test_node_is_folded_if_specified_as_should_fold(self): + model_text = """ + + agraph (float[M, 256] x) => (float[42, 42] z) + + { + z = ConstantOfShape (w) + } + """ + model = ir.from_onnx_text(model_text) + + # ConstantOfShape is not folded by default + optimized = self._fold(model) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["ConstantOfShape"]) + + # But ConstantOfShape is folded when specified in should_fold + optimized = self._fold( + model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None + ) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["Constant"]) + np.testing.assert_array_equal( + optimized.graph.node(0).attributes["value"].as_tensor().numpy(), + np.ones((42, 42), dtype=np.int64), + ) + def test_multi_graph_identity_output_preserves_output_name(self): model = """ From fc792e40d33656abc093cb760870360c45bf536b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 4 Sep 2025 10:16:55 -0700 Subject: [PATCH 007/123] [torchlib] Improve handling of SymInt[] (#2522) Previously sizes coming in as `SymInt[]` are first concatenated as INT64 then used. This created inefficiencies where we could not process any static dims from the size list and had to treat the whole shape as dynamic. In aten_expand, this meant we needed to add `Abs` on the shape. This change updates the functions that take `SymInt[]` such that they are no longer turned into INT64 first. I updated aten_expand to process constant `-1` values so an `Abs` is not required. I also added a helper `merge_dims` to create constants for consecutive constant dims first before concatinating. --------- Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/common.py | 21 ++++++ .../function_libs/torch_lib/ops/core.py | 72 +++++++++---------- 2 files changed, 57 insertions(+), 36 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index d7784a5289..b3ebbc1c53 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -5,6 +5,8 @@ # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" from __future__ import annotations +from collections.abc import Sequence + import numpy.typing as npt import onnx @@ -78,3 +80,22 @@ def constant( A constant node. """ return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype))) + + +def merge_dims(dims: Sequence[int | INT64]) -> INT64: + """Concatenate dimensions into a single value.""" + + if not dims: + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) + + neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1])) + + result_dims = [ + op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d) + for d in dims + ] + + # Set the output type to INT64 so op.Concat can be used + for dim in result_dims: + dim.dtype = ir.DataType.INT64 + return op.Concat(*result_dims, axis=0) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8bb1665aaf..3607a11361 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1523,10 +1523,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::broadcast_to") -def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor: +@torch_op("aten::broadcast_to", trace_only=True) +def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor: """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - + size = common_ops.merge_dims(size) return op.Expand(self, size) @@ -3286,20 +3286,20 @@ def aten_embedding_sparse_backward( @torch_op("aten::empty.memory_format", trace_only=True) def aten_empty( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False, memory_format: str = "", ) -> TensorType: # type: ignore[type-var] - # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + """empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # using Zeros to simulate np.empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + # using Zeros to simulate empty() + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3334,17 +3334,18 @@ def aten_empty_quantized( @torch_op("aten::empty_strided", trace_only=True) def aten_empty_strided( - size: INT64, + size: Sequence[INT64], stride: INT64, layout: str = "", + dtype: int = FLOAT.dtype, device: str = "", pin_memory: bool = False, ) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3392,13 +3393,14 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand", trace_only=True) -def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor: +def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor: """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. - size = op.Abs(size) - return op.Expand(self, size) + # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely + # and isn't expected to appear from correct usages of SymInt. + size = [1 if isinstance(s, int) and s == -1 else s for s in size] + return op.Expand(self, common_ops.merge_dims(size)) @torch_op("aten::expand_as", trace_only=True) @@ -7409,12 +7411,10 @@ def aten_repeat_interleave_Tensor( ) -@torch_op("aten::reshape") -def aten_reshape(self: TTensor, shape: IntType) -> TTensor: +@torch_op("aten::reshape", trace_only=True) +def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor: """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" - - # Reshape only support INT64 as 'shape' - shape = op.Cast(shape, to=INT64.dtype) + shape = common_ops.merge_dims(shape) return op.Reshape(self, shape) @@ -9153,23 +9153,22 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True) -def aten_view(self: TTensor, size: IntType) -> TTensor: +def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size, allowzero=True) -@torch_op(("aten::view", "aten::_unsafe_view"), complex=True) -def aten_view_complex(self: TTensor, size: IntType) -> TTensor: +@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True) +def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) + complex_size = common_ops.merge_dims([*size, 2]) return op.Reshape(self, complex_size, allowzero=True) -@torch_op("aten::view_as") +@torch_op("aten::view_as", trace_only=True) def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" @@ -9213,11 +9212,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_copy") -def aten_view_copy(self: TTensor, size: IntType) -> TTensor: +@torch_op("aten::view_copy", trace_only=True) +def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size) @@ -9245,7 +9244,8 @@ def reshape_to_2d(tensor): "aten::where.ScalarSelf", "aten::where.ScalarOther", "aten::where.self", - ) + ), + trace_only=True, ) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor""" @@ -9261,7 +9261,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::zeros", trace_only=True) def aten_zeros( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -9270,9 +9270,9 @@ def aten_zeros( """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) From d98e3dd0ae7caa15b6dba251f82f7450a68dd505 Mon Sep 17 00:00:00 2001 From: Karl Sassie Date: Thu, 4 Sep 2025 23:59:54 +0200 Subject: [PATCH 008/123] [torch] Fix incorrect Concat when processing dynamic paddings (#2540) See issue #2539 for a better explanation. I know crazy stuff right =^). --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index bccddb88a6..88b5bf807e 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1503,7 +1503,7 @@ def _process_padding(padding: Sequence[INT64 | int], rank: int) -> INT64: paddings = [*paddings, *zeros] # Interleave the padding values paddings = paddings[-2::-2] + paddings[-1::-2] - return op.Concat(paddings, axis=0) + return op.Concat(*paddings, axis=0) @torch_op("aten::pad", trace_only=True) From 19349018cb256c2f579f7b809433960360f89911 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 4 Sep 2025 15:28:35 -0700 Subject: [PATCH 009/123] Add test for dynamic padding (#2541) This is a follow up of https://github.com/microsoft/onnxscript/pull/2540 to add a test described in https://github.com/microsoft/onnxscript/issues/2539. Fix https://github.com/microsoft/onnxscript/issues/2539 Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/e2e_ops_tests.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index a0d0a0d880..253637ccd2 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -159,6 +159,21 @@ def forward(self, query, key, value, attn_mask): ) _testing.assert_onnx_program(onnx_program) + def test_dynamic_paddings(self): + class Model(torch.nn.Module): + def forward(self, x): + height = x.size(2) # height is SymInt + x = torch.nn.functional.pad(x, (0, 0, 0, height), mode="replicate") + return x + + onnx_program = torch.onnx.export( + Model(), + (torch.rand(1, 1, 1, 1),), + dynamo=True, + dynamic_shapes=({2: torch.export.Dim("H")},), + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From e76bfe0d95b4fc259ceacc75d916b61c016bb861 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 5 Sep 2025 12:48:56 -0700 Subject: [PATCH 010/123] [Reland] Update SplitToSequence in constant folding (#2544) Split input (SymbolicTensor) could have no const_value, but with shape that gives us information of how many outputs an op.Split should return. --- onnxscript/optimizer/_constant_folding.py | 40 ++++++++++---- .../optimizer/_constant_folding_test.py | 54 +++++++++++++++++++ 2 files changed, 83 insertions(+), 11 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 5f34e430dc..350277cc01 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -801,27 +801,45 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: axis = axis + rank if axis < 0 or axis >= rank: return None - split_dimension_size = shape[axis] - if not isinstance(split_dimension_size, int): - return None + # NOTE: Split needs to either be a scalar or a 1-D tensor. We need to + # calculate the number of outputs for Split. + # If split is a scalar, we split into chunks of size 'split' if possible. + # * the split dimension size and split_value has to be known. + # If split is a 1-D tensor, we split into 'size(split)' chunks + # * Get the size from split_value if it's numpy array. + # * Get the size from symbolic shape if split_value is not available. split_value = _get_numpy_value(split) - if split_value is None: + split_shape = ( + split.shape.numpy() if split.shape is not None and split.shape.is_static() else None + ) + + # No information about split value or shape. + if split_value is None and split_shape is None: return None - assert isinstance(split_value, np.ndarray) - if split_value.ndim == 0: - # split into chunks all of size 'split' if possible. - num_outputs = math.ceil(split_dimension_size / split_value.item()) + if isinstance(split_shape, tuple) and len(split_shape) == 1: + # If split_shape is known, we can use it to determine the number of outputs. + split_dimension_size = split_shape[0] + assert isinstance(split_dimension_size, int) + num_outputs = split_dimension_size split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split( - input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs - ) + split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) elif split_value.ndim == 1: # split into 'size(split)' chunks num_outputs = split_value.size split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) + elif split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + split_dimension_size = shape[axis] + if not isinstance(split_dimension_size, int): + return None + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split( + input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs + ) else: return None diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 6b2557551e..d3d76c4a23 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -346,6 +346,60 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( self.assertEqual(len(optimized.graph), 7) self.assertEqual(optimized.graph[6].op_type, "Concat") + def test_dynamic_split_to_sequence_list_shape_rewrite(self): + # split is a graph input with known 1-D static shape [4]; values unknown (not constant) + # Ensures the branch: if isinstance(split_shape, tuple) and len(split_shape) == 1 + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[2,N] x, int64[4] split) => (float[2,N] return_val) { + splits = SplitToSequence (x, split) + i0 = Constant () + s0 = SequenceAt (splits, i0) + i1 = Constant () + s1 = SequenceAt (splits, i1) + i2 = Constant () + s2 = SequenceAt (splits, i2) + i3 = Constant () + s3 = SequenceAt (splits, i3) + return_val = Concat (s0, s1, s2, s3) +}""" + optimized = self._fold(model) + # Expect: Split + Concat (index constants & SequenceAt removed) + split_nodes = [n for n in optimized.graph if n.op_type == "Split"] + self.assertEqual(len(split_nodes), 1) + self.assertEqual(len(split_nodes[0].outputs), 4) + self.assertEqual(split_nodes[0].op_type, "Split") + self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph)) + + def test_dynamic_split_to_sequence_list_shape_no_keepdims(self): + # keepdims=0 path with dynamic (non-constant) splits input; triggers squeeze logic. + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,M] x, int64[3] split) => (float[1,M] return_val) { + splits = SplitToSequence (x, split) + i0 = Constant () + s0 = SequenceAt (splits, i0) + i1 = Constant () + s1 = SequenceAt (splits, i1) + i2 = Constant () + s2 = SequenceAt (splits, i2) + return_val = Concat (s0, s1, s2) +}""" + optimized = self._fold(model) + split_nodes = [n for n in optimized.graph if n.op_type == "Split"] + self.assertEqual(len(split_nodes), 1) + self.assertEqual(len(split_nodes[0].outputs), 3) + self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph)) + # Each split output should have a corresponding Squeeze (keepdims=0 branch) + squeeze_nodes = [n for n in optimized.graph if n.op_type == "Squeeze"] + self.assertEqual(len(squeeze_nodes), 3) + def test_initializer_input_not_folded(self): model_text = """ From 5762a6977606d19bfe87d21bd2d21e34269413af Mon Sep 17 00:00:00 2001 From: Ayoub BIH <89558574+AyoubMDL@users.noreply.github.com> Date: Fri, 5 Sep 2025 22:10:47 +0200 Subject: [PATCH 011/123] [Rewriter]: add fusion rules for successive Min/Max patterns (#2500) This PR adds the following transformation: - Min(Min(X)) -> Min(X) - Max(Max(X)) -> Max(X) - Min(Max(X)) -> Clip(X) - Max(Min(X)) -> Clip(X) --- onnxscript/rewriter/__init__.py | 2 + onnxscript/rewriter/rules/common/__init__.py | 10 + .../rewriter/rules/common/_min_max_to_clip.py | 253 ++++++++++++ .../rules/common/_min_max_to_clip_test.py | 367 ++++++++++++++++++ 4 files changed, 632 insertions(+) create mode 100644 onnxscript/rewriter/rules/common/_min_max_to_clip.py create mode 100644 onnxscript/rewriter/rules/common/_min_max_to_clip_test.py diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 1d07e9f5af..232750af78 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -37,6 +37,7 @@ _collapse_slices, _fuse_pad_into_conv, _fuse_relus_clips, + _min_max_to_clip, _no_op, _redundant_scatter_nd, ) @@ -47,6 +48,7 @@ *_broadcast_to_matmul.rules, *_cast_constant_of_shape.rules, *_collapse_slices.rules, + *_min_max_to_clip.rules, *_fuse_relus_clips.rules, *_basic_rules.basic_optimization_rules(), *_redundant_scatter_nd.rules, diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 752e3c9430..e86b46cd7b 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -15,6 +15,10 @@ "fuse_batchnorm_into_gemm_rule", "fuse_pad_into_conv_integer_rule", "fuse_pad_into_conv_rule", + "min_min_rule", + "max_max_rule", + "min_max_rule", + "max_min_rule", "gemm_to_matmul_add_rule", "matmul_add_to_gemm_rule", "mul_by_1_rule", @@ -89,6 +93,12 @@ transpose_ab_matmul_add_to_gemm_rule, transpose_b_matmul_add_to_gemm_rule, ) +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, +) from onnxscript.rewriter.rules.common._no_op import ( add_0_rule, div_by_1_rule, diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip.py b/onnxscript/rewriter/rules/common/_min_max_to_clip.py new file mode 100644 index 0000000000..88ae495dbc --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses successive Min/Max patterns in ONNX graphs. + +Supported transformations: +- Min(Min(X, c1, c2, ...), d1, d2, ...) → Min(X, fused_const) +- Max(Max(X, c1, c2, ...), d1, d2, ...) → Max(X, fused_const) +- Min(Max(X, lb1, lb2, ...), ub1, ub2, ...) → Clip(X, lb, ub) +- Max(Min(X, ub1, ub2, ...), lb1, lb2, ...) → Clip(X, lb, ub) + +Where: + - fused_const is the reduction (min or max) over all constant inputs. + - For Clip fusion: + * All constant inputs must be scalars. + * The effective lower bound is the maximum of all lower-bound constants. + * The effective upper bound is the minimum of all upper-bound constants. + + For the case of Max(Min(X, upper_bound), lower_bound): + * The rule applies only if lower_bound ≤ upper_bound. + +General constraints: + - The first input may be any tensor. + - All other inputs must be constant tensors (from Constant nodes or initializers). +""" + +import abc +import functools +from typing import ClassVar + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC): + """Base class for Min/Max fusion rewrites. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - If ``need_scalars`` is True (Clip fusion), all constants must be scalars. + - If ``check_bounds`` is True (Clip fusion in the pattern Max(Min(X, upper_bound), lower_bound)), lower_bound ≤ upper_bound. + """ + + need_scalars: ClassVar = False + check_bounds: ClassVar = False + + @abc.abstractmethod + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: ... + + def rewrite(self, op, x, out1, out2): + first_node = out1.producer() + second_node = out2.producer() + + # Compute new constants for the fused op + constants = self.compute_constants(first_node, second_node, x.name) + + initializers = [op.initializer(constant, name=name) for constant, name in constants] + + return op.op( + self.op_type, + inputs=[x, *initializers], + ) + + def _is_scalar(self, v: np.ndarray) -> bool: + return np.isscalar(v) or np.size(v) == 1 + + def check(self, context, out1, out2, **_): + """Condition to check if we need to replace the pattern. + + Conditions: + - The min and max input nodes must not be graph inputs. + - These inputs (except the first) must be constant values (from Constant nodes or initializers). + - In the case of Min(Max) and Max(Min) patterns: + * All inputs must be scalars (as Clip requires scalars). + For Max(Min) pattern: + * The lower bound must be less than or equal to the upper bound. + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Not used + check_result = MatchResult() + + first_node = out1.producer() + second_node = out2.producer() + + # Ensure all inputs except the first are constants + for input_ in first_node.inputs[1:] + second_node.inputs[1:]: + if ir.convenience.get_const_tensor(input_) is None: + return check_result.fail(f"{input_.name} is not a constant.") + + # If scalars are required (Clip fusion), enforce scalar-ness + if self.need_scalars and not self._is_scalar(input_.const_value.numpy()): + return check_result.fail(f"{input_.name} is not a scalar.") + + if self.need_scalars and self.check_bounds: + # For Clip fusion in the case of Max(Min(X, upper_bound), lower_bound): check that lower_bound <= upper_bound + lower_bound, upper_bound = self.compute_constants(first_node, second_node) + if lower_bound[0].numpy() > upper_bound[0].numpy(): + return check_result.fail( + f"Invalid bounds: lower bound ({lower_bound[0].numpy()}) is greater " + f"than upper bound ({upper_bound[0].numpy()})." + ) + + return check_result + + +class FuseSuccessiveMin(_FuseMinMaxBase): + """Replaces ``Min(Min(X, c1, c2, ...), d1, d2, ...)`` with ``Min(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Min" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.minimum, values)), f"{input_name}_min")] + + def pattern(self, op, x): + return op.Min( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseSuccessiveMax(_FuseMinMaxBase): + """Replaces ``Max(Max(X, c1, c2, ...), d1, d2, ...)`` with ``Max(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Max" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.maximum, values)), f"{input_name}_max")] + + def pattern(self, op, x): + return op.Max( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMaxMinToClip(_FuseMinMaxBase): + """Replaces ``Min(Max(X, lb1, lb2, ...), ub1, ub2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + lower_bound = np.max([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + upper_bound = np.min([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Min( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMinMaxToClip(_FuseMinMaxBase): + """Replaces ``Max(Min(X, ub1, ub2, ...), lb1, lb2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + - Requires ``lower_bound <= upper_bound``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + check_bounds: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + upper_bound = np.min([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + lower_bound = np.max([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Max( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +min_min_rule = FuseSuccessiveMin().rule() +max_max_rule = FuseSuccessiveMax().rule() +min_max_rule = FuseMinMaxToClip().rule() +max_min_rule = FuseMaxMinToClip().rule() + + +rules = RewriteRuleSet( + [ + min_min_rule, + max_max_rule, + min_max_rule, + max_min_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py new file mode 100644 index 0000000000..dd09078a9e --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py @@ -0,0 +1,367 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, + rules, +) + + +class _TestMinMaxToClipBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250817) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = rules.apply_to_model(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = ( + self.rng.integers( + low=-10, + high=10, + size=(2, *updated_model.graph.inputs[0].shape[1:]), + dtype=np.int32, + ), + ) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class TestFuseSuccessiveMinOrMax(_TestMinMaxToClipBase): + @parameterized.expand( + [ + ("int32_min", "int32", "Min"), + ("int32_max", "int32", "Max"), + ("float32_min", "float", "Min"), + ("float32_max", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14, 17] X) => ({dtype} [N, ?, ?, ?] Y) + <{dtype}[1] cst1 = {{3}}, {dtype}[1] cst2 = {{6}}> + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min_multi", "int32", "Min"), + ("int32_max_multi", "int32", "Max"), + ("float32_min_multi", "float", "Min"), + ("float32_max_multi", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_multiple_inputs(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 3, 3] X) => ({dtype}[N, 3, 3] Y) + <{dtype}[3] cst1 = {{2, 5, 8}}, + {dtype}[1] cst2 = {{4}}, + {dtype}[3] cst3 = {{3, 1, -6}}, + {dtype}[1] cst4 = {{10}}, + {dtype}[3] cst5 = {{-2, 7, 9}}, + {dtype}[1] cst6 = {{0}}, + {dtype}[3] cst7 = {{11, -3, 4}}> + {{ + x1 = {op_type}(X, cst1, cst2, cst3, cst4) + Y = {op_type}(x1, cst5, cst6, cst7) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min", "Min"), + ("int32_max", "Max"), + ("float32_min", "Min"), + ("float32_max", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + cst2 = Constant() + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=["Constant", op_type]) + + @parameterized.expand( + [ + ("min_nonconst", "Min", min_min_rule), + ("max_nonconst", "Max", max_max_rule), + ] + ) + def test_failure_fuse_successive_min_or_max_non_constant(self, _, op_type, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] Y) + + {{ + cst1 = ReduceMean(X) + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is not a constant.") + + @parameterized.expand( + [ + ("min_graph_input", "Min"), + ("max_graph_input", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_graph_inputs_as_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] cst1, float[1] cst2) => (float[N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type]) + + +class TestMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_min_max_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_min_max_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + max = Constant() + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_min_max_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_min_max_to_clip_invalid_bounds(self): + """Min node should have the max value and Max node should have the min value.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "Invalid bounds:") + + def test_failure_fuse_min_max_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(model, min_max_rule, "is not a constant.") + + def test_failure_min_max_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "is not a scalar") + + +class TestMaxMinToClip(_TestMinMaxToClipBase): + def test_successful_max_min_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + min = Constant() + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_max_min_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_check_bounds(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_fuse_max_min_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_failed_condition_test(model, max_min_rule, "is not a constant.") + + def test_failure_max_min_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Max(X, min) + Y = Min(x1, max) + } + """) + self.run_failed_condition_test(base_model, max_min_rule, "is not a scalar") + + +class TestIntegrationMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_full_chain_fusion(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min1) + x2 = Min(x1, min2) + x3 = Max(x2, max1) + x4 = Max(x3, max2) + x5 = Min(x4, min3) + x6 = Max(x5, max3) + Y = Min(x6, min4) + } + """) + self.run_test(model, expected_op_types=["Clip", "Clip", "Clip"]) + + +if __name__ == "__main__": + unittest.main() From f5f9e6a616c763b731c97e2d8dae6eac3544f674 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 5 Sep 2025 14:17:26 -0700 Subject: [PATCH 012/123] Update onnx-weekly version to 1.20.0 (#2545) --- requirements/ci/requirements-onnx-weekly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index e2eda3baa9..9c5363b8af 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.19.0.dev20250602 +onnx-weekly==1.20.0.dev20250901 From d0fb218c03c8bb1e041b9f081d7dd61d59e519ef Mon Sep 17 00:00:00 2001 From: Johan MEJIA <69996955+Johansmm@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:17:57 +0200 Subject: [PATCH 013/123] [rewriter] Unify reshape flatten ops (#2518) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Following (https://github.com/microsoft/onnxscript/issues/2301), `flatten_to_reshape_rule` rule set is introduced to reduce the following list of operators: - Reshape ∘ Flatten -> Reshape - Flatten ∘ Reshape -> Reshape Note to support this changes: - `ReshapeReshape` rule is updated to support more cases. - `Flatten2Reshape` rule is introduced to convert Flatten ops into Reshape when possible. --- onnxscript/rewriter/rules/common/__init__.py | 2 + .../rewriter/rules/common/_basic_rules.py | 87 +++++- .../rules/common/_basic_rules_test.py | 264 ++++++++++++++---- 3 files changed, 288 insertions(+), 65 deletions(-) diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index e86b46cd7b..0b01bade72 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -10,6 +10,7 @@ "div_by_1_rule", "dropout_inference_rule", "dropout_zero_rule", + "flatten_to_reshape_rule", "fuse_batchnorm_into_conv_rule", "fuse_batchnorm_into_conv_transpose_rule", "fuse_batchnorm_into_gemm_rule", @@ -48,6 +49,7 @@ from onnxscript.rewriter.rules.common._basic_rules import ( cast_cast_rule, + flatten_to_reshape_rule, no_op_cast_rule, no_op_expand_rule, no_op_transpose_rule, diff --git a/onnxscript/rewriter/rules/common/_basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py index 6f38050f3e..b7a648880a 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules.py +++ b/onnxscript/rewriter/rules/common/_basic_rules.py @@ -11,6 +11,8 @@ from typing import ClassVar, Sequence +import numpy as np + from onnxscript import ir from onnxscript.rewriter import _ir_utils as ir_utils from onnxscript.rewriter._basics import MatchResult @@ -123,16 +125,37 @@ def pattern(self, op, x, shape_ignored, shape): return op.Reshape(op.Reshape(x, shape_ignored), shape) def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): - return op.Reshape(x, shape) + new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name)) + return op.Reshape(x, new_shape, allowzero=self._allowzero) def check(self, context, x, shape_ignored, shape) -> MatchResult: check_result = MatchResult() - if shape_ignored.const_value is None: - return check_result.fail("Shape ignored is not a constant.") - if shape.const_value is None: + + # Shape must be a constant. + if (np_shape := ir_utils.get_numpy_value(shape)) is None: return check_result.fail("Shape is not a constant.") - if shape.const_value.numpy().min() <= 0: - return check_result.fail("Shape has non-positive values.") + # Convert to array to support assignment destination. + self._new_shape = np.array(np_shape, np_shape.dtype) + + # Try to replace {0,-1} values in shape if reshape output is known. + if (reshape_output := context.output_values[0].shape) is not None: + for i, dim in enumerate(reshape_output): + if isinstance(dim, int) and dim > 0: + self._new_shape[i] = dim + + # Constraints for shape. + self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0) + if self._allowzero == 1 and any(self._new_shape == 0): + return check_result + if any(self._new_shape == 0) and any(self._new_shape < 0): + return check_result.fail("Shape cannot contain both 0 and -1 dimensions.") + elif np.count_nonzero(self._new_shape == 0) > 1: + return check_result.fail("Shape cannot contain more than one 0 dimension.") + + # At this point, we can safely replace '0' with '-1'. + # Note allowzero is removed since at this point it does not have any effect. + self._allowzero = None + self._new_shape = np.where(self._new_shape == 0, -1, self._new_shape) return check_result @@ -279,6 +302,55 @@ def check(self, context, x, axes1, axes2) -> MatchResult: return check_result +class Flatten2Reshape(RewriteRuleClassBase): + """Convert ``Flatten(x)`` to Reshape.""" + + def pattern(self, op, x: ir.Value): + return op.Flatten(x) + + def rewrite(self, op, x: ir.Value): + new_shape = op.initializer(ir.Tensor(self._new_shape, name=f"{x.name}/shape")) + return op.Reshape(x, new_shape) + + def check(self, context, x: ir.Value) -> MatchResult: + check_result = MatchResult() + self._new_shape = np.array([-1, -1], "int64") + + # Convert axis in a positive value if possible. + axis = context.root.attributes.get_int("axis", 1) + input_rank = None + if (input_shape := x.shape) is not None: + input_rank = len(input_shape) + if axis < 0: + axis += input_rank + + # Compute reshape shape following axis attribute. + if axis == 0: + self._new_shape[0] = 1 + elif axis == 1: + self._new_shape[0] = 0 + elif axis == input_rank: + self._new_shape[1] = 1 + + # Try to update shape if output is known. + if (output_shape := context.output_values[0].shape) is not None: + for i, dim in enumerate(output_shape): + if isinstance(dim, int): + self._new_shape[i] = dim + + # Try to update shape if input is known. + if input_shape is not None: + if all(isinstance(dim, int) for dim in input_shape[:axis]): + self._new_shape[0] = np.prod(input_shape[:axis]) + if all(isinstance(dim, int) for dim in input_shape[axis:]): + self._new_shape[1] = np.prod(input_shape[axis:]) + + # Verify if it is possible to apply rule. + if np.count_nonzero(self._new_shape == -1) > 1: + return check_result.fail("Impossible to compute new shape.") + return check_result + + # Create rule instances cast_cast_rule = CastCast.rule() no_op_cast_rule = CastIdentity.rule() @@ -289,6 +361,7 @@ def check(self, context, x, axes1, axes2) -> MatchResult: transpose_transpose_rule = TransposeTranspose.rule() unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule() squeeze_reshape_1d_rule = SqueezeReshape.rule() +flatten_to_reshape_rule = Flatten2Reshape.rule() def basic_optimization_rules() -> RewriteRuleSet: @@ -311,6 +384,8 @@ def basic_optimization_rules() -> RewriteRuleSet: cast_cast_rule, no_op_cast_rule, no_op_expand_rule, + # flatten_to_reshape_rule is order sensitive to reshape_reshape_rule + flatten_to_reshape_rule, reshape_reshape_rule, slice_split_rule, no_op_transpose_rule, diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 8709300763..9ce74b22a2 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -14,6 +14,8 @@ import onnxscript.onnx_types as ot from onnxscript import ir from onnxscript.onnx_opset import opset18 +from onnxscript.rewriter import MatchingTracer, testing +from onnxscript.rewriter import pattern as orp from onnxscript.rewriter.rules.common import _basic_rules FLOAT = onnx.TensorProto.FLOAT @@ -29,6 +31,10 @@ def _make_model(*args, **kwargs) -> ir.Model: return ir.serde.deserialize_model(onnx.helper.make_model(*args, **kwargs)) +def clone_model(model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + class BasicRulesTest(unittest.TestCase): def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: feeds: dict[str, Any] = {} @@ -318,65 +324,6 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): self.assertEqual(["Constant", "Unsqueeze"], [n.op_type for n in model.graph]) self._check_model(model_proto, rewritten_model) - @parameterized.parameterized.expand( - [ - ( - "double_reshape_1", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]), - onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])], - [ - onnx.numpy_helper.from_array( - np.array([4, 5, 3], dtype=np.int64), name="shape_" - ), - onnx.numpy_helper.from_array( - np.array([5, 4, 3], dtype=np.int64), name="shape" - ), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), - ( - "double_reshape_2", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]), - onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])], - [ - onnx.numpy_helper.from_array( - np.array([-1], dtype=np.int64), name="shape_" - ), - onnx.numpy_helper.from_array( - np.array([5, 4, 3], dtype=np.int64), name="shape" - ), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), - ] - ) - def test_reshape_reshape_rule(self, _: str, model: ir.Model): - rule_set = _basic_rules.basic_optimization_rules() - model_proto = ir.serde.serialize_model(model) - rule_set.apply_to_model(model) - rewritten_model = ir.serde.serialize_model(model) - - self.assertEqual(["Reshape"], [n.op_type for n in model.graph]) - self._check_model(model_proto, rewritten_model) - @classmethod def _slices_split_models(cls): models = [ @@ -465,5 +412,204 @@ def model3(X: ot.FLOAT[1, 1]): check(model3, 0) +class ReshapeReshapeTest(unittest.TestCase): + @staticmethod + def create_model( + input_shape, shape1, shape2, allowzero1=0, allowzero2=0, infer_shape=False + ): + def _convert_shape(shape, name): + if isinstance(shape, np.ndarray): + shape = tape.initializer(ir.Tensor(shape, name=name)) + elif isinstance(shape, (list, tuple)): + shape = ir.Input(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64)) + tape.graph_like.inputs.append(shape) + else: + raise TypeError(f"Unsupported type {type(shape)} for shape.") + return shape + + x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT)) + y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT)) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + reshape = tape.op( + "Reshape", + inputs=[x, _convert_shape(shape1, "shape_")], + attributes={"allowzero": allowzero1}, + ) + tape.op( + "Reshape", + inputs=[reshape, _convert_shape(shape2, "shape")], + attributes={"allowzero": allowzero2}, + output=y, + ) + model = ir.Model(tape.graph_like, ir_version=10) + + # Infer shapes. + if infer_shape: + model = ir.passes.common.ShapeInferencePass()(model).model + return model + + @parameterized.parameterized.expand( + [ + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 8), [2, 0, 3, -1], [0, 3, 2, 8]), + ((3, 4, 8), [3, 4, -1], [-1, 12], 1), + ((3, 4, 2), [0, 4, -1], [12, -1], 0, 1), + ((3, 0, 8), [4, 2, 0, 0], [3, 0], 1, 1), + ] + ) + def test_reshape_reshape_rule( + self, input_shape, shape1, shape2, allowzero1=0, allowzero2=0 + ): + model = self.create_model( + input_shape, + np.array(shape1, dtype="int64"), + np.array(shape2, dtype="int64"), + allowzero1=allowzero1, + allowzero2=allowzero2, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(10).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand([([3, 2, 3, 3, 3], 1), ([0, -1, 3, 2], 0)]) + def test_reshape_dynamic_reshape_rule(self, shape1, allowzero1=0): + input_shape = (3, 6, 9) + shape1 = np.array(shape1, dtype="int64") + # Build the model with unknown shape1. + model = self.create_model( + input_shape, + (shape1.size,), + np.array((1, 6, 27), dtype="int64"), + allowzero1=allowzero1, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + feeds = { + "X": np.random.default_rng(2).random(input_shape, dtype="float32"), + "shape_": shape1, + } + testing.assert_numerically_equal(model, updated_model, feeds, atol=0, rtol=0) + + @parameterized.parameterized.expand( + [((3, 6, 9), [0, 3, 2, -1]), ((0, 6, 2), [0, 0, 3], 1)] + ) + def test_reshape_reshape_dynamic_rule(self, input_shape, shape2, allowzero2=0): + # Note that shape inference is required for this test to be valid. + shape2 = np.array(shape2, dtype="int64") + model = self.create_model( + input_shape, + np.array((3, 2, -1), dtype="int64"), + shape2, + allowzero2=allowzero2, + infer_shape=True, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(7).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + ((3,), "is not a constant"), + (np.array([0, -1], dtype="int64"), "both 0 and -1 dimensions"), + (np.array([0, 0, 3], dtype="int64"), "more than one 0 dimension"), + ] + ) + def test_unsupported_reshape_reshape(self, shape2, error_msg): + model = self.create_model((1, 2, 3), np.array([1, 6], dtype="int64"), shape2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.reshape_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.reshape_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, error_msg) + + +class Flatten2ReshapeTest(unittest.TestCase): + @staticmethod + def create_model(input_shape, axis=1): + x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT)) + y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT)) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + tape.op("Flatten", inputs=[x], attributes={"axis": axis}, output=y) + model = ir.Model(tape.graph_like, ir_version=10) + return model + + @parameterized.parameterized.expand(list(range(-5, 6))) + def test_flatten_to_reshape_rule(self, axis): + input_shape = (1, 4, 8, 7, 5) + model = self.create_model(input_shape=input_shape, axis=axis) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(13).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand(list(range(-4, 5))) + def test_flatten_to_reshape_dynamic_input(self, axis): + model = self.create_model(input_shape=("N", "C1", "C2", "C3"), axis=axis) + # Rule is supported in all cases if the output shape is known for non-special cases. + input_shape = (1, 2, 3, 4) + if axis not in {-3, 0, 1, 4}: + out_shape = ir.Shape((np.prod(input_shape[:axis]), np.prod(input_shape[axis:]))) + model.graph.outputs[0].shape = out_shape + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(17).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + def test_unsupported_flatten_to_reshape(self): + model = self.create_model(input_shape=("N", "C1", "C2"), axis=2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.flatten_to_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, "Impossible to compute new shape") + + if __name__ == "__main__": unittest.main(verbosity=2) From 9036fabf140e8b3015a947ad3710c08a86097506 Mon Sep 17 00:00:00 2001 From: Ayoub BIH <89558574+AyoubMDL@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:39:27 +0200 Subject: [PATCH 014/123] [Rewriter] Support specifying node name in rewrites (#2474) Allows passing a node name when defining a rewrite. fixes https://github.com/microsoft/onnxscript/issues/2435 --------- Co-authored-by: Justin Chu --- onnxscript/ir/_tape.py | 22 ++++++++++++++++++++-- onnxscript/ir/_tape_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 79312eaefa..78dce2739e 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -17,7 +17,17 @@ class Builder(tape.Tape): - """An extension of the tape that provides a more convenient API for constructing the IR.""" + """An extension of the tape that provides a more convenient API for constructing the IR. + + Example: + >>> from onnxscript import ir + >>> from onnxscript.ir import _tape + >>> op = _tape.Builder() + >>> input = ir.Value(name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))) + >>> relu_val = op.Relu(input, _name="relu_node", _domain="", _version=18, _outputs=["relu_out"]) + + Note: When passing `_name`, ensure it is unique to avoid duplicate node names. + """ def __getattr__(self, op_type: str) -> Any: return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) @@ -26,6 +36,8 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, domain = kwargs.pop("_domain", "") version = kwargs.pop("_version", None) outputs = kwargs.pop("_outputs", 1) + name = kwargs.pop("_name", None) + if isinstance(outputs, Sequence): num_outputs = len(outputs) else: @@ -34,7 +46,12 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, if num_outputs == 1: value = super().op( - op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version + op_type, + inputs=inputs, + attributes=kwargs, + domain=domain, + version=version, + name=name, ) if isinstance(outputs, Sequence): value.name = outputs[0] @@ -45,6 +62,7 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, attributes=kwargs, domain=domain, version=version, + name=name, num_outputs=num_outputs, ) if isinstance(outputs, Sequence): diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py index 46cbcc23fe..f8210e7a0b 100644 --- a/onnxscript/ir/_tape_test.py +++ b/onnxscript/ir/_tape_test.py @@ -5,6 +5,7 @@ import unittest from onnxscript import ir +from onnxscript.ir import _tape class TestTape(unittest.TestCase): @@ -72,5 +73,32 @@ def test_op_multi_out(self): self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) +class TestBuilder(unittest.TestCase): + def test_op_name(self): + op = _tape.Builder() + + input_a = ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + input_b = ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + + add = op.Add(input_a, input_b, _name="add_node") + _ = op.Relu(add, _name="relu_node") + self.assertEqual(op.nodes[0].name, "add_node") + self.assertEqual(op.nodes[1].name, "relu_node") + + def test_op_name_multi_out(self): + op = _tape.Builder() + + input_a = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + + _ = op.CustomOp(input_a, _name="custom_node", _outputs=3) + self.assertEqual(op.nodes[0].name, "custom_node") + + if __name__ == "__main__": unittest.main() From cec5396648fa1aacfd914e6c838642efd8420976 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 8 Sep 2025 15:26:59 -0700 Subject: [PATCH 015/123] Do not try to fold op.SplitToSequence when split is `None` (#2550) split is an optional input to op.SplitToSequence. --- onnxscript/optimizer/_constant_folding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 350277cc01..62c28894c0 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -784,6 +784,9 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. """ input = node.inputs[0] + if len(node.inputs) == 1: + # split is not provided + return None split = node.inputs[1] output = node.outputs[0] From 647b22ab412c28dd5c4721f26930a934ffefb807 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 9 Sep 2025 10:19:41 -0700 Subject: [PATCH 016/123] Bump version to 0.5.0 (#2538) Because there will be breaking changes in this release --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 267577d47e..8f0916f768 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.4.1 +0.5.0 From 0e79b62b0ba1c91b8c3ea53b348d17e1da6cf58a Mon Sep 17 00:00:00 2001 From: Ayoub BIH <89558574+AyoubMDL@users.noreply.github.com> Date: Tue, 9 Sep 2025 23:44:21 +0200 Subject: [PATCH 017/123] [Rewriter] Add fuse batchnorm to default rules (#2553) This PR adds `fuse_batchnorm` rules to default rules. --------- Co-authored-by: Justin Chu --- onnxscript/rewriter/__init__.py | 2 ++ .../rewriter/rules/common/_fuse_batchnorm.py | 21 ++++--------------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 232750af78..fc000dc176 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -35,6 +35,7 @@ _broadcast_to_matmul, _cast_constant_of_shape, _collapse_slices, + _fuse_batchnorm, _fuse_pad_into_conv, _fuse_relus_clips, _min_max_to_clip, @@ -53,6 +54,7 @@ *_basic_rules.basic_optimization_rules(), *_redundant_scatter_nd.rules, *_fuse_pad_into_conv.rules, + *_fuse_batchnorm.rules, ) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py index a5ceb00468..9d8b8f23f4 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -15,7 +15,7 @@ """ from abc import ABC, abstractmethod -from typing import Mapping +from typing import ClassVar, Mapping import numpy as np @@ -33,16 +33,6 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra class _FuseBatchNormBase(RewriteRuleClassBase, ABC): """Interface for BatchNormalization nodes fusion.""" - def __init__( - self, - op_type: str, - name: str | None = None, - remove_nodes: bool = True, - as_function: bool = False, - ) -> None: - super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function) - self.op_type = op_type - @abstractmethod def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: """Return the axis along which BatchNorm scale should be broadcasted.""" @@ -116,8 +106,7 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M class FuseBatchNormIntoConv(_FuseBatchNormBase): """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" - def __init__(self): - super().__init__("Conv") + op_type: ClassVar = "Conv" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 0 @@ -133,8 +122,7 @@ def pattern(self, op, x): class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" - def __init__(self): - super().__init__("ConvTranspose") + op_type: ClassVar = "ConvTranspose" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 1 @@ -150,8 +138,7 @@ def pattern(self, op, x): class FuseBatchNormIntoGemm(_FuseBatchNormBase): """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" - def __init__(self): - super().__init__("Gemm") + op_type: ClassVar = "Gemm" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return ( From 821015a652c31381349c5ec7de62b8a21a0fe3cb Mon Sep 17 00:00:00 2001 From: Kaiyu Shi Date: Wed, 10 Sep 2025 21:45:52 +0800 Subject: [PATCH 018/123] Add Conv-Affine(Mul+Add) and hardswish fusion (#2472) Close #2468 - Absorbs Affine into Conv: - Mul + Add + Conv ==> Conv - Conv + Mul + Add ==> Conv - Fuse HardSwish: - Add + Clip + Div ==> HardSigmoid - HardSigmoid + Mul ==> HardSwish - Add + Clip + Mul + Div ==> HardSwish (Since the order of operator matters, I have to create different rewrite pattern for this) May not be generic enough, but works for us in `paddleOCRv4` model. Another question is hardswish is introduced in opset-v14, will onnxscript handles older opset version or rewrite rules take care of this? --------- Co-authored-by: Kaiyu Shi --- onnxscript/rewriter/rules/common/__init__.py | 8 + .../rules/common/_fuse_conv_affine.py | 112 ++++++++++++++ .../rules/common/_fuse_conv_affine_test.py | 115 ++++++++++++++ .../rewriter/rules/common/_fuse_hardswish.py | 142 ++++++++++++++++++ .../rules/common/_fuse_hardswish_test.py | 117 +++++++++++++++ 5 files changed, 494 insertions(+) create mode 100644 onnxscript/rewriter/rules/common/_fuse_conv_affine.py create mode 100644 onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py create mode 100644 onnxscript/rewriter/rules/common/_fuse_hardswish.py create mode 100644 onnxscript/rewriter/rules/common/_fuse_hardswish_test.py diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 0b01bade72..14ed3587f3 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -2,11 +2,13 @@ # Licensed under the MIT License. __all__ = [ "add_0_rule", + "affine_conv_fusion_rule", "cast_cast_rule", "cast_constant_of_shape_rule", "cast_constant_of_shape_without_value_rule", "collapse_slice_rule", "collapse_slice2_rule", + "conv_affine_fusion_rule", "div_by_1_rule", "dropout_inference_rule", "dropout_zero_rule", @@ -14,6 +16,7 @@ "fuse_batchnorm_into_conv_rule", "fuse_batchnorm_into_conv_transpose_rule", "fuse_batchnorm_into_gemm_rule", + "fuse_hardswish_rules", "fuse_pad_into_conv_integer_rule", "fuse_pad_into_conv_rule", "min_min_rule", @@ -76,6 +79,11 @@ fuse_batchnorm_into_conv_transpose_rule, fuse_batchnorm_into_gemm_rule, ) +from onnxscript.rewriter.rules.common._fuse_conv_affine import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) +from onnxscript.rewriter.rules.common._fuse_hardswish import fuse_hardswish_rules from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( fuse_pad_into_conv_integer_rule, fuse_pad_into_conv_rule, diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py new file mode 100644 index 0000000000..2aaba5cd73 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Absorbs affine operation into convolution (best effort): +- Conv(Mul(Add(x))) -> Conv (only conv without padding can be fused) +- Add(Mul(Conv)) -> Conv (for all convolutions) +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import get_const_value, get_singleton_value + + +class _ConvAffineFusionBase(pattern.RewriteRuleClassBase): + def check( + self, + context, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + if get_const_value(w) is None: + return check_result.fail("The weight of Conv should be constant") + if get_const_value(b) is None: + return check_result.fail("The bias of Conv should be constant") + if get_singleton_value(scale) is None: + return check_result.fail("Operand for Mul should be constant scalar value") + if get_singleton_value(offset) is None: + return check_result.fail("Operand for Add should be constant scalar value") + return check_result + + +class AffineConvFusion(_ConvAffineFusionBase): + """Pattern: scalar Mul + scalar Add + Conv (1x1) --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return op.Conv( + x * scale + offset, + w, + b, + pads=[0, 0, 0, 0], + _allow_other_attributes=True, + _outputs=["conv_out"], + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_value = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor( + b_value + np.sum(w_value * offset_value, axis=(1, 2, 3), keepdims=False) + ) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_value, offset_bias, **conv_attributes) + + +class ConvAffineFusion(_ConvAffineFusionBase): + """Pattern: Conv + scalar Mul + scalar Add --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return ( + op.Conv(x, w, b, _allow_other_attributes=True, _outputs=["conv_out"]) * scale + + offset + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_weight = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor(b_value * scale_value + offset_value) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_weight, offset_bias, **conv_attributes) + + +affine_conv_fusion_rule = AffineConvFusion().rule() +conv_affine_fusion_rule = ConvAffineFusion().rule() diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py new file mode 100644 index 0000000000..4f1f671f43 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import rewrite, testing +from onnxscript.rewriter.rules.common import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) + + +class FuseConvAffineTest(unittest.TestCase): + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def test_conv_affine_fusion(self): + tape = ir.tape.Tape() + x = ir.Input( + "x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT) + ) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + conv_out = tape.op("Conv", [x, w, b], attributes={"pads": [1, 1, 1, 1]}) + mul_out = tape.op("Mul", [conv_out, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.Input( + "z", + shape=ir.Shape([1, 3, 32, 32]), + type=ir.TensorType(ir.DataType.FLOAT), + ), + ) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[z], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[conv_affine_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + def test_affine_conv_fusion_without_pad(self): + tape = ir.tape.Tape() + x = ir.Input( + "x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT) + ) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + mul_out = tape.op("Mul", [x, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.Input( + "z", + shape=ir.Shape([1, 3, 32, 32]), + type=ir.TensorType(ir.DataType.FLOAT), + ), + ) + conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]}) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[conv_out], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[affine_conv_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish.py b/onnxscript/rewriter/rules/common/_fuse_hardswish.py new file mode 100644 index 0000000000..6d2e8c84e1 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Div(Clip(Add(x))) -> HardSigmoid +- Mul(HardSigmoid(x), x) -> HardSwish +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import is_singleton_value +from onnxscript.rewriter._rewrite_rule import RewriteRuleSet + + +class _HardSigmoidFusionBase(pattern.RewriteRuleClassBase): + """HardSwish requires constant values so we check in base class.""" + + def check( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + + if not is_singleton_value(clip_min, 0.0, rtol=1e-4): + return check_result.fail("Swish requires min value of 0 for clip") + if not is_singleton_value(clip_max, 6.0, rtol=1e-4): + return check_result.fail("Swish requires max value of 6 for clip") + if not is_singleton_value(bias, 3.0, rtol=1e-4): + return check_result.fail("Swish requires bias value of 3") + if not is_singleton_value(divisor, 6.0, rtol=1e-4): + return check_result.fail("Swish requires divisor value of 6") + return check_result + + +class HardSwishFusion(_HardSigmoidFusionBase): + """Fuse Add(_, 3) + Clip<0, 6>(_) + Mul + Div(_, 6) into HardSwish + + In this case we can't make HardSigmoid fusion first. The Mul + is placed before Div while HardSigmoid requires Add+Clip+Div. + """ + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) * x + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSwish(x) + + +class HardSwishFusionFromHardSigmoid(pattern.RewriteRuleClassBase): + """Fuse HardSigmoid + Mul into HardSwish""" + + def pattern(self, op, x: ir.Value) -> ir.Value: + # Floating point matching for 1/6 is not exact, so we use isclose below + out = op.HardSigmoid(x, _allow_other_attributes=True, _outputs=["hardsigmoid_out"]) + out = out * x + return out + + def check(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> MatchResult: + check_result = MatchResult() + hardsigmoid = hardsigmoid_out.producer() + # Use getter to protect when 'alpha' / 'beta' is not in attributes + alpha = hardsigmoid.attributes.get_float("alpha", -1) + beta = hardsigmoid.attributes.get_float("beta", -1) + if not np.isclose(alpha, 1 / 6): + return check_result.fail( + "HardSigmoid alpha must be 1/6 to get fused into HardSwish" + ) + if not np.isclose(beta, 0.5): + return check_result.fail( + "HardSigmoid beta must be 0.5 to get fused into HardSwish" + ) + return check_result + + def rewrite(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> ir.Value: + return op.HardSwish(x) + + +class HardSigmoidFusion(_HardSigmoidFusionBase): + """Fuse HardSigmoid only for HardSwish hyper-parameters: alpha=1/6, beta=0.5""" + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSigmoid(x, alpha=1 / 6, beta=0.5) + + +def fuse_hardswish_rules() -> RewriteRuleSet: + """Returns the rewrite rules for fusing HardSwish and HardSigmoid.""" + return RewriteRuleSet( + [ + HardSwishFusion().rule(), + HardSigmoidFusion().rule(), + HardSwishFusionFromHardSigmoid().rule(), + ], + commute=True, + ) diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py new file mode 100644 index 0000000000..36556e9cff --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +import onnxruntime as ort +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript import optimizer +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import fuse_hardswish_rules + + +class FuseHardSwishTest(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250621) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = fuse_hardswish_rules().apply_to_model(updated_model) + + # Polish model to remove unused constants + updated_model = optimizer.optimize(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = (self.rng.integers(low=-10, high=10, size=(2 * 32), dtype=np.int32),) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ort_optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + + # Validate serialized model + output_model_proto = ir.to_proto(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_hardsigmoid_fusion(self): + model_text = """ + + hardsigmoid (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + y = Div(clipped, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSigmoid"]) + + def test_hardswish_fusion(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + mul_x = Mul(clipped, x) + y = Div(mul_x, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_mul_last(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + div_x = Div(clipped, six) + y = Mul(div_x, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_from_sigmoid(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + hardsigmoid_out = HardSigmoid(x) + y = Mul(hardsigmoid_out, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + +if __name__ == "__main__": + unittest.main() From 710d597cfacda33e24c936e519b79fd9a344916a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 10 Sep 2025 13:12:16 -0700 Subject: [PATCH 019/123] Fix rewriter and CI tests for the latest onnx-ir version (#2554) Fix rewriter CI tests for the latest onnx-ir version (currently in main). Since the latest onnx-ir is now returning tuples for repeated attributes, we need to update the comparison logic to account for that. --------- Signed-off-by: Justin Chu --- onnxscript/rewriter/_fusion_utils.py | 2 +- onnxscript/rewriter/_pattern_ir.py | 9 ++++++++- onnxscript/rewriter/_rewrite_rule.py | 4 ++-- onnxscript/rewriter/ort_fusions/attention.py | 2 +- .../rewriter/ort_fusions/fused_matmul_rule_sets.py | 6 +++--- .../ort_fusions/fused_matmul_rule_sets_test.py | 12 ++++++------ onnxscript/rewriter/ort_fusions/gqa.py | 2 +- onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py | 2 +- onnxscript/rewriter/ort_fusions/mha.py | 2 +- onnxscript/rewriter/ort_fusions/mha_bias.py | 2 +- .../rewriter/ort_fusions/skip_normalization.py | 4 ++-- onnxscript/rewriter/pattern_test.py | 2 +- 12 files changed, 28 insertions(+), 21 deletions(-) diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index dbf16ae3d3..f6a7204ac8 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -13,7 +13,7 @@ Dim = Union[int, ir.SymbolicDim] -def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: +def check_shape_bool(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: if val.shape is None: return False if val.shape.rank() != len(shape): diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index f64d3fca3c..9b81e33581 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -126,7 +126,14 @@ def __init__(self, value: SupportedAttrTypes): self._value = value def matches(self, attr: ir.Attr) -> bool: - return isinstance(attr, ir.Attr) and attr.value == self._value + if attr.type in { + ir.AttributeType.INTS, + ir.AttributeType.FLOATS, + ir.AttributeType.STRINGS, + }: + # Since the type of attr.value is Sequence, we need to convert to the same type for comparison. + return tuple(attr.value) == tuple(self._value) + return attr.value == self._value def __str__(self) -> str: return str(self._value) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 9481ca5077..af0165dea0 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -392,7 +392,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False """ @@ -463,7 +463,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index 4a4cd0ad8e..ce234bbb63 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -160,7 +160,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 5082c20464..cdc50c99ae 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -79,7 +79,7 @@ def check( # Check that last two dimensions are swapped expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - if perm != expected_perm: + if list(perm) != expected_perm: return check_result.fail("Permutation values for Transpose are not correct.") elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or ( self._pos == 2 and not _ir_utils.has_rank(y, 2) @@ -188,7 +188,7 @@ def check( trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" trans_batch = fused_node.attributes.get_int(trans_batch_property, 0) transposed_node = _get_node(transposed, "Transpose") - perm = transposed_node.attributes["perm"].as_ints() + perm = list(transposed_node.attributes["perm"].as_ints()) if not perm: return check_result.fail("Permutation values for Transpose are not correct.") @@ -296,7 +296,7 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult: if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2): if perm: # Check that the two dimensions are swapped - if perm != [1, 0]: + if tuple(perm) != (1, 0): return check_result.fail( "Permutation values for Transpose are not correct." ) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index 527d4826d5..f82702d557 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -284,7 +284,7 @@ def _check_model( opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul]) expected = ref.run(None, feeds) got = opt.run(None, feeds) - self.assertEqual(len(expected), len(got)) + self.assertEqual(len(got), len(expected)) for a, b in zip(expected, got): np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) @@ -319,7 +319,7 @@ def test_fused_matmul_div_models(self, name, script_func, input_types, output_ty rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Constant", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Constant", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -354,7 +354,7 @@ def test_fused_matmul_with_transpose(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand([("should_not_match", _should_not_match)]) @@ -366,8 +366,8 @@ def test_should_not_match(self, _, script_func): self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) self.assertEqual( - ["Transpose", "MatMul", "Transpose"], [n.op_type for n in ir_model.graph], + ["Transpose", "MatMul", "Transpose"], ) self._check_model(model_proto, rewritten_model, atol=1e-6) @@ -391,7 +391,7 @@ def test_fused_matmul_with_other_node_in_middle(self, _, script_func): common_passes.ShapeInferencePass()(ir_model) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Identity", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Identity", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -440,7 +440,7 @@ def test_transpose_fused_matmul_with_batch(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 99852f712a..5fff910bcf 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -247,7 +247,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return False diff --git a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py index 0d404b2754..51355fc8cf 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py +++ b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py @@ -84,7 +84,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) # Check that if x is being split into q, k, v correctly # based on hidden sizes diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index e2987cfc5e..321e895f44 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -157,7 +157,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/mha_bias.py b/onnxscript/rewriter/ort_fusions/mha_bias.py index 28b9646ddc..9ecf2ce017 100644 --- a/onnxscript/rewriter/ort_fusions/mha_bias.py +++ b/onnxscript/rewriter/ort_fusions/mha_bias.py @@ -78,7 +78,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if query_matmul.dtype not in valid_float_types: return check_result.fail("Query is not a float or float16 type.", query_matmul) diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py index f7a376aef9..c76a7454cb 100644 --- a/onnxscript/rewriter/ort_fusions/skip_normalization.py +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -60,7 +60,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( @@ -184,7 +184,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 49ace2fb81..0a29080b4d 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -674,7 +674,7 @@ def test_model(x: FLOAT[1024, 512], y: FLOAT[1024, 512]) -> FLOAT[512, 1024]: function = model.functions[function_id] self.assertEqual([x.op_type for x in function], ["Add", "Transpose"]) transpose_node = function[1] - self.assertEqual(transpose_node.attributes["perm"].value, [1, 0]) + self.assertEqual(list(transpose_node.attributes["perm"].value), [1, 0]) onnxscript.optimizer.inline(model) self.assertEqual([x.op_type for x in model.graph], ["Add", "Transpose"]) From 50d7e87f6d64418d5fb542b14612d4d560967384 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 10 Sep 2025 17:19:49 -0700 Subject: [PATCH 020/123] [torchlib] Mark atan2 as trace_only and map NaN to 0 (#2557) Fix https://github.com/pytorch/pytorch/issues/162570 --------- Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 11 ++++++++--- tests/function_libs/torch_lib/ops_test_data.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3607a11361..a66faae0be 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -925,16 +925,21 @@ def aten_atan(self: TFloat) -> TFloat: return op.Atan(self) -@torch_op("aten::atan2") +@torch_op("aten::atan2", trace_only=True) def aten_atan2(self: TFloat, other: TFloat) -> TFloat: """atan2(Tensor self, Tensor other) -> Tensor""" # self is y, and other is x on coordinate slope = op.Div(self, other) atan = op.Atan(slope) + zero = common_ops.constant(0.0, dtype=self.dtype) + pi = common_ops.constant(_MATH_PI, dtype=self.dtype) - second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI) - result = op.Where(other < 0.0, second_third_quadrant, atan) + second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi) + result = op.Where(op.Less(other, zero), second_third_quadrant, atan) + + # Map NaN to 0 to match PyTorch behavior + result = op.Where(op.IsNaN(result), zero, result) return result diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 646a5133fa..0cf8898241 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -578,7 +578,7 @@ def _where_input_wrangler( TorchLibOpInfo("asin", core_ops.aten_asin), TorchLibOpInfo("asinh", core_ops.aten_asinh), TorchLibOpInfo("atan", core_ops.aten_atan), - TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}), + TorchLibOpInfo("atan2", core_ops.aten_atan2), TorchLibOpInfo("atanh", core_ops.aten_atanh), TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), From 366f7be321f3c44a1236a0f702b492cf767418e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 12 Sep 2025 12:03:04 +0200 Subject: [PATCH 021/123] [torchlib] Fix repeat_interleave when repeats is a symbolic tensor (#2548) --- .../function_libs/torch_lib/ops/core.py | 35 ++++++++++++------- .../function_libs/torch_lib/e2e_ops_tests.py | 21 +++++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a66faae0be..6698a2ccdb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7332,16 +7332,25 @@ def aten_repeat_interleave_self_int( self_rank = len(self.shape) pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) - tiles = [1] * (self_rank + 1) - tiles[pos_dim + 1] = repeats - tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) - tiled = op.Tile(unsqueezed, tile_repeat) + if isinstance(repeats, int): + tiles = [1] * (self_rank + 1) + tiles[pos_dim + 1] = repeats + tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) + else: + # repeats is a symbolic tensor + tile_repeat = op.Concat( + op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)), + op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), + op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), + axis=0, + ) + tiled = op.Expand(unsqueezed, tile_repeat) if self_rank == 1: return op.Identity(tiled) final_shape = op.Concat( op.Shape(self, start=0, end=dim), op.Constant(value_ints=[-1]), - op.Shape(self, start=dim + 1), + op.Shape(self, start=pos_dim + 1), axis=0, ) return op.Reshape(tiled, final_shape) @@ -7380,20 +7389,22 @@ def aten_repeat_interleave_Tensor( if dim is None: # flatten self = op.Reshape(self, [-1]) - rk = 1 + rank = 1 else: - rk = len(self.shape) + rank = len(self.shape) - if rk > 2: + if rank > 2: shape_x0 = op.Shape(self, start=0, end=1) shape_x = op.Shape(self, start=1) self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) - elif rk == 1: + elif rank == 1: shape_x = None self = op.Reshape(self, [-1, 1]) else: - if rk != 2: - raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave") + if rank != 2: + raise NotImplementedError( + f"rank(self)={rank} not implemented for repeat_interleave" + ) shape_x = None ci = op.CumSum(repeats, [0]) @@ -7406,7 +7417,7 @@ def aten_repeat_interleave_Tensor( ) indices = op.Reshape(srows, [-1]) values = op.GatherND(self, op.Unsqueeze(indices, [-1])) - if rk == 2: + if rank == 2: return values # shape_x is None at this stage. assert shape_x is None # for mypy diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 253637ccd2..c0139328a4 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -137,6 +137,27 @@ def forward(self, x, ind): ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_symbolic_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, y): + return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave( + y, x.shape[1], dim=1 + ) + + inputs = ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "y"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_sdpa_with_bool_attn_mask(self): class ScaledDotProductAttention(torch.nn.Module): def forward(self, query, key, value, attn_mask): From 8ed3521a5040daa1a517fe9baa987c6cf48621b9 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 12 Sep 2025 07:35:49 -0700 Subject: [PATCH 022/123] Support `enable_gqa` and only support 4D Q, K, and V (#2558) 1. Support `enable_gqa` 2. Align PyTorch setting to unsupport Q, K, and V when they are not 4D: https://github.com/pytorch/pytorch/blob/62843c14bbf694f5722fd6e1075da4792507fe42/torch/onnx/_internal/exporter/_torchlib/ops/nn.py#L131-L133 NOTE: torch.nn.functional.scaled_dot_product_attention actually supports 3D, and even Q-3D with K and V - 4D in op tests. --- onnxscript/function_libs/torch_lib/ops/nn.py | 77 +++++++++++++++++-- .../function_libs/torch_lib/e2e_ops_tests.py | 30 ++++++++ .../function_libs/torch_lib/ops_test_data.py | 12 +++ 3 files changed, 114 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 88b5bf807e..1a31c9eac8 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1741,6 +1741,64 @@ def _attention_scale(query: TFloat) -> TFloat: return scale +def _attention_repeat_kv_for_group_query( + query: TFloat, key: TFloat, value: TFloat +) -> Tuple[TFloat, TFloat]: + """Expand key and value for group query attention. + + repeat_interleave is applied on key and value to match the number of heads in query. + + Args: + query: Tensor of shape [B, q_num_heads, q_S, E] + key: Tensor of shape [B, k_num_heads, kv_S, E] + value: Tensor of shape [B, v_num_heads, kv_S, E] + + Returns: + Tuple of (expanded_key, expanded_value) where: + - expanded_key: Tensor of shape [B, q_num_heads, kv_S, E] + - expanded_value: Tensor of shape [B, q_num_heads, kv_S, E + """ + + assert ( + query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0 + ), ( + "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" + ) + + # NOTE: QKV are expected to be 4D tensors + + batch_size = op.Shape(query, start=0, end=1) # [B] + q_num_heads = op.Shape(query, start=1, end=2) # [Hq] + kv_num_heads = op.Shape(key, start=1, end=2) # [Hk] + qk_head_size = op.Shape(key, start=3, end=4) # [Dk] + v_head_size = op.Shape(value, start=3, end=4) # [Dv] + new_kv_seq_len = op.Shape(key, start=2, end=3) # [T] + + interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk + two = op.Constant(value_int=2) + k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk] + v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv] + + k_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0 + ) + k_expand = op.Expand(k_unsqueezed, k_expand_shape) + v_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0 + ) + v_expand = op.Expand(v_unsqueezed, v_expand_shape) + + k_attention_shape = op.Concat( + batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0 + ) + v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0) + + expanded_key = op.Reshape(k_expand, k_attention_shape) + expanded_value = op.Reshape(v_expand, v_attention_shape) + + return expanded_key, expanded_value + + @torch_op("aten::scaled_dot_product_attention", trace_only=True) def aten_scaled_dot_product_attention( query: TFloat, @@ -1772,8 +1830,8 @@ def aten_scaled_dot_product_attention( "is_causal and attn_mask cannot be set at the same time" ) - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" ) # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -1784,6 +1842,13 @@ def aten_scaled_dot_product_attention( if is_causal: attn_mask = _causal_attention_mask(query, key) + if enable_gqa: + key, value = _attention_repeat_kv_for_group_query(query, key, value) + else: + assert query.shape[1] == key.shape[1] == value.shape[1], ( + "SDPA (MHA) requires q_num_heads = kv_num_heads" + ) + if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( query, key, value, scale, dropout_p @@ -1981,9 +2046,8 @@ def aten_scaled_dot_product_attention_bool_mask( assert (not is_causal) or (is_causal and attn_mask is None), ( "is_causal and attn_mask cannot be set at the same time" ) - - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" ) if scale is None: @@ -1997,6 +2061,9 @@ def aten_scaled_dot_product_attention_bool_mask( query, key, value, attn_mask, scale, dropout_p ) + if enable_gqa: + key, value = _attention_repeat_kv_for_group_query(query, key, value) + if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( query, key, value, scale, dropout_p diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index c0139328a4..1b0410c27f 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -195,6 +195,36 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_enable_gqa_in_attention(self): + class Model(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable + q, + k, + v, + enable_gqa=True, + ) + + model = Model() + + query = torch.randn(2, 4, 8, 16) + key = torch.randn(2, 2, 8, 16) + value = torch.randn(2, 2, 8, 16) + + onnx_program = torch.onnx.export( + model, + ( + query, + key, + value, + ), + input_names=["query", "key", "value"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 0cf8898241..cf3dd9cf83 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1908,6 +1908,12 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", + ) + .xfail( + matcher=lambda sample: len(sample.input.shape) != 4 + or len(sample.args[0].shape) != 4 + or len(sample.args[1].shape) != 4, + reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( "ops.aten._scaled_dot_product_flash_attention", @@ -1959,6 +1965,12 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", + ) + .xfail( + matcher=lambda sample: len(sample.input.shape) != 4 + or len(sample.args[0].shape) != 4 + or len(sample.args[1].shape) != 4, + reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", From 39f1015ec7d394384a0c931482b71d0d52311554 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:12:59 +0000 Subject: [PATCH 023/123] [torchlib] Implement torch.ops.prims.broadcast_in_dim.default (#2382) This PR implements the missing `torch.ops.prims.broadcast_in_dim.default` operation that appears in BERT_pytorch and other PyTorch models. ## Overview The `broadcast_in_dim` operation is a primitive that broadcasts a tensor to a target shape by specifying which dimensions of the output correspond to the input tensor dimensions. This is different from standard broadcasting operations. ## Implementation Details **Function signature:** ```python def prims_broadcast_in_dim( a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int] ) -> TensorType: ``` **Parameters:** - `a`: Input tensor to broadcast - `shape`: Target output shape - `broadcast_dimensions`: Specifies which dimensions of the output shape correspond to the input tensor dimensions **Example:** ```python # Input tensor: [3, 4] # Target shape: [2, 3, 5, 4] # broadcast_dimensions: [1, 3] # Result: Input dimension 0 (size 3) maps to output dimension 1 # Input dimension 1 (size 4) maps to output dimension 3 # Output dimensions 0 and 2 are broadcasted (filled from size 1) ``` Fixes #2218. Fix https://github.com/pytorch/pytorch/issues/135343 --------- Signed-off-by: Justin Chu Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu --- .../function_libs/torch_lib/ops/prims.py | 25 +++++++++++-- tests/function_libs/torch_lib/extra_opinfo.py | 36 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 1 + 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index ed870b0d7d..f53e9c1133 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -176,12 +176,33 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("prims::broadcast_in_dim", trace_only=True) def prims_broadcast_in_dim( - a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int] + a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int] ) -> TensorType: """broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)""" - raise NotImplementedError() + target_rank = len(shape) + + if not broadcast_dimensions: + # Special case: no broadcast dimensions - all target dims should be 1 + return op.Expand(a, common_ops.merge_dims(shape)) + + # Create base shape of all 1s + ones = [1] * target_rank + + # For each broadcast dimension, we'll replace the 1 with the actual input dimension + # Since broadcast_dimensions is compile-time known, we can do this with individual operations + intermediate_shape = ones + + for i, broadcast_dim in enumerate(broadcast_dimensions): + # Get the input dimension value + input_dim_value = op.Shape(a, start=i, end=i + 1) + intermediate_shape[broadcast_dim] = input_dim_value + + # Reshape input to intermediate shape and expand to target + reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape)) + return op.Expand(reshaped, shape) def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType: diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index ca80cf5172..4f4a3872e1 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -87,6 +87,35 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) +def sample_inputs_broadcast_in_dim(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + # cases: (input_shape, target_shape, broadcast_dimensions) + # broadcast_dimensions maps each input dim to an axis in target_shape + cases = ( + # scalar -> 1-D tensor + ((), (3,), ()), + # identity (no-op broadcast) + ((3,), (3,), (0,)), + # rank-preserving broadcast where singleton dims expand + ((1, 3, 1), (2, 3, 4), (0, 1, 2)), + # input rank 2 -> output rank 3, input dims map to trailing axes + ((3, 1), (2, 3, 4), (1, 2)), + # add leading broadcast axis + ((3, 4), (1, 3, 4), (1, 2)), + # insert broadcasting in middle axis + ((3,), (2, 3, 1), (1,)), + ) + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + for shape, target_shape, broadcast_dimensions in cases: + tensor = make_arg(shape) + yield opinfo_core.SampleInput(tensor, args=(target_shape, broadcast_dimensions)) + + def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -2687,6 +2716,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_trilinear3d_vec, supports_out=False, ), + opinfo_core.ReductionOpInfo( + "ops.prims.broadcast_in_dim.default", + op=torch.ops.prims.broadcast_in_dim.default, + dtypes=common_dtype.all_types(), + sample_inputs_func=sample_inputs_broadcast_in_dim, + supports_out=False, + ), opinfo_core.ReductionOpInfo( "ops.prims.var.default", nan_policy="propagate", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index cf3dd9cf83..b1e0c529ec 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2136,6 +2136,7 @@ def _where_input_wrangler( "Our implementation is based on that for CUDA" ), ), + TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim), TorchLibOpInfo( "ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)} ), From 8944f04c372d845df2430bde5fac3a45147978f9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Sep 2025 09:48:41 -0700 Subject: [PATCH 024/123] Bump version from 0.5.0 to 0.5.1 (#2559) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 8f0916f768..4b9fcbec10 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.0 +0.5.1 From 92633a694a3ca7ded2c0cf4d331bd2ab385b7034 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Sep 2025 15:55:37 -0700 Subject: [PATCH 025/123] Remove CheckerPass from ort_fusion (#2560) Since onnxruntime defines `SimplifiedLayerNormalization` incorrectly in the standard domain, the checker will fail. Fixing this for Olive. --- onnxscript/rewriter/ort_fusions/_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 8f3c7c463a..ea7af31b3e 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -150,7 +150,6 @@ def optimize_for_ort( common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1), common_passes.RemoveInitializersFromInputsPass(), common_passes.ShapeInferencePass(), - common_passes.CheckerPass(), ) assert passes.in_place result = passes(model) From a70ee8d0905f563c840bbd5338595e9ac6b1b5b4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Sep 2025 16:24:59 -0700 Subject: [PATCH 026/123] Use ir.val to replace ir.Input (#2556) Use ir.val to replace ir.Input because ir.Input was deprecated --------- Signed-off-by: Justin Chu --- noxfile.py | 2 +- onnxscript/ir/__init__.py | 154 +----------------- .../bfloat16_utils/bfloat16_converter_test.py | 6 +- .../rules/common/_basic_rules_test.py | 10 +- .../rules/common/_fuse_pad_into_conv_test.py | 8 +- .../rules/common/_matmul_add_to_gemm_test.py | 8 +- pyproject.toml | 3 +- 7 files changed, 20 insertions(+), 171 deletions(-) diff --git a/noxfile.py b/noxfile.py index f69c5af9bd..989b10b16e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -42,7 +42,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.7" +ONNX_IR = "onnx_ir==0.1.9" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 3fa204b405..6240347886 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -1,154 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""In-memory intermediate representation for ONNX graphs.""" - -__all__ = [ - # Modules - "serde", - "traversal", - "convenience", - "external_data", - "tape", - # IR classes - "Tensor", - "ExternalTensor", - "StringTensor", - "LazyTensor", - "SymbolicDim", - "Shape", - "TensorType", - "OptionalType", - "SequenceType", - "SparseTensorType", - "TypeAndShape", - "Value", - "Attr", - "RefAttr", - "Node", - "Function", - "Graph", - "GraphView", - "Model", - # Constructors - "AttrFloat32", - "AttrFloat32s", - "AttrGraph", - "AttrGraphs", - "AttrInt64", - "AttrInt64s", - "AttrSparseTensor", - "AttrSparseTensors", - "AttrString", - "AttrStrings", - "AttrTensor", - "AttrTensors", - "AttrTypeProto", - "AttrTypeProtos", - "Input", - # Protocols - "ArrayCompatible", - "DLPackCompatible", - "TensorProtocol", - "ValueProtocol", - "ModelProtocol", - "NodeProtocol", - "GraphProtocol", - "GraphViewProtocol", - "AttributeProtocol", - "ReferenceAttributeProtocol", - "SparseTensorProtocol", - "SymbolicDimProtocol", - "ShapeProtocol", - "TypeProtocol", - "MapTypeProtocol", - "FunctionProtocol", - # Enums - "AttributeType", - "DataType", - # Types - "OperatorIdentifier", - # Protobuf compatible types - "TensorProtoTensor", - # Conversion functions - "from_proto", - "from_onnx_text", - "to_proto", - # Convenience constructors - "tensor", - "node", - # Pass infrastructure - "passes", - # IO - "load", - "save", -] - -from onnx_ir import ( - ArrayCompatible, - Attr, - AttrFloat32, - AttrFloat32s, - AttrGraph, - AttrGraphs, - AttributeProtocol, - AttributeType, - AttrInt64, - AttrInt64s, - AttrSparseTensor, - AttrSparseTensors, - AttrString, - AttrStrings, - AttrTensor, - AttrTensors, - AttrTypeProto, - AttrTypeProtos, - DataType, - DLPackCompatible, - ExternalTensor, - Function, - FunctionProtocol, - Graph, - GraphProtocol, - GraphView, - GraphViewProtocol, - Input, - LazyTensor, - MapTypeProtocol, - Model, - ModelProtocol, - Node, - NodeProtocol, - OperatorIdentifier, - OptionalType, - RefAttr, - ReferenceAttributeProtocol, - SequenceType, - Shape, - ShapeProtocol, - SparseTensorProtocol, - SparseTensorType, - StringTensor, - SymbolicDim, - SymbolicDimProtocol, - Tensor, - TensorProtocol, - TensorProtoTensor, - TensorType, - TypeAndShape, - TypeProtocol, - Value, - ValueProtocol, - convenience, - external_data, - from_onnx_text, - from_proto, - load, - node, - passes, - save, - serde, - tape, - tensor, - to_proto, - traversal, -) +# pylint: disable=wildcard-import,unused-wildcard-import +from onnx_ir import * # type: ignore # noqa: F403 diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index b9666fba3a..a64d6e6023 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -14,11 +14,11 @@ class Bfloat16ConversionTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4])) + self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4])) self.v0.dtype = ir.DataType.BFLOAT16 - self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4])) + self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4])) self.v1.dtype = ir.DataType.BFLOAT16 - self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4])) + self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4])) self.v2.dtype = ir.DataType.BFLOAT16 self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 9ce74b22a2..7d4e9d9b33 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -421,14 +421,14 @@ def _convert_shape(shape, name): if isinstance(shape, np.ndarray): shape = tape.initializer(ir.Tensor(shape, name=name)) elif isinstance(shape, (list, tuple)): - shape = ir.Input(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64)) + shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape)) tape.graph_like.inputs.append(shape) else: raise TypeError(f"Unsupported type {type(shape)} for shape.") return shape - x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT)) - y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) # Build the graph. @@ -554,8 +554,8 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg): class Flatten2ReshapeTest(unittest.TestCase): @staticmethod def create_model(input_shape, axis=1): - x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT)) - y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) # Build the graph. diff --git a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py index 740f8b3358..ded57fe023 100644 --- a/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -61,13 +61,13 @@ def build_model( # Register operations in the tape idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT - x = ir.Input("X", shape=input_shape, type=ir.TensorType(idtype)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype)) y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes) y = tape.op( op_type, inputs=[y, self.get_conv_weights(weight_shape, tape)], attributes=conv_attributes, - output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)), + output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)), ) if op_type == "ConvInteger": y.dtype = ir.DataType.INT32 @@ -290,12 +290,12 @@ def build_model( raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") # Register operations in the tape - x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) y = tape.op( "Conv", inputs=[x, *conv_inputs], attributes=conv_attributes, - output=ir.Input("Y", shape=output_shape, type=x.type), + output=ir.val("Y", shape=output_shape, type=x.type), ) # Build the model diff --git a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index c4f9abe65c..4c643801fc 100644 --- a/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -46,10 +46,10 @@ def get_test_model( bias_shape = weight_shape[0] if transB else weight_shape[-1] output_shape = ir.Shape(("?",) * input_shape.rank()) - x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) if weight_as_inputs: - w = ir.Input("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) + w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) inputs.append(w) else: w = ir.tensor( @@ -58,7 +58,7 @@ def get_test_model( w = tape.initializer(w) if bias_as_inputs: - b = ir.Input( + b = ir.val( "B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) ) inputs.append(b) @@ -77,7 +77,7 @@ def get_test_model( y = tape.op( "Add", inputs=[y, b], - output=ir.Input("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), ) # Build the model diff --git a/pyproject.toml b/pyproject.toml index 1f720c1168..3df6b3995c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dependencies = [ "ml_dtypes", "numpy", - "onnx_ir>=0.1.7,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx_ir>=0.1.9,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. "onnx>=1.16", "packaging", "typing_extensions>=4.10", @@ -41,7 +41,6 @@ onnxscript = ["py.typed"] onnx = ["py.typed"] [tool.pytest.ini_options] -filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"] addopts = "-rsfEX --tb=short --color=yes" [tool.mypy] From ea790222deaada24d4567cd49400b8838a96c31c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:12:02 -0700 Subject: [PATCH 027/123] chore(deps): bump ruff from 0.12.11 to 0.13.0 in /requirements/lintrunner (#2563) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index a17c852e86..0dd608a643 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.12.11 +ruff==0.13.0 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250402 From f529292844a863b1aa77a20ea531c6bb0291a506 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 16 Sep 2025 10:37:06 -0700 Subject: [PATCH 028/123] Bump version from 0.5.1 to 0.5.2 (#2565) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 4b9fcbec10..cb0c939a93 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.1 +0.5.2 From 3156bed261246c842cbec5f1825cd1667a71a857 Mon Sep 17 00:00:00 2001 From: deoxy Date: Fri, 19 Sep 2025 00:13:47 +0900 Subject: [PATCH 029/123] [torchlib] Fix aten_gather to correctly handle scalar indices (#2566) Fixes #2564 Signed-off-by: Linsho Kaku --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6698a2ccdb..95fbe39811 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3814,11 +3814,15 @@ def aten_gather( else: return op.Expand(self, op.Shape(index)) - if len(index.shape) == 0: - return op.Identity(self) + is_scalar_index = len(index.shape) == 0 + if is_scalar_index: + index = op.Unsqueeze(index, [0]) index = op.Cast(index, to=INT64.dtype) result = op.GatherElements(self, index, axis=dim) + + if is_scalar_index: + result = op.Squeeze(result, [0]) return result From 79afb878b4f516c6d4997de101dec7541ea42df9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Sep 2025 14:33:50 -0700 Subject: [PATCH 030/123] [rewriter] Remove generic pattern matcher (#2567) It is obsolete and the capability is covered by the simple pattern matcher. --------- Signed-off-by: Justin Chu --- .lintrunner.toml | 1 - examples/pattern_rewriting.py | 25 - onnxscript/rewriter/_rewrite_rule.py | 7 +- onnxscript/rewriter/generic_pattern.py | 702 -------------------- onnxscript/rewriter/generic_pattern_test.py | 607 ----------------- pyproject.toml | 34 - 6 files changed, 1 insertion(+), 1375 deletions(-) delete mode 100644 onnxscript/rewriter/generic_pattern.py delete mode 100644 onnxscript/rewriter/generic_pattern_test.py diff --git a/.lintrunner.toml b/.lintrunner.toml index 7b31bab564..907f3bfce6 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -57,7 +57,6 @@ exclude_patterns = [ 'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME 'onnxscript/tools/function_unittest_producer.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME - 'onnxscript/rewriter/generic_pattern.py', # FIXME ] command = [ 'python', diff --git a/examples/pattern_rewriting.py b/examples/pattern_rewriting.py index 7b5c56d5e3..fd84d7f3cb 100644 --- a/examples/pattern_rewriting.py +++ b/examples/pattern_rewriting.py @@ -141,28 +141,3 @@ def rotary_apply_pattern(op, x, pos_ids, axis): rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10) rule.apply_to_model(ir_model) - -# TODO(rama): Update the following, the trace-printed looks different now. - -###################################### -# The logs shows every time the algorithm rejected a pattern. -# We can see the following: -# -# :: -# -# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast -# --hint--: BACKWARD: different node types -# --pattern -# ConcatTraining(transpose, transpose) -> (output, length) -# -- model -# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1) -# iteration=1 -# --marked-- #2 -# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320] -# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472] -# len(stacked)=0:[] -# -# Line 673 in file `generic_pattern.py`, the match was rejected. -# It says while comparing two nodes in the backward direction, -# node types do not match. -# It also says that two nodes were actually matched. diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index af0165dea0..8964230fe0 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -82,12 +82,7 @@ def __init__( if isinstance(matcher, _matcher.PatternMatcher): self._matcher = matcher elif matcher is None: - if target_pattern.has_single_output_node: - self._matcher = _matcher.SimplePatternMatcher(self._target_pattern) - else: - import onnxscript.rewriter.generic_pattern as generic_pattern - - self._matcher = generic_pattern.GenericPatternMatcher(self._target_pattern) + self._matcher = _matcher.SimplePatternMatcher(self._target_pattern) else: self._matcher = matcher(self._target_pattern) self._verbose = verbose diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py deleted file mode 100644 index 12827b3116..0000000000 --- a/onnxscript/rewriter/generic_pattern.py +++ /dev/null @@ -1,702 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import collections -import inspect -import os -import textwrap -import warnings -from typing import Any, Callable, Iterator, Sequence - -import onnxscript.rewriter.pattern as orp -from onnxscript import ir - - -class PatternMatchResult: - """Stores information about a match if a match was successful. - - * pattern: the GraphPattern which found this result - * model_nodes: the graph nodes that matched the pattern - * matched_pattern_to_model_value: a mapping from ValuePattern to ir.Value - * kwargs: additional attributes the user may add through the method - :meth:`PatternMatchResult.add_kwargs` - """ - - def __init__( - self, - pattern: orp.GraphPattern, - model_nodes: Sequence[ir.Node], - ): - pattern_nodes: list[orp.NodePattern] = list(pattern) - assert len(model_nodes) == len(pattern_nodes) - self.pattern = pattern - self.model_nodes = model_nodes - self.kwargs: dict[str, Any] = {} - self.matched_pattern_to_model_value: dict[orp.ValuePattern, ir.Value] = {} - - for graph_node, pattern_node in zip(model_nodes, pattern_nodes): - assert graph_node.op_identifier() == pattern_node.op_identifier(), ( - f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}" - ) - assert len(graph_node.inputs) == len(pattern_node.inputs), ( - f"Unexpected number of inputs for type {graph_node.op_identifier()}" - ) - for a, b in zip(graph_node.inputs, pattern_node.inputs): - if b is None: - # optional input or not an interesting input - continue - self._bind(b, a) - - assert len(graph_node.outputs) == len(pattern_node.outputs), ( - f"Unexpected number of outputs for type {graph_node.op_identifier()}" - ) - for a, b in zip(graph_node.outputs, pattern_node.outputs): - self._bind(b, a) - - def _bind(self, value_pattern: orp.ValuePattern, value: ir.Value) -> None: - map = self.matched_pattern_to_model_value - if value_pattern in map: - assert map[value_pattern] == value, ( - f"Ambiguities, pattern output {value_pattern!r} means " - f"{value!r} or {map[value_pattern]}" - ) - else: - map[value_pattern] = value - - def add_kwargs(self, name: str, value: Any): - """Adds an attribute, it can be done when the match is being validated, - this attribute can be used when building the replacement nodes. - """ - self.kwargs[name] = value - - def __repr__(self) -> str: - return ( - f"PatternMatchResult: {len(self.model_nodes)} nodes ..., {self.pattern.inputs}, " - f"{self.pattern.outputs})" - ) - - -def _to_match_result(pmr: PatternMatchResult) -> orp.MatchResult: - """Converts a PatternMatchResult into a MatchResult. - - TODO: This is a temporary hack until MatchResult and PatternMatchResult are unified. - """ - result = orp.MatchResult() - for node in pmr.model_nodes: - result.add_node(node) - - for var, val in pmr.matched_pattern_to_model_value.items(): - if var.name is not None: - result.bind(var.name, val) - result.outputs.extend([pmr.matched_pattern_to_model_value[v] for v in pmr.pattern.outputs]) - return result - - -def _value_to_str(value: ir.Value | orp.ValuePattern) -> str: - return value.name if value.name is not None else "anonymous:" + str(id(value)) - - -def _opt_value_to_str(value: ir.Value | orp.ValuePattern | None) -> str: - return _value_to_str(value) if value is not None else "None" - - -def _node_to_str(node: ir.Node | orp.NodePattern) -> str: - inputs = ", ".join(_opt_value_to_str(input) for input in node.inputs) - outputs = ", ".join(_opt_value_to_str(output) for output in node.outputs) - op_type = node.op_type - domain = str(node.domain) - qualified_op = f"{domain}.{op_type}" if domain else op_type - return f"{outputs} = {qualified_op}({inputs})" - - -# def _pattern_node_to_str(node: orp.NodePattern) -> str: -# inputs = ", ".join(_opt_value_to_str(input) for input in node.inputs) -# outputs = ", ".join(_opt_value_to_str(output) for output in node.outputs) -# return f"{outputs} = {node.op_type}({inputs})" - - -class GenericPatternMatcher(orp.PatternMatcher): - """ - Implements a pattern optimization for quick experimentation. - - Current limitation: - - * The current implementation does match on domain name (easy fix). - * It does not compares attributes either (easy fix as well). - """ - - def __init__(self, pattern: orp.GraphPattern) -> None: - super().__init__(pattern) - - def enumerate_matches( - self, - model: ir.Model, - graph_or_function: ir.Graph | ir.Function, - node: ir.Node | None = None, - verbose: int = 0, - ) -> Iterator: - """Enumerates all the matches.""" - if node is None: - matched = [] - for node in graph_or_function: - res = self.match(model, graph_or_function, node, verbose=verbose) - if res: - matched.append(res) - yield res - else: - res = self.match(model, graph_or_function, node, verbose=verbose) - if res: - yield res - - def none( - self, - node: ir.Node | None = None, - lineno: int | None = None, - msg: str = "", - ) -> None: - """Must be called every time a match fails to trace it. - - It may be useful which reason made a pattern matching fail. - Instead of returning None, method *match* can return the following - expression: - - :: - - return self.none(node, inspect.currentframe().f_lineno) - - By setting the verbosity (see next Section), the user may then know - which lines in the code returned None and which condition failed. - If logs are fully enabled, it shows information about matched none - and the line deciding the matched failed. - For example, this tells the matching failed at line 601 in ``generic_pattern.py``. - It happens when propagating the match in the backward directions. - The unmatched types are Mul, MatMul and below, - it shows the matched nodes. The first one was Cast. - And the failure happened at iteration 5. - ``139774002356544-139774000632672`` is the pair of ids used in container ``matched``. - ``id(node)`` is used as a unique identifiers of the nodes. - - :: - - [RotaryEmbeddingPattern.match] NONE - line: 601:__main__, op_type=Cast - --hint--: BACKWARD: different node types - --pattern - Mul(pos_ids, cast) -> (mul) - -- model - MatMul(/_original_modu...Expand_output_0, /_original_modu...b/Cast_output_0) -> (/_original_modu...MatMul_output_0) - iteration=5 - --matched-- #6 - Cast(/_original_modu...mb/Cos_output_0) ~ Cast(cos) [139774002356544-139774000632672] - Cos(/_original_modu...ncat_1_output_0) ~ Cos(concattraining-transpose-0) [139774002356448-139774000632048] - ConcatTraining(/_original_modu...nspose_output_0,/_original_modu...nspose_output_0) ~ ConcatTraining(transpose,transpose) [139774002356352-139774000631712] - Transpose(/_original_modu...MatMul_output_0) ~ Transpose(mul) [139774002356256-139774000631184] - Sin(/_original_modu...ncat_1_output_0) ~ Sin(concattraining-transpose-0) [139774002358512-139774000631568] - Cast(/_original_modu...mb/Sin_output_0) ~ Cast(sin) [139774002358608-139774000632384] - len(stack)=0:[] - - 'hints' are not added everywhere. More can easily be added with method ``_hint``. - """ - if node and self.verbose: - if self.verbose >= 10: - if hasattr(self, "_debug"): - msg2 = self._debug_print() - if msg2: - msg2 = f"\n{textwrap.indent(msg2, ' ')}" - else: - msg2 = "" - print( - f"[{self.__class__.__name__}.match] Match failed at line: {lineno}:" - f"{os.path.split(self.__class__.__module__)[-1]}, " - f"op_type={node.op_type}{msg}{msg2}" - ) - return None - - def print_match(self, graph_node: ir.Node, pattern_node: orp.NodePattern) -> str: - s1 = _node_to_str(graph_node) - s2 = _node_to_str(pattern_node) - return f"match {s1} with pattern: {s2}" - - def _debug_print(self) -> str: - if not hasattr(self, "_debug"): - return "" - - def _s(s: str) -> str: - if len(s) <= 30: - return s - return f"{s[:15]}...{s[-15:]}" - - def _p(n: ir.Node, full: bool = False) -> str: - if full: - return str(n) - return _node_to_str(n) - - rows = [] - for k, v in sorted(self._debug.items()): - if k == "stack": - rows.append(f"len({k})={len(v)}:{v}") # type: ignore[arg-type] - continue - if k == "iteration": - rows.append(f"{k}={v}") - continue - if k == "matched": - rows.append(f"--matched-- #{len(v)}") # type: ignore[arg-type] - for pattern_node, graph_node in v.items(): - rows.append( - f" {_p(pattern_node)} ~ {_p(graph_node)} [{id(pattern_node)}-{id(graph_node)}]" - ) - continue - if k == "hint": - rows.append(f"--hint--: {v[0]}") # type: ignore[arg-type] - for i in v[1:]: - if isinstance(i, str): - rows.append(" " + i) - if isinstance(i, ir.Node): - rows.append(" " + _p(i, full=True)) - continue - if k in {"node", "pattern", "pattern_node", "pattern_nodes"}: - continue - rows.append(f"-- not shown {k}") - - return "\n".join(rows) - - def _hint(self, *args: Any) -> None: - """Add debugging information to help users.""" - self._debug["hint"] = args - - def _match_backward( - self, - starting_node: ir.Node, - matched: dict[orp.NodePattern, ir.Node], - stack: list[orp.NodePattern], - graph_node: ir.Node, - pattern_node: orp.NodePattern, - ) -> int | None: - """ - Matches backward. - - Args: - starting_node: root node (the node the matched begain with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_node: node coming from the graph - pattern_node: node coming from the pattern - - Returns: - number of matched nodes, None or False to indicate a failed match - """ - match_count = 0 - - # predecessors - if len(graph_node.inputs) != len(pattern_node.inputs): - # not the same number of inputs - self._hint( - "BACKWARD: not the same number of inputs", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - - for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs): - if len(graph_input.uses()) != len(pattern_input.uses()): - self._hint( - "BACKWARD: one input is used outside the pattern", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - - for graph_value, pattern_value in zip(graph_node.inputs, pattern_node.inputs): - # TODO(rama): Handle constant-pattern - pattern_pred = pattern_value.producer() - if pattern_pred is None: - # pattern_pred is None means the pattern backward search ends here. - result = self._match_values_forward( - starting_node, matched, stack, graph_value, pattern_value - ) - if result is None: - return result - match_count += result - continue - graph_pred = graph_value.producer() - if graph_pred is None: - # No node in the graph. - return self.none(starting_node, inspect.currentframe().f_lineno) - if graph_pred.op_identifier() != pattern_pred.op_identifier(): - self._hint( - "BACKWARD: different node types", - "--pattern", - _node_to_str(pattern_pred), - "-- model", - _node_to_str(graph_pred), - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - # matching backward - if pattern_pred not in matched: - if self.verbose >= 10: - print( - f"[GenericPattern._match_backward] {self.print_match(graph_pred, pattern_pred)}" - ) - matched[pattern_pred] = graph_pred - stack.append(pattern_pred) - match_count += 1 - if self.verbose > 5 and match_count > 0: - print(f"[GenericPatternMatcher._match_backward] add {match_count} nodes") - return match_count - - def _match_values_forward( - self, - starting_node: ir.Node, - matched: dict[orp.NodePattern, ir.Node], - stack: list[orp.NodePattern], - graph_value: ir.Value, - pattern_value: orp.ValuePattern, - ) -> int | None: - """ - Matches forward. - - Args: - starting_node: root node (the node the match begins with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_value: value coming from the graph - pattern_value: pattern value coming from the pattern - - Returns: - number of matched nodes to continue, None or False to indicate a failed match - """ - match_count = 0 - graph_node_users = [user for user, _ in graph_value.uses()] - pattern_node_users = [user for user, _ in pattern_value.uses()] - if not pattern_node_users: - # The pattern has no node forward, the matching stops. - return match_count - if len(graph_node_users) < len(pattern_node_users): - # Not enough node in the graph to match the pattern. A match is not possible - return self.none(starting_node, inspect.currentframe().f_lineno) - - # Here comes the fun part, there is the same number of successors or more - # nodes in the graph to match with the pattern. - # And we have to handle the nodes already matched as found. - # Hopefully, there is only one option. - - if len(graph_node_users) == len(pattern_node_users) == 1: - # Let's deal with the simple case - if graph_node_users[0].op_identifier() != pattern_node_users[0].op_identifier(): - return self.none(starting_node, inspect.currentframe().f_lineno) - - node = pattern_node_users[0] - if node not in matched: - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_values_forward]{self.print_match(graph_node_users[0], pattern_node_users[0])}" - ) - matched[node] = graph_node_users[0] - stack.append(node) - match_count += 1 - return match_count - - # Let's remove the nodes already matched. - pattern_node_users_not_matched = [ - unmatched_node - for unmatched_node in pattern_node_users - if unmatched_node not in matched - ] - pattern_node_users_matched = [ - matched[matched_node] - for matched_node in pattern_node_users - if matched_node in matched - ] - assert len(pattern_node_users_matched) + len(pattern_node_users_not_matched) == len( - pattern_node_users - ), ( - f"pattern_node_users_not_matched={pattern_node_users_not_matched}, " - f"pattern_node_users_matched={pattern_node_users_matched}, " - f"pattern_node_users={pattern_node_users}, " - f"matched={matched}" - ) - free = list(set(graph_node_users) - set(pattern_node_users_matched)) - if not pattern_node_users_not_matched: - # Everything is already matched. - return match_count - if len(free) < len(pattern_node_users_not_matched): - # Not enough successors to match the remaining patterns. - return self.none(starting_node, inspect.currentframe().f_lineno) - if len(pattern_node_users_not_matched) == len(free) == 1: - # Only one option again. - graph_node = free[0] - if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier(): - return self.none(starting_node, inspect.currentframe().f_lineno) - - key = pattern_node_users_not_matched[0] - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_values_forward] {self.print_match(graph_node, pattern_node_users_not_matched[0])}" - ) - matched[key] = graph_node - stack.append(key) - match_count += 1 - return match_count - - # And now another fun part, let's try to handle the case when - # there is only one option, matching on node type only returns one - # option. - expected_op_type = [_.op_identifier() for _ in pattern_node_users_not_matched] - got_op_type = [_.op_identifier() for _ in free] - - ec = collections.Counter(expected_op_type) - gc = collections.Counter(got_op_type) - if len(ec) != len(gc) or set(ec) != set(gc): - # unique operator types is different. - self._hint( - "FORWARD: unique operator types are different", - "-- pattern", - ec, - pattern_value, - "-- model", - gc, - graph_value, - "-- model-matched", - pattern_node_users_matched, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - for k, v in ec.items(): - if gc[k] < v: - # Not enough types to match. - return self.none(starting_node, inspect.currentframe().f_lineno) - - # At this stage, we know matching the types is possible. - # We first mark whatever is possible. - ptype_to_node = {_.op_identifier(): _ for _ in pattern_node_users_not_matched} - gtype_to_node = {_.op_identifier(): _ for _ in free} - missing = [] - for k, v in ec.items(): - if gc[k] == v == 1: - key = id(ptype_to_node[k]) - if key not in matched: - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_values_forward] match " - f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}" - ) - matched[key] = gtype_to_node[k] - stack.append(key) - match_count += 1 - else: - missing.append(k) - - if not missing: - return match_count - - # At this stage, there are mutiple options for matching. We can: - # 1. make assumptions and continue - # 2. mark the node as incomplete matching, we could end up stuck anyway. - raise NotImplementedError( - f"There are more than one option, this will be implemented later, ec={ec}, gc={gc}" - ) - - def _match_forward( - self, - starting_node: ir.Node, - matched: dict[orp.NodePattern, ir.Node], - stack: list[orp.NodePattern], - graph_node: ir.Node, - pattern_node: orp.NodePattern, - ) -> int | None: - """ - Matches forward. - - Args: - starting_node: root node (the node the match begins with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_node: node coming from the graph - pattern_node: node coming from the pattern - - Returns: - number of matched nodes to continue, None or False to indicate a failed match - """ - match_count = 0 - - # successors - if len(graph_node.outputs) != len(pattern_node.outputs): - # not the same number of outputs - self._hint( - "FORWARD: not the same number of output_names", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - - for graph_output, pattern_output in zip(graph_node.outputs, pattern_node.outputs): - result = self._match_values_forward( - starting_node, matched, stack, graph_output, pattern_output - ) - if result is None: - return result - match_count += result - - if self.verbose > 5 and match_count > 0: - print(f"[GenericPatternMatcher._match_forward] add {match_count} nodes") - return match_count - - def match( - self, - model: ir.Model, - graph_or_function: ir.Graph | ir.Function, - node: ir.Node, - *, - verbose: int = 0, - remove_nodes: bool = True, - tracer: orp.MatchingTracer | None = None, - ) -> orp.MatchResult | None: - if not remove_nodes: - raise NotImplementedError( - "remove_nodes=False is not implemented in GenericPatternMatcher" - ) - del model - del graph_or_function - self.verbose = verbose - self._debug = {} - - # Let's match the last node. - # Then we need to match successors and predecessors. - last_pattern_node = self.pattern.node(-1) - if node.op_identifier() != last_pattern_node.op_identifier(): - # The last node does not have the same op_identifier(). - return self.none() - - if self.verbose > 5: - print( - f"[GenericPatternMatcher.match] Matching started at node: {_node_to_str(node)}" - ) - if self.verbose >= 10: - print(f"[GenericPatternMatcher.match] match pattern {self}") - - all_pattern_nodes = set(self.pattern) - matched: dict[orp.NodePattern, ir.Node] = {last_pattern_node: node} - stack: list[orp.NodePattern] = [last_pattern_node] - iteration = 0 - - if self.verbose > 5: - self._debug = dict( - pattern=self.pattern, - matched=matched, - stack=stack, - iteration=iteration, - node=node, - pattern_node=last_pattern_node, - pattern_nodes=self.pattern, - ) - - max_iter = self.pattern.num_nodes() * 2 - while stack and iteration < max_iter: - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert not nodes_not_in_pattern, ( - f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - f"\nall_pattern_nodes={all_pattern_nodes}" - ) - - # TODO(justinchuby): Change to a for loop - iteration += 1 - if self.verbose > 5: - print( - f"[GenericPatternMatcher.match] iteration={iteration} " - f"n_matched={len(matched)}, n_stack={len(stack)}, " - f"matched_types={collections.Counter(_.op_identifier() for _ in matched)}" - ) - next_pattern_node = stack.pop() - next_graph_node = matched[next_pattern_node] - - result = self._match_backward( - node, matched, stack, next_graph_node, next_pattern_node - ) - if result is None: - if self.verbose > 5: - print("[GenericPatternMatcher.match] done. backward failed.") - return result - - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert not nodes_not_in_pattern, ( - f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - ) - - result = self._match_forward( - node, matched, stack, next_graph_node, next_pattern_node - ) - if result is None: - if self.verbose > 5: - print("[GenericPatternMatcher.match] done. forward failed.") - return result - - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert not nodes_not_in_pattern, ( - f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - ) - - if self.verbose > 5: - self._debug["iteration"] = iteration - - if iteration >= max_iter and stack: - self._hint(f"reached {iteration}>={max_iter} iterations") - return self.none(node, inspect.currentframe().f_lineno) - - if self.verbose > 5: - print(f"[GenericPatternMatcher.match] done. {len(matched)} matched nodes") - - # At this point, the pattern is matched but let's make sure. - assert len(matched) == self.pattern.num_nodes(), ( - f"Number of matched nodes is different, {len(matched)} matched nodes, " - f"and {len(self.pattern)} nodes in the pattern, matched is {matched}" - ) - assert len(stack) == 0, f"There are still {len(stack)} nodes to explore." - - # We order the matched nodes in the same order than the pattern - # to let next functions to be able to build the matching again. - matched_nodes = [matched[pattern_node] for pattern_node in self.pattern] - return _to_match_result(PatternMatchResult(self.pattern, matched_nodes)) - - -def make_pattern_rule( - match_pattern_function: Callable, - apply_pattern_function: Callable, - validate_mapping: Callable | None = None, - verbose: int = 0, -) -> orp.RewriteRule: - """ - Creates a rewriting rule from a callable or a function proto. - - Args: - match_pattern_function: an onnxscript-like function that defines - the pattern subgraph (nodes) to be replaced - apply_pattern_function: an onnxscript-like function that constructs - the replacement subgraph (new nodes replacing the matched nodes) - validate_mapping: a function that validates the matching subgraph once - it is found. If it returns False the pattern is not applied. - If not specified, it is equivalent to a function that always return True - verbose: verbosity level - - Returns: - the rewriting rule - """ - - warnings.warn( - "make_pattern_rule(...) is deprecated, use pattern.RewriteRule(...) instead", - FutureWarning, - stacklevel=2, - ) - pattern = orp._to_graph_pattern(match_pattern_function) - matcher = GenericPatternMatcher(pattern) - return orp.RewriteRule( - pattern, - apply_pattern_function, - validate_mapping, - matcher, - verbose=verbose, - ) diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py deleted file mode 100644 index dadaf5e8bb..0000000000 --- a/onnxscript/rewriter/generic_pattern_test.py +++ /dev/null @@ -1,607 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import contextlib -import io -import os -import unittest - -import numpy as np -import onnx -import onnx.parser -import onnx.reference -import onnxruntime as ort -import parameterized - -from onnxscript import ir -from onnxscript.rewriter import generic_pattern, pattern - -FLOAT = onnx.TensorProto.FLOAT - - -@parameterized.parameterized_class( - ("matcher_algo",), - [ - (generic_pattern.GenericPatternMatcher,), - (pattern.SimplePatternMatcher,), - ], -) -class GenericPatternTest(unittest.TestCase): - def _range(self, *shape, bias: float | None = None): - n = np.prod(shape) - x = np.arange(n).astype(np.float32) / n - if bias: - x = x + bias - return x.reshape(tuple(shape)).astype(np.float32) - - def test_graph_pattern_builder(self): - """Test replacing Add + Add by AddAdd.""" - - def match_pattern(op, x, y, z): - """Builds the pattern to match.""" - tmp = op.Add(x, y) - return op.Add(tmp, z) - - def apply_pattern(op, x, y, z, **_): - """Builds the replacement graph.""" - return op.AddAdd(x, y, z, _domain="ZZZ") - - def validate_mapping(context, x, y, z, **_) -> bool: - """Validates the mapping.""" - del context - return True - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - validate_mapping, - self.matcher_algo, - ) - - class AddAdd(onnx.reference.op_run.OpRun): - op_domain = "ZZZ" - - def _run(self, x, y, z): - return (x + y + z,) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["x", "y"], ["gggg"]), - onnx.helper.make_node("Add", ["gggg", "z"], ["final"]), - ], - "dummy", - [ - onnx.helper.make_tensor_value_info("x", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("y", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("z", FLOAT, [None, None]), - ], - [onnx.helper.make_tensor_value_info("final", FLOAT, [None, None])], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ir_version=9, - ) - onnx.checker.check_model(model) - - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule.apply_to_model(ir_model) - self.assertEqual( - ["AddAdd"], - [n.op_type for n in ir_model.graph], - ) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - rewriten_model = ir.serde.serialize_model(ir_model) - self.assertEqual( - ["AddAdd"], - [n.op_type for n in rewriten_model.graph.node], - ) - - feeds = { - "x": self._range(5, 6), - "y": self._range(5, 6), - "z": self._range(5, 6), - } - ref1 = onnx.reference.ReferenceEvaluator(model) - expected = ref1.run(None, feeds) - - self.assertEqual(0, len(rewriten_model.graph.initializer)) - opsets = {v.domain: v.version for v in rewriten_model.opset_import} - self.assertIn("ZZZ", opsets) - self.assertEqual(opsets["ZZZ"], 1) - - ref2 = onnx.reference.ReferenceEvaluator(rewriten_model, new_ops=[AddAdd]) - got = ref2.run(None, feeds) - np.testing.assert_almost_equal(expected[0], got[0]) - - def test_graph_pattern_builder_multi_outputs(self): - def match_pattern(op, x, y, w, z): - """Builds the pattern to match.""" - tmp = op.Add(x, y) - tmp2 = op.Add(tmp, w) - r1 = op.Add(tmp, z) - return tmp2, r1 - - def apply_pattern(op, x, y, w, z, **_): - """Builds the pattern to match.""" - return op.AddAddAddAdd(x, y, w, z, _domain="ZZZ", _outputs=2) - - def validate_mapping(context, **_) -> bool: - return True - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - validate_mapping, - self.matcher_algo, - verbose=10, - ) - - class AddAddAddAdd(onnx.reference.op_run.OpRun): - op_domain = "ZZZ" - - def _run(self, x, y, w, z): - return (x + y + w, x + y + z) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["x", "y"], ["gggg"]), - onnx.helper.make_node("Add", ["gggg", "w"], ["f1"]), - onnx.helper.make_node("Add", ["gggg", "z"], ["f2"]), - ], - "dummy", - [ - onnx.helper.make_tensor_value_info("x", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("y", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("z", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("w", FLOAT, [None, None]), - ], - [ - onnx.helper.make_tensor_value_info("f1", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("f2", FLOAT, [None, None]), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ir_version=9, - ) - onnx.checker.check_model(model) - - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule.apply_to_model(ir_model) - self.assertEqual( - ["AddAddAddAdd"], - [n.op_type for n in ir_model.graph], - ) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - self.assertEqual( - ["AddAddAddAdd"], - [n.op_type for n in rewriten_model.graph.node], - ) - - feeds = { - "x": self._range(5, 6), - "y": self._range(5, 6), - "w": self._range(5, 6), - "z": self._range(5, 6), - } - ref1 = onnx.reference.ReferenceEvaluator(model) - expected = ref1.run(None, feeds) - - self.assertEqual(0, len(rewriten_model.graph.initializer)) - opsets = {v.domain: v.version for v in rewriten_model.opset_import} - self.assertIn("ZZZ", opsets) - self.assertEqual(opsets["ZZZ"], 1) - - ref2 = onnx.reference.ReferenceEvaluator(rewriten_model, new_ops=[AddAddAddAdd]) - got = ref2.run(None, feeds) - np.testing.assert_almost_equal(expected[0], got[0]) - - def check_with_ort(self, model: onnx.ModelProto, providers=None): - if providers is None: - providers = ["CPUExecutionProvider"] - - if isinstance(model, onnx.ModelProto): - model = model.SerializeToString() - session = ort.InferenceSession(model, providers=providers) - return session - - def get_rotary_model(self): - inputs = [ - onnx.helper.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]), - onnx.helper.make_tensor_value_info("pos_ids", FLOAT, shape=[]), - onnx.helper.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]), - ] - nodes = [ - onnx.helper.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]), - onnx.helper.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1), - onnx.helper.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]), - onnx.helper.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]), - onnx.helper.make_node( - "ConcatTraining", - ["_onx_transpose0", "_onx_transpose0"], - ["_onx_concattraining0", "_onx_concattraining1"], - domain="com.microsoft", - ), - onnx.helper.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]), - onnx.helper.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1), - onnx.helper.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]), - onnx.helper.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1), - ] - outputs = [ - onnx.helper.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []), - onnx.helper.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []), - ] - model = onnx.helper.make_model( - onnx.helper.make_graph( - nodes, - "experiment", - inputs, - outputs, - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 18), - ], - ) - return model - - def test_shared_root_value_test(self): - def match_pattern(op, x): - t1 = op.Sin(x) - t2 = op.Cos(x) - return t1, t2 - - def apply_pattern(op, x, **_): - return op.SinCos(x, _domain="com.microsoft", _outputs=2) - - rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo) - model_proto = onnx.parser.parse_model( - """ - - agraph (float[N] y) => (float[N] z) - { - temp1 = Sin(y) - temp2 = Cos(y) - z = Add(temp1, temp2) - } - """ - ) - onnx.checker.check_model(model_proto) - model = onnx.shape_inference.infer_shapes(model_proto) - ir_model = ir.serde.deserialize_model(model) - rule.apply_to_model(ir_model) - rewritten_model = ir.serde.serialize_model(ir_model) - graph = rewritten_model.graph - self.assertEqual(len(graph.node), 2) - self.assertEqual(graph.node[0].op_type, "SinCos") - - def test_shared_root_value_extra_use(self): - if self.matcher_algo is generic_pattern.GenericPatternMatcher: - raise unittest.SkipTest("GenericPatternMatcher does not support extra uses yet.") - - def match_pattern(op, x): - t1 = op.Sin(x) - t2 = op.Cos(x) - return t1, t2 - - def apply_pattern(op, x, **_): - return op.SinCos(x, _domain="com.microsoft", _outputs=2) - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - matcher=self.matcher_algo, - ) - model_proto = onnx.parser.parse_model( - """ - - agraph (float[N] y) => (float[N] z) - { - temp1 = Sin(y) - temp2 = Cos(y) - w = Add(temp1, temp2) - z = Mul(w, y) - } - """ - ) - onnx.checker.check_model(model_proto) - model = onnx.shape_inference.infer_shapes(model_proto) - ir_model = ir.serde.deserialize_model(model) - rule.apply_to_model(ir_model) - graph = ir_model.graph - self.assertEqual(len(graph), 3) - self.assertEqual(graph.node(0).op_type, "SinCos") - - def test_rotary_embedding(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def match_pattern(op, x, pos_ids, axis): - # original code: the code does verifies the constant yet - # unsqueeze = op.Unsqueeze(x, [1]) - - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, _length = op.ConcatTraining( - transpose, - transpose, - _domain="com.microsoft", - _outputs=2, - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_mapping(match_result, **_) -> bool: - del match_result - return True - - def apply_pattern(op, x, pos_ids, axis, **_): - del axis - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - return op.RotaryEmbedding( - x, - pos_ids, - cos_cache, - sin_cache, - _domain="com.microsoft", - _outputs=2, - ) - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - validate_mapping, - self.matcher_algo, - verbose=10, - ) - - model = self.get_rotary_model() - - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - # back to ir - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - # starts matching - rule.apply_to_model(ir_model) - ir_model.opset_imports["com.microsoft"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Constant", "Constant", "RotaryEmbedding"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - out = buffer.getvalue() - # TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working? - if self.matcher_algo is generic_pattern.GenericPatternMatcher: - self.assertIn("[GenericPatternMatcher.match", out) - - def test_rotary_embedding_onnxscript(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def rotary_match_pattern(op, x, pos_ids, axis): - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, _length = op.ConcatTraining( - transpose, transpose, _domain="com.microsoft", _outputs=2 - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_rotary_mapping(match_result, **_) -> bool: - # If some pattern needs to be rejected. - del match_result - return True - - def rotary_apply_pattern(op, x, pos_ids, axis, **_): - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 - ) - return part1, part2 - - rule = pattern.RewriteRule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, - self.matcher_algo, - verbose=10, - ) - - model = self.get_rotary_model() - - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - # back to ir - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - # starts matching - rule.apply_to_model(ir_model) - ir_model.opset_imports["com.microsoft"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Constant", "Constant", "RotaryEmbedding"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - out = buffer.getvalue() - # TODO(justinchuby): Remove this assert - capturing stdout is not robust - if self.matcher_algo is generic_pattern.GenericPatternMatcher: - self.assertIn("[GenericPatternMatcher.match", out) - - def test_rotary_emb_file_onnxscript(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def rotary_match_pattern(op, x, pos_ids, axis): - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, _length = op.ConcatTraining( - transpose, transpose, _domain="com.microsoft", _outputs=2 - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_rotary_mapping(match_result, **_) -> bool: - # If some pattern needs to be rejected. - del match_result - return True - - def rotary_apply_pattern(op, x, pos_ids, axis): - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 - ) - return part1, part2 - - model_path = "gemma_optimized_pre_grad_training_2.onnx" - if not os.path.exists(model_path): - raise unittest.SkipTest(f"{model_path!r} is missing") - model = onnx.load(model_path) - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule = pattern.RewriteRule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, - self.matcher_algo, - verbose=10, - ) - - rule.apply_to_model(ir_model) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - buffer = rewriten_model.SerializeToString() - with open(f"{model}.opt.onnx", "wb") as f: - f.write(buffer) - self.check_with_ort(rewriten_model) - - def test_transpose_transpose_onnxscript(self): - # TODO(rama): Attribute-parameters not yet supported in multi-output matching. - # def transpose_transpose_pattern(op, X, perm0, perm1): - # xt = op.Transpose(X, perm=perm0) - # Y = op.Transpose(xt, perm=perm1) - # return Y - - def transpose_transpose_pattern(op, X): - XT = op.Transpose(X, _outputs=["XT"]) - Y = op.Transpose(XT, _outputs=["Y"]) - return Y - - def transpose_transpose_mapping(perm0, perm1): - new_perm = [0 for p in perm0] - for i, p in enumerate(perm1): - new_perm[i] = perm0[p] - # replace by return [perm0[p] for p in perm1] ? - return new_perm - - def transpose_transpose_check(op, **_) -> bool: - return True - - def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): - perm0 = XT.producer().attributes.get("perm") - if perm0 is not None: - perm0 = perm0.value # TODO(rama): handle RefAttr - perm1 = Y.producer().attributes.get("perm") - if perm1 is not None: - perm1 = perm1.value # TODO(rama): handle RefAttr - if perm0 is None and perm1 is None: - return op.Identity(X) - if perm0 is None: - perm0 = range(len(perm1) - 1, -1, -1) - if perm1 is None: - perm1 = range(len(perm0) - 1, -1, -1) - composed_perm = transpose_transpose_mapping(perm0, perm1) - return op.Transpose(X, perm=composed_perm) - - rule = pattern.RewriteRule( - transpose_transpose_pattern, - transpose_transpose_apply_pattern, - transpose_transpose_check, - self.matcher_algo, - verbose=0, - ) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]), - onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ) - - # back to ir - ir_model = ir.serde.deserialize_model(model) - - # starts matching - - rule.apply_to_model(ir_model) - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Transpose"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - node = rewriten_model.graph.node[0] - self.assertEqual(len(node.attribute), 1) - att = node.attribute[0] - self.assertEqual(att.name, "perm") - self.assertEqual(list(att.ints), [2, 0, 1]) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/pyproject.toml b/pyproject.toml index 3df6b3995c..5f31581494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,40 +79,6 @@ module = [ ] ignore_errors = true -# FIXME(#1378): Remove this overrides section -[[tool.mypy.overrides]] -module = [ - "onnxrewriter.rewriter.generic_pattern_test.*", -] -check_untyped_defs = false -disable_error_code = 'override,import-untyped,no-untyped-def,assignment' -disallow_incomplete_defs = true -disallow_untyped_defs = true -disallow_untyped_decorators = true -show_column_numbers = true -strict_optional = true -warn_incomplete_stub = true -warn_no_return = true -warn_unused_configs = true -warn_unused_ignores = false - -# FIXME(#1378): Remove this overrides section -[[tool.mypy.overrides]] -module = [ - "onnxrewriter.rewriter.generic_pattern.*", -] -check_untyped_defs = false -disable_error_code = 'override,import-untyped,no-untyped-def,assignment,union-attr,func-returns-value,annotation-unchecked,arg-type,index,name-defined,attr-defined' -disallow_incomplete_defs = true -disallow_untyped_defs = true -disallow_untyped_decorators = true -show_column_numbers = true -strict_optional = true -warn_incomplete_stub = true -warn_no_return = true -warn_unused_configs = true -warn_unused_ignores = false - [tool.pylint.messages_control] # NOTE: This list is for vscode. Add new disables in pyproject_pylint.toml for lintrunner # Exclude patterns should be modified in .lintrunner.toml From 27c7f09099c05ddc5cfb1491832f6f6e007eee5b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 22:20:42 +0000 Subject: [PATCH 031/123] chore(deps): bump ruff from 0.13.0 to 0.13.1 in /requirements/lintrunner (#2568) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 0dd608a643..b2be2fa2f3 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.13.0 +ruff==0.13.1 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250402 From f54cf47749ab7ffbe424c6e736ec4d74aa4c15b2 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 23 Sep 2025 10:23:47 -0700 Subject: [PATCH 032/123] Add GQA fusion to ONNX fusions (#2524) Add GQA fusion to ONNX fusions. TODO: * Test cases. (Fusion seems to work on Gemma3, but more to be done.) --------- Signed-off-by: Ganesan Ramalingam Co-authored-by: Justin Chu --- .../rewriter/onnx_fusions/_onnx_fusions.py | 3 +- onnxscript/rewriter/rules/fusion/_gqa.py | 113 ++++++++++++++++++ onnxscript/rewriter/rules/fusion/_gqa_test.py | 97 +++++++++++++++ onnxscript/rewriter/testing.py | 68 ++++++++--- 4 files changed, 263 insertions(+), 18 deletions(-) create mode 100644 onnxscript/rewriter/rules/fusion/_gqa.py create mode 100644 onnxscript/rewriter/rules/fusion/_gqa_test.py diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index bd73cb1f6d..008a995764 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -4,7 +4,7 @@ import onnx_ir as ir -from onnxscript.rewriter.rules.fusion import _rms_normalization, _rotary_embedding +from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding def _get_onnx_opset_version(model: ir.Model) -> int | None: @@ -24,6 +24,7 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: counts: dict[str, int] = {} counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug) + counts["GQA"] = _gqa.fuse_gqa(model, debug=debug) return counts diff --git a/onnxscript/rewriter/rules/fusion/_gqa.py b/onnxscript/rewriter/rules/fusion/_gqa.py new file mode 100644 index 0000000000..8d6f156ed5 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_gqa.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Union + +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _basics, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("ONNXGQA", remove_nodes=False) + + def pattern( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + ): + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2) + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE) + present_key_BHStD = op.Reshape( + present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2) + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) + present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE) + present_value_BHStD = op.Reshape( + present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"] + ) + + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + pattern.Var("mask", can_match_none=True), + _outputs=["attention_BHSDh"], + ) + + return attention_BHSDh + + def check( + self, + context: _basics.MatchContext, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + present_key_BHStD, + present_value_BHStD, + **_, + ): + bindings: dict[str, Dim] = {} + # Check that inputs to new Attention node have expected shapes + _fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"]) + _fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"]) + _fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"]) + # We need to check that the Expand/Reshape arguments are as expected. + # As a substitute, we check that the outputs of Expand=>Reshape have expected shapes. + # TODO (rama): May be better to check the actual Expand/Reshape arguments. + _fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"]) + _fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"]) + + return True + + def rewrite( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + mask, + attention_BHSDh, + **_, + ): + original_attention_node = attention_BHSDh.producer() + original_attrs = original_attention_node.attributes + return op.Attention( + query_BHSD, + key_BHkvSD, + value_BHkvSD, + mask, + past_key_BHkvSpD, + past_value_BHkvSpD, + **original_attrs, + ) + + +_basic_gqa_rule = OnnxGroupQueryAttention.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) diff --git a/onnxscript/rewriter/rules/fusion/_gqa_test.py b/onnxscript/rewriter/rules/fusion/_gqa_test.py new file mode 100644 index 0000000000..baf80c4b8c --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_gqa_test.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx +import onnx_ir as ir +from packaging import version + +import onnxscript +import onnxscript.optimizer +import onnxscript.rewriter.testing +from onnxscript import FLOAT, script +from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa + +op = onnxscript.values.Opset("", 23) + +H = [8] # Number of attention heads +Hkv = [4] # Number of key/value heads (H should be divisible by Hkv) +D = [64] # Head size +G = [2] # Number of groups + + +@script(ir_version=10) +def _gqa_script( + query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64 + key_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + value_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + past_key_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 + past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 +) -> FLOAT[2, 8, 4, 64]: + """Basic GQA pattern that should be fused into an Attention op.""" + + # Concatenate past_key cache and current key + present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D] + + # Unsqueeze to add group dimension + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + + # Calculate shapes dynamically + B = op.Shape(query_BHSD, start=0, end=1) # [B] + T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P] + + # Create expand shape [B, Hkv, G, S+P, D] + expand_shape = op.Concat(B, Hkv, G, T, D, axis=0) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] + + # Create reshape shape [B, H, S+P, D] + reshape_shape = op.Concat(B, H, T, D, axis=0) + present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Same for value + present_value_BHkvStD = op.Concat( + past_value_BHkvPD, value_BHkvSD, axis=-2 + ) # [B, Hkv, S+P, D] + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + present_value_BHkvGStD = op.Expand( + present_value_BHkv1StD, expand_shape + ) # [B, Hkv, G, S+P, D] + present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Attention computation + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + ) + + return attention_BHSDh + + +class GQAFusionTest(unittest.TestCase): + def test_basic_gqa_fusion(self): + """Test basic GQA fusion pattern.""" + model_proto = _gqa_script.to_model_proto() + + # Apply GQA fusion + model = ir.serde.deserialize_model(model_proto) + onnxscript.optimizer.optimize(model) + count = fuse_gqa(model) + self.assertGreater(count, 0, "GQA fusion should have occurred") + + # We can't yet test numerical equivalence because of a bug in the op spec/implementation. + onnx_ver = version.parse(onnx.__version__) + if onnx_ver >= version.parse("1.19.1") and not ( + onnx_ver.is_prerelease or onnx_ver.is_devrelease + ): + # Only official releases >= 1.19.1 + onnxscript.optimizer.remove_unused_nodes(model) + rewritten_model_proto = ir.serde.serialize_model(model) + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, rewritten_model_proto, use_reference=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index 591f9387c2..2a9d24ee01 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -6,6 +6,7 @@ import numpy as np import onnx +import onnx.reference import onnxruntime as ort from onnxscript import ir @@ -32,10 +33,11 @@ def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]: def assert_numerically_equal( original_model_proto: onnx.ModelProto | ir.Model, rewritten_model_proto: onnx.ModelProto | ir.Model, - args: tuple[Any, ...] | dict[str, Any], + args: tuple[Any, ...] | dict[str, Any] | None = None, ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL, rtol: float = 1, atol: float = 1e-3, + use_reference: bool = False, ): """Assert that the two models are numerically equal. @@ -46,6 +48,7 @@ def assert_numerically_equal( ort_optimization_level: Onnxruntime optimization level. rtol: Relative tolerance. atol: Absolute tolerance. + use_reference: If True, use ONNX reference implementation instead of ONNXRuntime. """ if isinstance(original_model_proto, ir.Model): @@ -53,7 +56,10 @@ def assert_numerically_equal( if isinstance(rewritten_model_proto, ir.Model): rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto) - if isinstance(args, dict): + if args is None: + original_proto_ort_inputs = generate_random_inputs(original_model_proto) + the_rewritten_proto_ort_inputs = original_proto_ort_inputs + elif isinstance(args, dict): original_proto_ort_inputs = args the_rewritten_proto_ort_inputs = args else: @@ -64,21 +70,34 @@ def assert_numerically_equal( k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) } - original_proto_ort_inference_session = _ort_session_initializer( - original_model_proto.SerializeToString(), ort_optimization_level - ) - run_options = ort.RunOptions() - run_options.log_severity_level = 3 # 3: Error - original_outputs = original_proto_ort_inference_session.run( - None, original_proto_ort_inputs, run_options=run_options - ) - - the_rewritten_proto_ort_inference_session = _ort_session_initializer( - rewritten_model_proto.SerializeToString(), ort_optimization_level - ) - the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( - None, the_rewritten_proto_ort_inputs, run_options=run_options - ) + if use_reference: + # Use ONNX reference implementation + original_evaluator = _reference_session( + original_model_proto.SerializeToString(), ort_optimization_level + ) + original_outputs = original_evaluator.run(None, original_proto_ort_inputs) + + rewritten_evaluator = _reference_session( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = rewritten_evaluator.run(None, the_rewritten_proto_ort_inputs) + else: + # Use ONNXRuntime + original_proto_ort_inference_session = _ort_session_initializer( + original_model_proto.SerializeToString(), ort_optimization_level + ) + run_options = ort.RunOptions() + run_options.log_severity_level = 3 # 3: Error + original_outputs = original_proto_ort_inference_session.run( + None, original_proto_ort_inputs, run_options=run_options + ) + + the_rewritten_proto_ort_inference_session = _ort_session_initializer( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( + None, the_rewritten_proto_ort_inputs, run_options=run_options + ) np.testing.assert_allclose( original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True @@ -103,3 +122,18 @@ def _ort_session_initializer( provider for provider in possible_providers if provider in available_providers ] return ort.InferenceSession(model, providers=providers, sess_options=session_options) + + +def _reference_session( + model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel +) -> onnx.reference.ReferenceEvaluator: + """Initialize an ONNX reference evaluator with the specified model.""" + # Parse the model from bytes if needed + if isinstance(model, (str, bytes)): + model_proto = onnx.load_from_string(model) + else: + model_proto = model + + # Note: ort_optimization_level is ignored for reference implementation + # as it doesn't have equivalent optimization levels + return onnx.reference.ReferenceEvaluator(model_proto) From e67eeefc8bc2b120bab79a8d04f303690ddc4bc0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 23 Sep 2025 12:47:28 -0700 Subject: [PATCH 033/123] [torchlib] Simplify linalg_vector_norm to remove the redundant Abs (#2570) This happens in some of the LORA models. When we use ReduceL1/ReduceL2 or when ord is an even number, we don't need to take Abs of the input Signed-off-by: Justin Chu --------- Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/linalg.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 05bac181ca..c9d870bd86 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -330,8 +330,9 @@ def aten_linalg_vector_norm( keepdim = False else: dim = op.Reshape(dim, op.Constant(value_ints=[-1])) - self = op.Abs(self) + if math.isinf(ord): + self = op.Abs(self) if ord > 0: return op.ReduceMax(self, dim, keepdims=keepdim) else: @@ -345,6 +346,9 @@ def aten_linalg_vector_norm( elif ord == 2.0: return op.ReduceL2(self, dim, keepdims=keepdim) else: + if ord < 0 or ord % 2 != 0: + # Not an even integer (could be odd, fractional or negative), use Abs + self = op.Abs(self) self_pow = op.Pow(self, ord) exp = op.CastLike(1 / ord, self) return op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), exp) From 7e45333e58657d584b8503aaffb0ed3537023605 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 23 Sep 2025 20:57:06 -0700 Subject: [PATCH 034/123] [torchlib] Add trace_only flag to aten_copy, aten_tril, aten_triu (#2572) --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 95fbe39811..99fc6fb44f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2236,7 +2236,7 @@ def aten_convolution_overrideable( raise NotImplementedError() -@torch_op("aten::copy") +@torch_op("aten::copy", trace_only=True) def aten_copy( self: TTensor, src: TTensor2, @@ -8690,7 +8690,7 @@ def aten_triangular_solve( raise NotImplementedError() -@torch_op("aten::tril") +@torch_op("aten::tril", trace_only=True) def aten_tril(self: TTensor, diagonal: int = 0) -> TTensor: """tril(Tensor self, int diagonal=0) -> Tensor""" @@ -8718,7 +8718,7 @@ def aten_triplet_margin_loss( raise NotImplementedError() -@torch_op("aten::triu") +@torch_op("aten::triu", trace_only=True) def aten_triu(self: TTensor, diagonal: int = 0) -> TTensor: """triu(Tensor self, int diagonal=0) -> Tensor""" From 168fd8a63c6591b132c9393c8cf5e1d9a2aba933 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 24 Sep 2025 13:47:23 -0700 Subject: [PATCH 035/123] Bump version from 0.5.2 to 0.5.3 (#2571) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index cb0c939a93..be14282b7f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.2 +0.5.3 From dddf0c2f97c4839b5fbcdbd1c0509562a922a7fe Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 26 Sep 2025 12:29:53 -0700 Subject: [PATCH 036/123] Fix Onnx 23 Rotary Fusion (#2576) Fix Onnx 23 Rotary Fusion --------- Signed-off-by: Ganesan Ramalingam --- .../fusion/_rms_normalization_test.py} | 34 ++---------- .../rules/fusion/_rotary_embedding.py | 33 ++++++++++-- .../rules/fusion/_rotary_embedding_test.py | 53 +++++++++++++++++++ 3 files changed, 85 insertions(+), 35 deletions(-) rename onnxscript/rewriter/{onnx_fusions/_onnx_fusions_test.py => rules/fusion/_rms_normalization_test.py} (53%) create mode 100644 onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py similarity index 53% rename from onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py rename to onnxscript/rewriter/rules/fusion/_rms_normalization_test.py index 22d6120da1..e70c4ec7a0 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py +++ b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py @@ -5,14 +5,12 @@ import unittest import onnx_ir as ir -from parameterized import parameterized import onnxscript -from onnxscript.rewriter import onnx_fusions -from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rms_normalization -class OnnxFusionsTest(unittest.TestCase): +class RmsNormOnnxFusionsTest(unittest.TestCase): def test_rms_normalization_fusion(self): opset23 = onnxscript.values.Opset("", 23) @@ -34,34 +32,10 @@ def rms_norm_script(embedding, layernorm_weight): output_types=[onnxscript.FLOAT[128]], ) model = ir.serde.deserialize_model(rms_norm_model_proto) - onnx_fusions.fuse(model, debug=True) + count = _rms_normalization.fuse_rms_normalization(model) + self.assertEqual(count, 1) self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization") - @parameterized.expand( - [ - ( - "test_case_1", - _rotary_embedding_models.test_case_1, - ), - ( - "test_case_2", - _rotary_embedding_models.test_case_2, - ), - ] - ) - def test_rotary_embedding_fusion(self, _: str, test_data_constructor): - test = test_data_constructor() - for opset_version in [22, 23]: - model: ir.Model = test.get_onnx_model() - model.graph.opset_imports[""] = opset_version - onnxscript.optimizer.optimize(model) - onnx_fusions.fuse(model) - op_types = [n.op_type for n in model.graph] - if opset_version == 22: - self.assertNotIn("RotaryEmbedding", op_types) - else: - self.assertIn("RotaryEmbedding", op_types) - if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py index 2009c6953f..524b6f4806 100644 --- a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -30,13 +30,34 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2): class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase): def __init__(self): - super().__init__(name="RotaryEmbedding23") + super().__init__(name="RotaryEmbedding23", remove_nodes=False) - def pattern(self, op, x, cos, sin, start1, end1, start2, end2): - return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): + freqs_repeated = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(freqs_repeated) + sin = op.Sin(freqs_repeated) + cos_4d = op.Unsqueeze(cos, one1) + sin_4d = op.Unsqueeze(sin, one2) + return x * cos_4d + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin_4d - def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined] + def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() + + def is_one(val): + """Check if val is a 0/1 dimensional tensor with a single element equal to 1.""" + np_val = _ir_utils.get_numpy_value(val) + return ( + np_val is not None + and np_val.size == 1 + and np_val.ndim <= 1 + and np_val.item() == 1 + ) + + if not is_one(one1): + return check_result.fail("Unsqueeze axes is not [1]", one1) + if not is_one(one2): + return check_result.fail("Unsqueeze axes is not [1]", one2) + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: return check_result.fail("Input is not known to be a 4D tensor.", x) @@ -59,8 +80,10 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: ) return check_result - def rewrite(self, op, x, cos, sin, **_): + def rewrite(self, op, x, freqs, **_): num_heads = x.shape[1] + cos = op.Cos(freqs) + sin = op.Sin(freqs) return op.RotaryEmbedding( x, cos, diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py new file mode 100644 index 0000000000..b8ffe95cac --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx +import onnx_ir as ir +from packaging.version import Version +from parameterized import parameterized + +import onnxscript +import onnxscript.rewriter.testing +from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rotary_embedding + + +class RotaryEmbeddingOnnxFusionTest(unittest.TestCase): + @parameterized.expand( + [ + ( + "test_case_1", + _rotary_embedding_models.test_case_1, + ), + ( + "test_case_2", + _rotary_embedding_models.test_case_2, + ), + ] + ) + def test_rotary_embedding_fusion(self, _: str, test_data_constructor): + test = test_data_constructor() + model: ir.Model = test.get_onnx_model() + model.graph.opset_imports[""] = 23 + model_proto = ir.serde.serialize_model(model) + onnxscript.optimizer.optimize(model) + _rotary_embedding.fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + rewritten_model_proto = ir.serde.serialize_model(model) + inputs = test.get_ort_inputs() + + onnx_version = Version(onnx.__version__) + min_version = Version("1.19.1") + is_stable = not (onnx_version.is_devrelease or onnx_version.is_prerelease) + if onnx_version >= min_version and is_stable: + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, rewritten_model_proto, args=inputs, use_reference=True + ) + + +if __name__ == "__main__": + unittest.main() From df8f706fc763697f9c453c54dd7efc16ee23a2a4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 15:44:46 -0700 Subject: [PATCH 037/123] [torchlib] Support integers in logical_and/or ops and update other logical ops (#2582) This PR 1. Consolidates logic for `bitwise_*` functions so that the `logical_*` functions are no longer handling bool overloads of the bitwise ops. 2. Adds support for integer inputs in the `logical_*` implementations. Replacement of #2579. --------- Signed-off-by: Justin Chu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../function_libs/torch_lib/ops/core.py | 117 ++++++++++-------- .../function_libs/torch_lib/ops_test_data.py | 4 + 2 files changed, 67 insertions(+), 54 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 99fc6fb44f..96b92c2e8e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -162,9 +162,15 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) -def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(microsoft/onnxruntime#15977): Improve fp16 precision + + if self.dtype == ir.DataType.BOOL: + # alpha can also be bool + if alpha == 0: + return op.Identity(self) + return op.Or(self, other) + if alpha != 1.0: alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) @@ -1233,15 +1239,19 @@ def aten_binomial( "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", - "_operator::and_", ), trace_only=True, ) -def aten_bitwise_and(self: TInt, other: TInt) -> TInt: +def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_and implements the BOOL variant - return op.BitwiseAnd(self, other) + assert self.dtype == other.dtype + + if self.dtype.is_integer(): + return op.BitwiseAnd(self, other) + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op( @@ -1329,11 +1339,14 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: @torch_op("aten::bitwise_not", trace_only=True) -def aten_bitwise_not(self: TInt) -> TInt: +def aten_bitwise_not(self: TTensor) -> TTensor: """bitwise_not(Tensor self) -> Tensor""" - # logical_not implements the BOOL variant - return op.BitwiseNot(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + if self.dtype.is_integer(): + return op.BitwiseNot(self) + raise NotImplementedError(f"Not implemented for type {self.dtype}") @torch_op( @@ -1341,15 +1354,19 @@ def aten_bitwise_not(self: TInt) -> TInt: "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", - "_operator::or_", ), trace_only=True, ) -def aten_bitwise_or(self: TInt, other: TInt) -> TInt: +def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_or implements the BOOL variant - return op.BitwiseOr(self, other) + assert self.dtype == other.dtype + + if self.dtype.is_integer(): + return op.BitwiseOr(self, other) + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op( @@ -1487,11 +1504,15 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: ), trace_only=True, ) -def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: +def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor: """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_xor implements the BOOL variant + assert self.dtype == other.dtype - return op.BitwiseXor(self, other) + if self.dtype.is_integer(): + return op.BitwiseXor(self, other) + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op("aten::blackman_window", trace_only=True) @@ -5010,58 +5031,46 @@ def aten_logdet(self: TFloat) -> TFloat: return op.Log(op.Det(self)) -@torch_op( - ( - "aten::logical_and", - "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", - ), - trace_only=True, -) -def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: +@torch_op("aten::logical_and", trace_only=True) +def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" - return op.And(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) -def aten_logical_not(self: BOOL) -> BOOL: +@torch_op("aten::logical_not", trace_only=True) +def aten_logical_not(self: TTensor) -> BOOL: """logical_not(Tensor self) -> Tensor""" - return op.Not(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + return op.Not(op.Cast(self, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_or", - "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", - "aten::add.Tensor", - "aten::add.Scalar", - ), - trace_only=True, -) -def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: +@torch_op(("aten::logical_or"), trace_only=True) +def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" - return op.Or(self, other) + assert self.dtype == other.dtype + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_xor", - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ), - trace_only=True, -) -def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: + +@torch_op("aten::logical_xor", trace_only=True) +def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" - return op.Xor(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) @torch_op("aten::logit", private=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b1e0c529ec..98d10d9e5b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1631,6 +1631,10 @@ def _where_input_wrangler( dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,), reason="fixme: test is unstable on macosx, windows", ), + TorchLibOpInfo("logical_and", core_ops.aten_logical_and), + TorchLibOpInfo("logical_not", core_ops.aten_logical_not), + TorchLibOpInfo("logical_or", core_ops.aten_logical_or), + TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor), TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}), TorchLibOpInfo("max_dim", core_ops.aten_max_dim) .xfail( From 94fb24fa0862d23069f2087007db4456ac376243 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 20:21:15 -0700 Subject: [PATCH 038/123] Record names of contributing values in the constant folding pass (#2575) Record names of contributing values in the constant folding pass to the newly created output as metadata, so that downstream users like Olive can use the info for further manipulations. This is useful for Olive to identify transposed lora weights in the graph. --------- Signed-off-by: Justin Chu --- docs/api/optimizer.md | 1 - onnxscript/optimizer/__init__.py | 8 ++---- onnxscript/optimizer/_constant_folding.py | 35 ++++++++++++++++++++++- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/docs/api/optimizer.md b/docs/api/optimizer.md index 90de403099..6c8adf21bb 100644 --- a/docs/api/optimizer.md +++ b/docs/api/optimizer.md @@ -15,5 +15,4 @@ optimizer.inline optimizer.basic_constant_propagation optimizer.fold_constants - optimizer.remove_unused_nodes ``` diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 6260829249..978a1b4d65 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -19,12 +19,8 @@ import onnxscript.optimizer._constant_folding as constant_folding from onnxscript import ir -from onnxscript.optimizer._constant_folding import ( - basic_constant_propagation, -) -from onnxscript.optimizer._constant_folding import ( - fold_constants as fold_constants_ir, -) +from onnxscript.optimizer._constant_folding import basic_constant_propagation +from onnxscript.optimizer._constant_folding import fold_constants as fold_constants_ir from onnxscript.optimizer._optimizer import optimize_ir _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 62c28894c0..b959e8df73 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -5,6 +5,13 @@ from __future__ import annotations +__all__ = [ + "basic_constant_propagation", + "fold_constants", + "FoldConstantsPass", + "FOLDED_FROM_KEY", +] + import dataclasses import logging import math @@ -23,6 +30,9 @@ DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512 +# Key used to store the metadata +FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from" + _NON_DETERMINISTIC_OPS = frozenset( { @@ -914,6 +924,24 @@ def merge_dims(dim1, dim2): return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) +def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None: + """Record the set of original input values that contributed to the constant-folded outputs.""" + folded_from: set[str] = set() + for input in original_node.inputs: + if input is None: + continue + folded_from.update(input.meta.get(FOLDED_FROM_KEY, set())) + assert input.name is not None + folded_from.add(input.name) + + for new_output in replacement.new_outputs: + if new_output is None: + continue + new_output.meta[FOLDED_FROM_KEY] = folded_from + # Store the string representation of the set to metadata_props to persist it across serialization + new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from)) + + class FoldConstantsPass(ir.passes.InPlacePass): """A pass that folds constant expressions in the model. @@ -1203,9 +1231,14 @@ def convert(av): ) return None - def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: + def replace_node( + self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function + ) -> None: logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) + # Record the names of the values that has contributed to the replacement + _record_contributing_values(node, replacement) + ir.convenience.replace_nodes_and_values( root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs ) From 3a26097c9fe629d6e01fec8a3ffb99457ea26054 Mon Sep 17 00:00:00 2001 From: Daniel Zhang Date: Tue, 30 Sep 2025 13:27:05 +0800 Subject: [PATCH 039/123] Merge output shape with input shape instead of override (#2578) `_constant_folding.cast` override `output.shape` with `input.shape`, that may make a static shape to dynamic shape. Here should use `_merge_shapes` instead. --- onnxscript/optimizer/_constant_folding.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index b959e8df73..6aae8efab3 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -501,9 +501,7 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: # should handle this. Only the optimization to eliminate redundant Cast ops # should be needed here. - input_shape = input.shape - if input_shape is not None: - output.shape = input_shape.copy() + output.shape = _merge_shapes(output.shape, input.shape) input_dtype = _get_input_element_type(node, 0) output_dtype = _get_int_attribute(node, "to", None) From 35054209a513c35e17797669436313adcc7fe8cb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 11:55:03 -0700 Subject: [PATCH 040/123] [torchlib] Add back operator and/or (#2590) Previously the entries were mistakenly removed. --------- Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 96b92c2e8e..dfbd562708 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1239,6 +1239,7 @@ def aten_binomial( "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", + "_operator::and_", ), trace_only=True, ) @@ -1354,6 +1355,7 @@ def aten_bitwise_not(self: TTensor) -> TTensor: "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", + "_operator::or_", ), trace_only=True, ) @@ -5051,7 +5053,7 @@ def aten_logical_not(self: TTensor) -> BOOL: return op.Not(op.Cast(self, to=BOOL.dtype)) -@torch_op(("aten::logical_or"), trace_only=True) +@torch_op("aten::logical_or", trace_only=True) def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" From 9b54ad549aa927469e666404437c706d43c43f92 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 30 Sep 2025 12:26:30 -0700 Subject: [PATCH 041/123] Extend utilities for checking a scalar value (#2587) Extend the `is_singleton_value` utility to check for singleton values that may be either 0D or 1D tensors. --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_ir_utils.py | 23 ++++++++++++++----- .../rules/fusion/_rotary_embedding.py | 14 ++--------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 6af84dd1d8..91c3308bc2 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -78,23 +78,34 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None -def get_singleton_value(val: ir.Value | None, rank: int | None = None): +def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None): """Returns element of a single element tensor constant value, and None otherwise. - If rank is specified, it checks that the value has the given rank. + If an int rank is specified, it checks that the value has the given rank. + If the rank is a sequence of ints, it checks that the value has one of the given ranks. + + Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and + `rank=(0,1)` checks for either a scalar or a 1D tensor. """ np_val = get_numpy_value(val) if np_val is not None and np_val.size == 1: - if rank is None or (np_val.ndim == rank): - return np_val.item() + value = np_val.item() + if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)): + return value + if isinstance(rank, Sequence) and (np_val.ndim in rank): + return value return None def is_singleton_value( - val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None + val: ir.Value | None, + expected: float | int | Callable, + *, + rtol: float | None = None, + rank: int | Sequence[int] | None = None, ) -> bool: """Returns True if the value is a single element tensor with given value, and False otherwise.""" - scalar = get_singleton_value(val) + scalar = get_singleton_value(val, rank=rank) if scalar is None: return False if callable(expected): diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py index 524b6f4806..b659afdbc0 100644 --- a/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -43,19 +43,9 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() - def is_one(val): - """Check if val is a 0/1 dimensional tensor with a single element equal to 1.""" - np_val = _ir_utils.get_numpy_value(val) - return ( - np_val is not None - and np_val.size == 1 - and np_val.ndim <= 1 - and np_val.item() == 1 - ) - - if not is_one(one1): + if not _ir_utils.is_singleton_value(one1, 1): return check_result.fail("Unsqueeze axes is not [1]", one1) - if not is_one(one2): + if not _ir_utils.is_singleton_value(one2, 1): return check_result.fail("Unsqueeze axes is not [1]", one2) # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) From 722765500257cdcc89a59eec35a5c2f17f79e522 Mon Sep 17 00:00:00 2001 From: Daniel Zhang Date: Wed, 1 Oct 2025 04:04:20 +0800 Subject: [PATCH 042/123] Merge input and output shape when removing identity (#2588) Similar with #2578, for this case: ```python import torch import torch.nn as nn class Model(nn.Module): def forward(self, x): return x.new_zeros(x.shape) def main(): model = Model() args = torch.rand(4, 4), batch = torch.export.Dim("batch") dynamic_shapes = {"x": {0: batch}} torch.onnx.export( model, args, "model_test.onnx", dynamic_shapes=dynamic_shapes, dynamo=True, ) if __name__ == "__main__": main() ``` --------- Co-authored-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6aae8efab3..8317d2be63 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -608,6 +608,9 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] output = node.outputs[0] if input is not None and output is not None: + input.shape = _merge_shapes(input.shape, output.shape) + if input.type is None: + input.type = output.type state.set_sym_value(output, input) return None From a1db753311ffa82b52f96e200087845b6ca247b0 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 30 Sep 2025 17:15:41 -0700 Subject: [PATCH 043/123] Add NaN handling in softmax pattern in SDPA fusion (#2593) Add NaN handling in softmax pattern in SDPA fusion Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/sdpa.py | 3 + onnxscript/rewriter/ort_fusions/sdpa_test.py | 85 ++++++++++++++++---- 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 1d339f43e7..55b38e9ad4 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -88,6 +88,9 @@ def pattern( ) attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + adj_attn_weight = op.Where(is_nan, 0.0, attn_weight) + attn_weight = pattern.OrValue([adj_attn_weight, attn_weight]) attn_output = op.MatMul(attn_weight, value) return attn_output diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 90bcd26097..c5326a77b9 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -44,7 +44,10 @@ def _unmasked_pre_div_sdpa_script(query, key, value): scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -56,7 +59,10 @@ def _unmasked_pre_mul_sdpa_script(query, key, value): scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -67,7 +73,10 @@ def _unmasked_post_div_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -78,7 +87,10 @@ def _unmasked_post_mul_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -90,7 +102,10 @@ def _custom_scale_pre_div_sdpa_script(query, key, value): scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -102,7 +117,10 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value): scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -115,7 +133,10 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): scaled_key = op.Mul(key_transposed, multiplier_k) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -126,7 +147,10 @@ def _custom_scale_post_div_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -137,7 +161,10 @@ def _custom_scale_post_mul_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -150,7 +177,10 @@ def _masked_pre_div_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -163,7 +193,10 @@ def _masked_pre_mul_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -175,7 +208,10 @@ def _masked_post_div_sdpa_script(query, key, value, mask): scaled_attn_score = op.Div(attn_score, divisor) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -187,7 +223,10 @@ def _masked_post_mul_sdpa_script(query, key, value, mask): scaled_attn_score = op.Mul(attn_score, multiplier) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -200,7 +239,10 @@ def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -213,7 +255,10 @@ def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -225,7 +270,10 @@ def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask): scaled_attn_score = op.Div(attn_score, divisor) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -237,7 +285,10 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask): scaled_attn_score = op.Mul(attn_score, multiplier) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output From 09bbd270156e0c241b8b8a27cb25107a55926c97 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 17:19:25 -0700 Subject: [PATCH 044/123] Remove usages of ir.Input in test (#2591) It was deprecated Signed-off-by: Justin Chu --- .../rules/common/_fuse_conv_affine_test.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py index 4f1f671f43..d456cab76b 100644 --- a/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py @@ -18,9 +18,7 @@ def clone_model(self, model: ir.Model) -> ir.Model: def test_conv_affine_fusion(self): tape = ir.tape.Tape() - x = ir.Input( - "x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT) - ) + x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32])) w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) @@ -31,10 +29,10 @@ def test_conv_affine_fusion(self): z = tape.op( "Add", [mul_out, offset], - output=ir.Input( + output=ir.val( "z", + dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32]), - type=ir.TensorType(ir.DataType.FLOAT), ), ) @@ -65,9 +63,7 @@ def test_conv_affine_fusion(self): def test_affine_conv_fusion_without_pad(self): tape = ir.tape.Tape() - x = ir.Input( - "x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT) - ) + x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32])) w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) @@ -77,10 +73,10 @@ def test_affine_conv_fusion_without_pad(self): z = tape.op( "Add", [mul_out, offset], - output=ir.Input( + output=ir.val( "z", + dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32]), - type=ir.TensorType(ir.DataType.FLOAT), ), ) conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]}) From 88b03d80799f6c47323b524ba9b56272ff8adca2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 17:24:52 -0700 Subject: [PATCH 045/123] Improve aten_floor_divide for int inputs (#2592) Fix aten_floor_divide for negative int inputs and large int inputs. I also combined the int and float overloads for https://github.com/microsoft/onnxscript/issues/2580 Fix #2589 --------- Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 28 +++++++++++-------- tests/function_libs/torch_lib/extra_opinfo.py | 11 +------- .../function_libs/torch_lib/ops_test_data.py | 1 - 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index dfbd562708..1a688a4277 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3688,23 +3688,27 @@ def python_math_floor(self: TFloat) -> TInt: @torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: +def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: """floor_divide(Tensor self, Tensor other) -> Tensor""" - return op.Floor(op.Div(self, other)) + if self.dtype.is_floating_point(): + return op.Floor(op.Div(self, other)) + assert self.dtype.is_integer() -@torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide_int(self: TInt, other: TInt) -> TInt: - """floor_divide(Tensor self, Tensor other) -> Tensor""" + if not self.dtype.is_signed(): + return op.Div(self, other) - # TODO(justinchuby): This can be simplified if we can constrain the - # inputs to be positive integers. Consider how we can embed constraints in the model. - dtype = self.dtype - self = op.Cast(self, to=FLOAT.dtype) - other = op.Cast(other, to=FLOAT.dtype) - result = op.Floor(op.Div(self, other)) - return op.Cast(result, to=dtype) + # Convert truncation to flooring + # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70 + # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + offset = op.And( + op.Not(op.Equal(op.Sign(self), op.Sign(other))), + op.Cast(op.Mod(self, other), to=BOOL.dtype), + ) + offset = op.Cast(offset, to=self.dtype) + return op.Sub(op.Div(self, other), offset) @torch_op("_operator::floordiv", trace_only=True) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 4f4a3872e1..b03cb5880a 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2270,18 +2270,9 @@ def __init__(self): opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", - dtypes=common_dtype.floating_types_and_half(), + dtypes=common_dtype.all_types_and_half(), rhs_make_tensor_kwargs=dict(exclude_zero=True), ), - opinfo_core.BinaryUfuncInfo( - "ops.aten.floor_divide.int", - aten_name="floor_divide", - op=torch.ops.aten.floor_divide, - dtypes=common_dtype.integral_types(), - # Create only positive inputs - lhs_make_tensor_kwargs=dict(low=0), - rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), - ), opinfo_core.OpInfo( "ops.aten.hamming_window", aten_name="hamming_window", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 98d10d9e5b..92495d201a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -794,7 +794,6 @@ def _where_input_wrangler( TorchLibOpInfo("flatten", core_ops.aten_flatten), TorchLibOpInfo("floor", core_ops.aten_floor), TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide), - TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), From 149d567592cdb5f8c9608259aab3315e0c4b1bdb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 17:36:33 -0700 Subject: [PATCH 046/123] Fix collapse slices rewrite rules to handle unknown dims (#2583) Fixes https://github.com/microsoft/onnxscript/issues/2577 Signed-off-by: Justin Chu --- noxfile.py | 2 +- onnxscript/rewriter/rules/common/_collapse_slices.py | 4 ++++ pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 989b10b16e..ac9296a5cd 100644 --- a/noxfile.py +++ b/noxfile.py @@ -42,7 +42,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.9" +ONNX_IR = "onnx_ir==0.1.10" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" diff --git a/onnxscript/rewriter/rules/common/_collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py index 5e262a785e..eda8547037 100644 --- a/onnxscript/rewriter/rules/common/_collapse_slices.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -85,6 +85,10 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ if not is_singleton_value(steps, 1): return False + # If any dim is unknown, the shapes are not the same + if data.shape.has_unknown_dim() or slice_output.shape.has_unknown_dim(): + return False + return data.shape == slice_output.shape diff --git a/pyproject.toml b/pyproject.toml index 5f31581494..4f7edc9bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dependencies = [ "ml_dtypes", "numpy", - "onnx_ir>=0.1.9,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx_ir>=0.1.10,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. "onnx>=1.16", "packaging", "typing_extensions>=4.10", From 929a7f2211d8da894da2d3fe5fe48456362ddbec Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 18:11:15 -0700 Subject: [PATCH 047/123] Expose the should_fold option to optimize() (#2594) Signed-off-by: Justin Chu --- onnxscript/optimizer/_optimizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 384cc12fd4..307144462f 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from typing import Callable import onnx_ir as ir import onnx_ir.passes.common as common_passes @@ -21,6 +22,7 @@ def optimize_ir( stop_if_no_change: bool = True, input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, inline: bool = True, ) -> None: """Optimizes a model. @@ -29,11 +31,15 @@ def optimize_ir( model: The model to be optimized. num_iterations: Number of times the optimization loop is repeated. onnx_shape_inference: Applies node-level shape-inference as part of optimization + stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. input_size_limit: Will not apply constant folding to ops with any input of size greater than this. Does not apply to special ops like Shape() and Size(). output_size_limit: Will not rewrite any foldable-op into a Constant op if the size of the output tensor is greater than this. - stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding. + The function should return True/False value to indicate if this particular + node should be folded, or None to use the default folding rules. inline: If True, inlines all functions in the model. """ passes = [ @@ -43,6 +49,7 @@ def optimize_ir( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, + should_fold=should_fold, ), rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES), common_passes.RemoveUnusedNodesPass(), From 81f8444df82e63dfe5eaf541d2f7d954d5a96ff0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Sep 2025 20:45:49 -0700 Subject: [PATCH 048/123] Bump version from 0.5.3 to 0.5.4 (#2595) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index be14282b7f..7d8568351b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.3 +0.5.4 From b7ccc86768f047992af3d2a45274013a85b9e324 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 09:37:55 -0700 Subject: [PATCH 049/123] Update torch api error message to include value names (#2599) Update torch api error message to include value names when raising error on uninitialized values Signed-off-by: Justin Chu --- onnxscript/_framework_apis/torch_2_5.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index 2f8601c7c6..162faf4b75 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -67,12 +67,14 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike """Save the model with external data. The model is unchanged after saving.""" # TODO(#1835): Decide if we want to externalize large attributes as well - for value in model.graph.initializers.values(): - if value.const_value is None: - raise ValueError( - "The model contains uninitialized initializer values. " - "Please make sure all initializer values are initialized." - ) + uninitialized_values = [ + value.name for value in model.graph.initializers.values() if value.const_value is None + ] + if uninitialized_values: + raise ValueError( + f"The model contains uninitialized initializer values ({uninitialized_values}). " + "Please make sure all initializer values are initialized." + ) destination_path = pathlib.Path(model_path) data_path = f"{destination_path.name}.data" From 30ae54b91acf3cdc419544d663acad73fe76944c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:53:53 -0700 Subject: [PATCH 050/123] Remove beartype (#2603) As it is unused Signed-off-by: Justin Chu --- noxfile.py | 1 - onnxscript/_internal/runtime_typing.py | 43 -------------------------- requirements-dev.txt | 3 -- 3 files changed, 47 deletions(-) delete mode 100644 onnxscript/_internal/runtime_typing.py diff --git a/noxfile.py b/noxfile.py index ac9296a5cd..23c2963998 100644 --- a/noxfile.py +++ b/noxfile.py @@ -12,7 +12,6 @@ COMMON_TEST_DEPENDENCIES = ( - "beartype==0.17.2", "expecttest==0.1.6", "hypothesis", "numpy", diff --git a/onnxscript/_internal/runtime_typing.py b/onnxscript/_internal/runtime_typing.py deleted file mode 100644 index 3cf8a8db57..0000000000 --- a/onnxscript/_internal/runtime_typing.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""An internal wrapper for the beartype library. - -Decorate a function with `@runtime_typing.checked` to enable runtime -type checking. The decorator is a no-op when the `beartype` library is not -installed. -""" - -import typing -import warnings - -__all__ = [ - "checked", -] - -T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any]) - -try: - from beartype import beartype as _beartype_decorator - from beartype import roar as _roar - - checked = typing.cast(typing.Callable[[T], T], _beartype_decorator) - - # Beartype warns when we import from typing because the types are deprecated - # in Python 3.9. But there will be a long time until we can move to using - # the native container types for type annotations (when 3.9 is the lowest - # supported version). So we silence the warning. - warnings.filterwarnings( - "ignore", - category=_roar.BeartypeDecorHintPep585DeprecationWarning, - ) -except ImportError: - - def checked(func: T) -> T: # type: ignore[no-redef] - return func - -except Exception as e: # pylint: disable=broad-exception-caught - # Warn errors that are not import errors (unexpected). - warnings.warn(f"{e}", stacklevel=2) - - def checked(func: T) -> T: # type: ignore[no-redef] - return func diff --git a/requirements-dev.txt b/requirements-dev.txt index 355fce3bff..b689d9bad5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,9 +17,6 @@ sphinx>=6 myst_nb chardet -# Torch lib -beartype!=0.16.0 - # Testing expecttest==0.1.6 hypothesis From 897345de82e22c042a007410b92bbdeb91b81cc6 Mon Sep 17 00:00:00 2001 From: deoxy Date: Tue, 7 Oct 2025 00:21:05 +0900 Subject: [PATCH 051/123] Separated implementation of aten::scatter overloads (#2605) close #2601 #2602 This PR refactors the implementation of `aten::scatter` overloads, improving the clarity of the ONNX output generated by `aten::scatter.src.` I've also added new tests to verify the correctness of these changes. To make the added tests pass, I needed to also address the issue reported in #2602, which is included in this PR's diff. Signed-off-by: Linsho Kaku --- .../function_libs/torch_lib/ops/core.py | 22 ++++-- tests/function_libs/torch_lib/extra_opinfo.py | 75 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 2 + 3 files changed, 94 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1a688a4277..11f26b8141 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7736,17 +7736,29 @@ def aten_scalar_tensor_sym_number( return common_ops.cast_to(s, dtype=dtype) -@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) -def aten_scatter( +@torch_op("aten::scatter.src", trace_only=True) +def aten_scatter_src( self: TReal, dim: int, # we have to use int here because ScatterElements() will use this attribute index: TInt, src: TReal, ) -> TReal: - """scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" + """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" + return op.ScatterElements(self, index, src, axis=dim) + - update = op.Expand(src, op.Shape(index)) - return op.ScatterElements(self, index, update, axis=dim) +@torch_op("aten::scatter.value", trace_only=True) +def aten_scatter_value( + self: TReal, + dim: int, # we have to use int here because ScatterElements() will use this attribute + index: TInt, + value: TReal, +) -> TReal: + """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor""" + # Ensure value is a scalar tensor and expand it to match index shape + scalar_tensor = op.CastLike(value, self) + src = op.Expand(scalar_tensor, op.Shape(index)) + return op.ScatterElements(self, index, src, axis=dim) @torch_op("aten::scatter_add", trace_only=True) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index b03cb5880a..f6f2a276fa 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1365,6 +1365,65 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(input_, args=(src, *args)) +def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Basic test cases for scatter.src + cases = [ + # (self_shape, index_shape, src_shape, dim) + ((5, 5), (2, 3), (2, 3), 0), # 2D scatter on dim=0 + ((5, 5), (3, 2), (3, 2), 1), # 2D scatter on dim=1 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 0), # 3D scatter on dim=0 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 1), # 3D scatter on dim=1 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 2), # 3D scatter on dim=2 + ((10,), (3,), (3,), 0), # 1D scatter + ] + + for self_shape, index_shape, src_shape, dim in cases: + self_tensor = make_arg(self_shape) + # Create valid indices for the given dimension without duplication + index_buffer_shape = list(index_shape) + index_buffer_shape[dim] = self_shape[dim] + index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ + tuple(slice(None, d, None) for d in index_shape) + ] + src_tensor = make_arg(src_shape) + yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor)) + + +def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Basic test cases for scatter.value + cases = [ + # (self_shape, index_shape, dim, value) + ((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value + ((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value + ((3, 4, 5), (2, 2, 3), 0, 0.0), # 3D scatter on dim=0 with scalar value + ((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value + ((3, 4, 5), (2, 2, 3), 2, -1.0), # 3D scatter on dim=2 with scalar value + ((10,), (3,), 0, 5.0), # 1D scatter with scalar value + ] + + for self_shape, index_shape, dim, value in cases: + self_tensor = make_arg(self_shape) + # Create valid indices for the given dimension without duplication + index_buffer_shape = list(index_shape) + index_buffer_shape[dim] = self_shape[dim] + index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ + tuple(slice(None, d, None) for d in index_shape) + ] + yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value)) + + def sample_inputs__scaled_dot_product_flash_attention( op_info, device, dtype, requires_grad, **kwargs ): @@ -2533,6 +2592,22 @@ def __init__(self): sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.scatter.src", + op=torch.ops.aten.scatter.src, + aten_name="scatter.src", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_scatter_src, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.scatter.value", + op=torch.ops.aten.scatter.value, + aten_name="scatter.value", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_scatter_value, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten._softmax", op=torch.ops.aten._softmax, # pylint: disable=protected-access diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 92495d201a..ff4a68d2f6 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2108,6 +2108,8 @@ def _where_input_wrangler( reason="onnxruntime does not support ml_dtypes.bfloat16", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), + TorchLibOpInfo("ops.aten.scatter.src", core_ops.aten_scatter_src), + TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), TorchLibOpInfo( From aa2cf4aa5f22ef53cf9bc018b1cb1892bddc4752 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Oct 2025 22:34:48 +0000 Subject: [PATCH 052/123] chore(deps): bump onnx-weekly from 1.20.0.dev20250901 to 1.20.0.dev20251006 in /requirements/ci (#2610) --- requirements/ci/requirements-onnx-weekly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index 9c5363b8af..e005031603 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.20.0.dev20250901 +onnx-weekly==1.20.0.dev20251006 From 6718ef0390d41c78d8da17e90d2325f3b2a76825 Mon Sep 17 00:00:00 2001 From: deoxy Date: Tue, 7 Oct 2025 14:28:39 +0900 Subject: [PATCH 053/123] Enhanced type annotations and simplified implementation of scatter.value (#2612) follow #2605 --------- Signed-off-by: Linsho Kaku --- onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++++++-------- tests/function_libs/torch_lib/extra_opinfo.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 11f26b8141..a03eab1263 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7738,26 +7738,26 @@ def aten_scalar_tensor_sym_number( @torch_op("aten::scatter.src", trace_only=True) def aten_scatter_src( - self: TReal, + self: TTensor, dim: int, # we have to use int here because ScatterElements() will use this attribute index: TInt, - src: TReal, -) -> TReal: + src: TTensor, +) -> TTensor: """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" return op.ScatterElements(self, index, src, axis=dim) @torch_op("aten::scatter.value", trace_only=True) def aten_scatter_value( - self: TReal, + self: TTensor, dim: int, # we have to use int here because ScatterElements() will use this attribute index: TInt, - value: TReal, -) -> TReal: + value: float, +) -> TTensor: """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor""" # Ensure value is a scalar tensor and expand it to match index shape - scalar_tensor = op.CastLike(value, self) - src = op.Expand(scalar_tensor, op.Shape(index)) + scalar_tensor = ir.tensor([value], dtype=self.dtype) + src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor) return op.ScatterElements(self, index, src, axis=dim) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index f6f2a276fa..51f9c233ad 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1407,9 +1407,9 @@ def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs) # (self_shape, index_shape, dim, value) ((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value ((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value - ((3, 4, 5), (2, 2, 3), 0, 0.0), # 3D scatter on dim=0 with scalar value + ((3, 4, 5), (2, 2, 3), 0, False), # 3D scatter on dim=0 with scalar value ((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value - ((3, 4, 5), (2, 2, 3), 2, -1.0), # 3D scatter on dim=2 with scalar value + ((3, 4, 5), (2, 2, 3), 2, -1), # 3D scatter on dim=2 with scalar value ((10,), (3,), 0, 5.0), # 1D scatter with scalar value ] From 7f3325b339b9c8d08ab4e6e18fa1317c877b0dc5 Mon Sep 17 00:00:00 2001 From: deoxy Date: Wed, 8 Oct 2025 02:39:25 +0900 Subject: [PATCH 054/123] support for scalar args to aten::scatter (#2613) close #2600 Signed-off-by: Linsho Kaku --- .../function_libs/torch_lib/ops/core.py | 6 +++ tests/function_libs/torch_lib/extra_opinfo.py | 44 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a03eab1263..0584522864 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7744,6 +7744,10 @@ def aten_scatter_src( src: TTensor, ) -> TTensor: """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" + if len(index.shape) == 0: + index = op.Unsqueeze(index, [0]) + if len(src.shape) == 0: + src = op.Unsqueeze(src, [0]) return op.ScatterElements(self, index, src, axis=dim) @@ -7756,6 +7760,8 @@ def aten_scatter_value( ) -> TTensor: """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor""" # Ensure value is a scalar tensor and expand it to match index shape + if len(index.shape) == 0: + index = op.Unsqueeze(index, [0]) scalar_tensor = ir.tensor([value], dtype=self.dtype) src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor) return op.ScatterElements(self, index, src, axis=dim) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 51f9c233ad..0155c6fa73 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1394,6 +1394,35 @@ def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs): src_tensor = make_arg(src_shape) yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor)) + # Additional test cases for scalar and single-element tensor combinations with dim=0 + # Test case: scalar index, scalar src (dim_size=5) + dim_size = 5 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + scalar_src = make_arg(()) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, scalar_src)) + + # Test case: single-element tensor index, scalar src (dim_size=7) + dim_size = 7 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + scalar_src = make_arg(()) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, scalar_src)) + + # Test case: scalar index, single-element tensor src (dim_size=3) + dim_size = 3 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + src_1d = make_arg((1,)) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, src_1d)) + + # Test case: single-element tensor index, single-element tensor src (dim_size=10) + dim_size = 10 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + src_1d = make_arg((1,)) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, src_1d)) + def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -1423,6 +1452,21 @@ def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs) ] yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value)) + # Additional test cases for scalar and single-element tensor combinations with dim=0 + # Test case: scalar index with scalar value (dim_size=6, value_type=torch.long) + dim_size = 6 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + random_value = torch.randint(0, 10, (), device=device, dtype=torch.long).item() + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, random_value)) + + # Test case: single-element tensor index with scalar value (dim_size=8, value_type=torch.float) + dim_size = 8 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + random_value = torch.rand((), device=device, dtype=torch.float).item() + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, random_value)) + def sample_inputs__scaled_dot_product_flash_attention( op_info, device, dtype, requires_grad, **kwargs From a106bad29cdf1b0c0a1bccb7cd6797ea42f4598a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:42:39 -0700 Subject: [PATCH 055/123] chore(deps): bump ruff from 0.13.1 to 0.13.2 in /requirements/lintrunner (#2584) --- onnxscript/irbuilder.py | 2 +- requirements/lintrunner/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index b4d378bd17..76023ea002 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -214,7 +214,7 @@ def __str__(self): def debug_print(self): if logger.isEnabledFor(logging.DEBUG): - logger.debug("%s: %s", type(self), str(self)) + logger.debug("%s: %s", type(self), self) def to_node_proto(self, node_name: str) -> onnx.NodeProto: n = helper.make_node( diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index b2be2fa2f3..c71e5de95a 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.13.1 +ruff==0.13.2 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250402 From 8e4d41d96a4bb3a0bae8e34ce53473bf51cee42f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:43:42 -0700 Subject: [PATCH 056/123] [torchlib] Implement aten_bilinear function using Einsum (#2574) This PR implements the `aten_bilinear` function that was previously raising `NotImplementedError`. The bilinear transformation computes `y = x1^T A x2 + b` where: - `input1` has shape `(..., in1_features)` - `input2` has shape `(..., in2_features)` - `weight` has shape `(out_features, in1_features, in2_features)` - `bias` has shape `(out_features)` (optional) - Output has shape `(..., out_features)` ## Implementation Details The implementation is done using einsum. --------- Signed-off-by: Justin Chu Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 19 +++++++++- tests/function_libs/torch_lib/extra_opinfo.py | 38 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 3 ++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0584522864..e26c9f4e4d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1195,6 +1195,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor: return op.CastLike(sampled, self) +@torch_op("aten::bilinear", trace_only=True) def aten_bilinear( input1: TensorType, input2: TensorType, @@ -1203,7 +1204,23 @@ def aten_bilinear( ) -> TensorType: """bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor""" - raise NotImplementedError() + # Bilinear transformation: y = x1^T A x2 + b + # input1 shape: (..., in1_features) + # input2 shape: (..., in2_features) + # weight shape: (out_features, in1_features, in2_features) + # bias shape: (out_features) - optional + # output shape: (..., out_features) + + # Use Einsum to compute the bilinear transformation + # "...i,oij,...j->...o" means: + # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] + result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") + + # Add bias if provided + if bias is not None: + result = op.Add(result, bias) + + return result def aten_binary_cross_entropy_with_logits( diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 0155c6fa73..5d7deb1695 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -37,6 +37,37 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(item, dtype=dtype) +def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for bilinear operation.""" + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + # Test cases: (batch_size, in1_features, in2_features, out_features) + cases = [ + (2, 3, 4, 5), # Basic case + (1, 2, 2, 1), # Minimal case + (3, 5, 7, 4), # Different dimensions + (2, 1, 1, 3), # Single input features + ] + + for batch_size, in1_features, in2_features, out_features in cases: + input1 = make_arg((batch_size, in1_features)) + input2 = make_arg((batch_size, in2_features)) + weight = make_arg((out_features, in1_features, in2_features)) + bias = make_arg((out_features,)) + + # Test with bias + yield opinfo_core.SampleInput(input1, args=(input2, weight, bias)) + + # Test without bias (only for first case to avoid too many tests) + if batch_size == 2: + yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -2283,6 +2314,13 @@ def __init__(self): # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. OP_DB: List[opinfo_core.OpInfo] = [ + opinfo_core.OpInfo( + "bilinear", + op=torch.nn.functional.bilinear, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_bilinear, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index ff4a68d2f6..36ea29f77d 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -657,6 +657,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), + TorchLibOpInfo( + "bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)} + ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we # also define test data for this ID with From e8d906acaeb087aef6981c91e562beefd0fd857e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:47:16 -0700 Subject: [PATCH 057/123] chore(deps): bump actions/setup-python from 5 to 6 (#2551) --- .github/workflows/lint.yaml | 2 +- .github/workflows/main.yaml | 6 +++--- .github/workflows/pages.yaml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 88787d6cce..3fe51a3a5a 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -45,7 +45,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: # Version range or exact version of Python to use, using SemVer's version range syntax. Reads from .python-version if unset. python-version: "3.10" diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index c547608cc6..faf40b9ec3 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -59,7 +59,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install nox @@ -97,7 +97,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" cache: pip @@ -121,7 +121,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 - name: Update readme run: | python docs/update_readme.py diff --git a/.github/workflows/pages.yaml b/.github/workflows/pages.yaml index c38de94b15..ce638dc60d 100644 --- a/.github/workflows/pages.yaml +++ b/.github/workflows/pages.yaml @@ -29,7 +29,7 @@ jobs: - name: Setup Pages uses: actions/configure-pages@v4 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" - uses: actions/checkout@v5 From 256be119d73a06750989c4cfa34dfec28045e0cc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:47:45 -0700 Subject: [PATCH 058/123] chore(deps): bump editorconfig-checker from 3.2.0 to 3.4.0 in /requirements/lintrunner (#2499) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index c71e5de95a..f07a2b52ed 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -8,4 +8,4 @@ types-PyYAML==6.0.12.20250402 # PYLINT pylint==3.3.6 # EDITORCONFIG-CHECKER -editorconfig-checker==3.2.0 +editorconfig-checker==3.4.0 From 8e449da0116714b29a91141fb3709468b9def191 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 10:48:00 -0700 Subject: [PATCH 059/123] chore(deps): bump types-pyyaml from 6.0.12.20250402 to 6.0.12.20250915 in /requirements/lintrunner (#2562) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index f07a2b52ed..38cad45b39 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -4,7 +4,7 @@ lintrunner-adapters>=0.8.0 ruff==0.13.2 # MYPY mypy==1.10.1 -types-PyYAML==6.0.12.20250402 +types-PyYAML==6.0.12.20250915 # PYLINT pylint==3.3.6 # EDITORCONFIG-CHECKER From 4eaf36d0297d6544cf4a27d8c59a32092451f1b5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 18:09:38 +0000 Subject: [PATCH 060/123] chore(deps): bump pylint from 3.3.6 to 3.3.9 in /requirements/lintrunner (#2608) --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 38cad45b39..f95977610e 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -6,6 +6,6 @@ ruff==0.13.2 mypy==1.10.1 types-PyYAML==6.0.12.20250915 # PYLINT -pylint==3.3.6 +pylint==3.3.9 # EDITORCONFIG-CHECKER editorconfig-checker==3.4.0 From 075fc4d1401e4fb0f9f24c157c0df7c747491bcf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 Oct 2025 14:39:06 -0700 Subject: [PATCH 061/123] Simplify aten_unbind when shape is static (#2597) Add static shape handling to aten_unbind function. Fix https://github.com/microsoft/onnxscript/issues/2596 --------- Signed-off-by: Justin Chu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 6 ++++++ tests/function_libs/torch_lib/ops_test.py | 3 ++- tests/function_libs/torch_lib/ops_test_data.py | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e26c9f4e4d..9e6aa69edc 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8799,6 +8799,12 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" + if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): + # We can create a definitive split op if the input shape is static + # Only torch>=2.7 supports correctly generating the correct number of outputs for Split + outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) + return [op.Squeeze(out, [dim]) for out in outputs] + return op.SplitToSequence(self, axis=dim, keepdims=False) diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 7ba6f9d37f..45875043ea 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -39,6 +39,7 @@ from torch.utils import _pytree as pytree import onnxscript +from onnxscript._internal import version_utils from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -200,7 +201,7 @@ def run_test_output_match( reference_torch_outputs, _ = pytree.tree_flatten(torch_output) if ( op.name.startswith("split") - or op.name.startswith("unbind") + or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7")) or op.name in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"} ): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 36ea29f77d..c8d0bf5786 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1481,6 +1481,7 @@ def _where_input_wrangler( reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", ) .xfail( + enabled_if=version_utils.torch_older_than("2.7"), dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), From 9ab7527f8c1e6a62604f3041540737d9d0bd4490 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 8 Oct 2025 15:51:15 -0700 Subject: [PATCH 062/123] Consolidate overloads in torchlib (#2604) The goal is to have no overloads and remove the PyTorch dispatcher. Right now there are still the following ops that need to be addressed: ``` Registering private function: aten::as_strided Registering private function: aten::embedding_bag Registering private function: aten::embedding_bag.padding_idx Registering overload for function: aten::index.Tensor Registering overload for function: aten::_unsafe_index.Tensor Registering overload for function: aten::index_put ``` I did a bit of cleaning up in tests and torchlib as well. https://github.com/microsoft/onnxscript/issues/2580 --------- Signed-off-by: Justin Chu --- noxfile.py | 6 +- onnxscript/backend/onnx_export_test.py | 1 + .../function_libs/torch_lib/ops/core.py | 742 +++++------------- onnxscript/function_libs/torch_lib/ops/nn.py | 82 +- requirements/ci/requirements-ort-nightly.txt | 2 +- tests/function_libs/torch_lib/ops_test.py | 6 +- .../function_libs/torch_lib/ops_test_data.py | 561 ++----------- 7 files changed, 308 insertions(+), 1092 deletions(-) diff --git a/noxfile.py b/noxfile.py index 23c2963998..60c2bb901b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,9 +29,9 @@ "ml-dtypes", ) ONNX = "onnx==1.17" -ONNX_RUNTIME = "onnxruntime==1.20.1" -PYTORCH = "torch==2.5.1" -TORCHVISON = "torchvision==0.20.1" +ONNX_RUNTIME = "onnxruntime==1.23.0" +PYTORCH = "torch==2.7.1" +TORCHVISON = "torchvision==0.22.1" TRANSFORMERS = "transformers==4.37.2" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 49eb398750..1f913ed897 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -84,6 +84,7 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): ), skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"), skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"), + skip(r"^test_attention", "ONNX Runtime 1.23 fails on these tests"), ) if sys.platform == "win32": diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9e6aa69edc..e837bfadae 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -18,21 +18,16 @@ import torch from onnxscript import ( - BFLOAT16, BOOL, COMPLEX64, COMPLEX128, DOUBLE, FLOAT, - FLOAT16, INT8, INT16, INT32, INT64, UINT8, - UINT16, - UINT32, - UINT64, graph, ir, ) @@ -77,13 +72,11 @@ def aten__local_scalar_dense(self: TensorType) -> TensorType: @torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax_half( - self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool -) -> FLOAT: +def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 - if half_to_float: + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: self = op.Cast(self, to=FLOAT.dtype) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) @@ -93,44 +86,23 @@ def aten__log_softmax_half( return result -@torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax( - self: TFloatHighPrecision, - dim: int, - half_to_float: bool, -) -> TFloatHighPrecision: - """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 + + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: + self = op.Cast(self, to=FLOAT.dtype) + if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.LogSoftmax(self, axis=dim) + result = op.Softmax(self, axis=dim) if self_is_scalar: + # Convert to scalar when input is scalar result = op.Squeeze(result) - return result - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only because we need to cast conditionally based on half_to_float - if half_to_float: - self = op.Cast(self, to=FLOAT.dtype) - - return aten_softmax_no_dtype(self, dim) - - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax( - self: TFloatHighPrecision, dim: int, half_to_float: bool -) -> TFloatHighPrecision: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only to reuse aten_softmax_no_dtype - - del half_to_float # Unused - return aten_softmax_no_dtype(self, dim) + return result @torch_op(("aten::abs", "_operator::abs"), trace_only=True) @@ -380,7 +352,6 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::all.dims", trace_only=True) def _aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" @@ -499,7 +470,6 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::any.dims", trace_only=True) def _aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) @@ -739,7 +709,6 @@ def aten_argmax( return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -752,7 +721,6 @@ def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -780,7 +748,6 @@ def aten_argmin( return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -793,7 +760,6 @@ def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -1282,78 +1248,30 @@ def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: ), trace_only=True, ) -def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT16.dtype) - - -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: +def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" # assert other >= 0 - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT32.dtype) - - -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT64.dtype) - + if self.dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + elif self.dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + elif self.dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + elif self.dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + else: + raise NotImplementedError(f"Not implemented for type {self.dtype}") -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT8.dtype) + return op.Cast(result, to=signed_dtype) @torch_op("aten::bitwise_not", trace_only=True) @@ -1395,115 +1313,37 @@ def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: "aten::bitwise_right_shift.Scalar_Tensor", "_operator::__rshift__", "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFF), to=UINT16.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT16.dtype), op.Cast(shifted, to=INT16.dtype) - ) - - -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFFFFFF), to=UINT32.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT32.dtype), op.Cast(shifted, to=INT32.dtype) - ) - - -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) + ), + trace_only=True, ) -def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: +def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - # 0xFFFFFFFFFFFFFFFF - op.Cast(op.Constant(value_int=-1), to=UINT64.dtype), - other, - direction="RIGHT", - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT64.dtype), op.Cast(shifted, to=INT64.dtype) - ) - + if self.dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + mask = ir.tensor(0xFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + mask = ir.tensor(0xFFFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF + else: + raise NotImplementedError(f"Not implemented for type {self.dtype}") -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) # Simulate arithmetic shift using logical shift # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFF), to=UINT8.dtype), other, direction="RIGHT" - ) + mask = op.BitShift(mask, other, direction="RIGHT") mask = op.BitwiseNot(mask) # Do logical shift shifted = op.BitShift(self, other, direction="RIGHT") @@ -1511,7 +1351,7 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: negative_shifted = op.BitwiseOr(shifted, mask) # Choose the shifted value based on the sign bit return op.Where( - negative, op.Cast(negative_shifted, to=INT8.dtype), op.Cast(shifted, to=INT8.dtype) + negative, op.Cast(negative_shifted, to=signed_dtype), op.Cast(shifted, to=signed_dtype) ) @@ -2173,7 +2013,6 @@ def aten_convolution( return result -@torch_op("aten::convolution", private=True, trace_only=True) def _aten_convolution_onnx( input: TFloat, weight: TFloat, @@ -2645,80 +2484,10 @@ def aten_diagflat(self: TensorType, offset: int = 0) -> TensorType: @torch_op(("aten::diagonal", "aten::diagonal_copy"), trace_only=True) -def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TReal: +def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TTensor: """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" - # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims - # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 - # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2 - # [0,1,2] -> [0,1,2] when dim1=1 and dim2=2 - if dim1 < 0: - dim1 = dim1 + len(self.shape) - if dim2 < 0: - dim2 = dim2 + len(self.shape) - - self_rank = len(self.shape) - perm = list(range(self_rank)) - perm.remove(dim1) - perm.remove(dim2) - perm.append(dim1) - perm.append(dim2) - - # If rank=2, then axes=[0]; if rank=3, then axes=[1] - # This is because computing diagonal sum is on dim2 after transpose by perm - axes = [self_rank - 2] - - neg_1 = op.Constant(value_ints=[-1]) - dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row - dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col - mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - mask = op.CastLike(mask, self) - self_t = op.Transpose(self, perm=perm) - result = op.Mul(self_t, mask) - result = op.ReduceSum(result, keepdims=False, axes=axes) - # min(row, col) - min_dim_size = op.Min(dim1_size, dim2_size) - # take 2 tensors as example: - # one is 3x5 in size, min_dim_size = 3, dim1_size = 3 - # the other is 5x3 in size, min_dim_size = 3, dim1_size = 5 - # 3 rows x 5 cols 5 rows x 3 cols - # offset diagonal offset diagonal - # ---------------- ---------------- - # -4 0 -6 0 - # -3 0 -5 0 - # -2 1 -4 1 - # -1 2 -3 2 - # 0 3 -2 3 - # 1 3 -1 3 - # 2 3 0 3 - # 3 2 1 2 - # 4 1 2 1 - # 5 0 3 0 - # 6 0 4 0 - - # From above table, we can get the logic below - offset_val = op.Constant(value_ints=[offset]) - if offset < 0: - # row + offset - length = op.Add(dim1_size, offset_val) - start = op.Constant(value_ints=[0]) - else: # offset >= 0 - # col - offset - length = op.Sub(dim2_size, offset_val) - start = offset_val - - # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) - end = op.Add(start, length) - result = op.Slice(result, start, end, axes=axes) - - return result - - -@torch_op("aten::diagonal", trace_only=True) -def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1) -> BOOL: - """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" + is_bool = self.dtype == BOOL.dtype # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 @@ -2745,10 +2514,16 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - self_int = op.Cast(self, to=INT64.dtype) - mask_int = op.Cast(mask, to=INT64.dtype) - self_int_t = op.Transpose(self_int, perm=perm) - result = op.Mul(self_int_t, mask_int) + + if is_bool: + self_int = op.Cast(self, to=INT64.dtype) + mask_int = op.Cast(mask, to=INT64.dtype) + self_int_t = op.Transpose(self_int, perm=perm) + result = op.Mul(self_int_t, mask_int) + else: + mask = op.CastLike(mask, self) + self_t = op.Transpose(self, perm=perm) + result = op.Mul(self_t, mask) result = op.ReduceSum(result, keepdims=False, axes=axes) # min(row, col) min_dim_size = op.Min(dim1_size, dim2_size) @@ -2785,7 +2560,9 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) - result = op.Cast(result, to=BOOL.dtype) + + if is_bool: + result = op.Cast(result, to=BOOL.dtype) return result @@ -2896,45 +2673,37 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat: +def aten_div_mode(self: TReal, other: TReal, rounding_mode: Optional[str] = None) -> TReal: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" assert rounding_mode in {"trunc", "floor", None} - if rounding_mode == "trunc": - # Rounds the results of the division towards zero. - # Equivalent to C-style integer division - return aten_trunc(op.Div(self, other)) - if rounding_mode == "floor": - return op.Floor(op.Div(self, other)) - - return op.Div(self, other) - + if self.dtype.is_integer(): + quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) -@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode_int( - self: TInt, other: TInt, rounding_mode: Optional[str] = None -) -> TensorType: - """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + if rounding_mode == "trunc": + # Rounds the results of the division towards zero. + # Equivalent to C-style integer division + result = aten_trunc(quotient) + return op.CastLike(result, self) + if rounding_mode == "floor": + result = op.Floor(quotient) + return op.CastLike(result, self) - Variant for integer inputs. - """ - assert rounding_mode in {"trunc", "floor", None} + assert rounding_mode is None + # When rounding_mode is None, the return type is float32 + return quotient - quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + # Float inputs if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - result = aten_trunc(quotient) - return op.CastLike(result, self) + return aten_trunc(op.Div(self, other)) if rounding_mode == "floor": - result = op.Floor(quotient) - return op.CastLike(result, self) + return op.Floor(op.Div(self, other)) - assert rounding_mode is None - # When rounding_mode is None, the return type is float32 - return quotient + return op.Div(self, other) @torch_op("aten::dot", trace_only=True) @@ -3888,26 +3657,18 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), trace_only=True, ) -def aten_ge(self: TReal, other: TReal) -> BOOL: - """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.GreaterOrEqual(self, other) - - -@torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), - trace_only=True, -) -def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL: +def aten_ge(self: TTensor, other: TTensor) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self >= other - # F, F, T - # F, T, F - # T, F, T - # T, T, T + if self.dtype == ir.DataType.BOOL: + # self, other, self >= other + # F, F, T + # F, T, F + # T, F, T + # T, T, T + return op.Or(self, op.Not(other)) - return op.Or(self, op.Not(other)) + return op.GreaterOrEqual(self, other) def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: @@ -4036,25 +3797,19 @@ def aten_gru_cell( ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), trace_only=True, ) -def aten_gt(self: TReal, other: TReal) -> BOOL: +def aten_gt(self: TTensor, other: TTensor) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Greater(self, other) - + if self.dtype == ir.DataType.BOOL: + # self, other, self > other + # F, F, F + # F, T, F + # T, F, T + # T, T, F -@torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), - trace_only=True, -) -def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: - """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self > other - # F, F, F - # F, T, F - # T, F, T - # T, T, F + return op.And(self, op.Not(other)) - return op.And(self, op.Not(other)) + return op.Greater(self, other) @torch_op("aten::hamming_window", trace_only=True) @@ -4875,26 +4630,19 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), trace_only=True, ) -def aten_le(self: TReal, other: TReal) -> BOOL: +def aten_le(self: TTensor, other: TTensor) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.LessOrEqual(self, other) - - -@torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), - trace_only=True, -) -def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: - """le.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + # self, other, self <= other + # F, F, T + # F, T, T + # T, F, F + # T, T, T - # self, other, self <= other - # F, F, T - # F, T, T - # T, F, F - # T, T, T + return op.Or(other, op.Not(self)) - return op.Or(other, op.Not(self)) + return op.LessOrEqual(self, other) @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) @@ -5096,29 +4844,23 @@ def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op("aten::logit", private=True) -def _aten_logit_onnx(self: TFloat) -> TFloat: - return op.Log(op.Div(self, op.Sub(1.0, self))) +@torch_op("aten::logit", trace_only=True) +def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: + """logit(Tensor self, float? eps=None) -> Tensor""" + one = ir.tensor(1, dtype=self.dtype) + + if eps is None: + return op.Log(op.Div(self, op.Sub(one, self))) + one_minus_eps = ir.tensor(1 - eps, dtype=self.dtype) + eps = ir.tensor(eps, dtype=self.dtype) -@torch_op("aten::logit", private=True) -def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat: - eps = op.CastLike(eps, self) - one = op.CastLike(1.0, self) - temporary_self = op.Where(self <= one - eps, self, one - eps) + temporary_self = op.Where(self <= one_minus_eps, self, one_minus_eps) z = op.Where(temporary_self < eps, eps, temporary_self) return op.Log(op.Div(z, op.Sub(one, z))) -@torch_op("aten::logit", trace_only=True) -def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: - """logit(Tensor self, float? eps=None) -> Tensor""" - if eps is None: - return _aten_logit_onnx(self) - return _aten_logit_clamp_onnx(self, eps) - - def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> TensorType: """logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" @@ -5175,26 +4917,18 @@ def aten_lstm_mps_backward( ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, ) -def aten_lt(self: TReal, other: TReal) -> BOOL: - """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.Less(self, other) - - -@torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), - trace_only=True, -) -def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL: +def aten_lt(self: TTensor, other: TTensor) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self < other - # F, F, F - # F, T, T - # T, F, F - # T, T, F + if self.dtype == ir.DataType.BOOL: + # self, other, self < other + # F, F, F + # F, T, T + # T, F, F + # T, T, F + return op.And(other, op.Not(self)) - return op.And(other, op.Not(self)) + return op.Less(self, other) def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: @@ -5368,18 +5102,14 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I return result, indices -@torch_op("aten::maximum") -def aten_maximum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::maximum", trace_only=True) +def aten_maximum(self: TTensor, other: TTensor) -> TTensor: """maximum(Tensor self, Tensor other) -> Tensor""" - return op.Max(self, other) - - -@torch_op("aten::maximum") -def aten_maximum_bool(self: BOOL, other: BOOL) -> BOOL: - """maximum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) - return op.Or(self, other) + return op.Max(self, other) @torch_op("aten::mean") @@ -5414,7 +5144,7 @@ def aten_meshgrid(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::min") +@torch_op("aten::min", trace_only=True) def aten_min(self: TReal) -> TReal: """min(Tensor self) -> Tensor""" @@ -5435,18 +5165,14 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T return result, indices -@torch_op("aten::minimum") -def aten_minimum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::minimum", trace_only=True) +def aten_minimum(self: TTensor, other: TTensor) -> TTensor: """minimum(Tensor self, Tensor other) -> Tensor""" - return op.Min(self, other) - - -@torch_op("aten::minimum") -def aten_minimum_bool(self: BOOL, other: BOOL) -> BOOL: - """minimum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) - return op.And(self, other) + return op.Min(self, other) def aten_miopen_batch_norm( @@ -5789,23 +5515,13 @@ def aten_msort(self: TensorType) -> TensorType: ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), trace_only=True, ) -def aten_mul(self: TReal, other: TReal) -> TReal: +def aten_mul(self: TTensor, other: TTensor) -> TTensor: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Mul(self, other) - - -@torch_op( - ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), - trace_only=True, -) -def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: - """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" - - # TODO(justinchuby): Handle cases where type reconcilation is not enough, - # since different ONNX operators are used based on different data types. + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) - return op.And(self, other) + return op.Mul(self, other) @torch_op( @@ -6047,7 +5763,6 @@ def aten_native_batch_norm( return norm, input_mean, input_rstd -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_training_onnx( input: TFloat, weight: TFloat, @@ -6099,7 +5814,6 @@ def _aten_native_batch_norm_training_onnx( return norm, mean, rstd, running_mean, new_running_var -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_inference_onnx( input: TFloat, weight: TFloat, @@ -6269,22 +5983,10 @@ def aten_native_group_norm( if bias is None: # Set to 0.0 as default, the shape is Channel size bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) - # Accoding to Torch, return rstd instead of var - norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) - return norm, mean, rstd - - -@torch_op("aten::native_group_norm", private=True) -def _aten_native_group_norm_onnx( - input: TFloat, - weight: TFloat, - bias: TFloat, - group: INT64, - eps: float, -) -> Tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate + # This implementation should be simplified after opset 21 neg_1 = op.Constant(value_ints=[-1]) # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter group_tensor = op.Reshape(group, neg_1) @@ -6321,7 +6023,9 @@ def _aten_native_group_norm_onnx( sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) # In Pytorch, vstd = 1/(sqrt(var + eps)) var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) - rstd = op.Div(1.0, op.Sqrt(var + eps)) + eps = op.Constant(value=ir.tensor(eps, dtype=input.dtype)) + one = op.Constant(value=ir.tensor(1.0, dtype=input.dtype)) + rstd = op.Div(one, op.Sqrt(op.Add(var, eps))) # Get the correct shape [N, group] for mean again mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=False) return norm_result, mean, rstd @@ -6533,16 +6237,7 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp raise NotImplementedError() -@torch_op( - ( - "aten::normal.Tensor_float", - "aten::normal.Tensor_Tensor", - "aten::normal.float_Tensor", - "aten::normal.float_float", - "aten::normal_functional", - ), - trace_only=True, -) +@torch_op("aten::normal_functional", trace_only=True) def aten_normal( self: TTensor, mean: float = 0.0, @@ -6571,7 +6266,7 @@ def aten_normal_float_float( return op.Cast(result, to=dtype) -@torch_op("aten::normal.float_Tensor") +@torch_op("aten::normal.float_Tensor", trace_only=True) def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: """normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6581,7 +6276,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: return op.Add(op.Mul(std, sampled), mean_casted) -@torch_op("aten::normal.Tensor_float") +@torch_op("aten::normal.Tensor_float", trace_only=True) def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: """normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor""" @@ -6590,7 +6285,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean) -@torch_op("aten::normal.Tensor_Tensor") +@torch_op("aten::normal.Tensor_Tensor", trace_only=True) def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat: """normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -7298,10 +6993,15 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) -def aten_remainder(self: TFloat, other: TFloat) -> TFloat: +@torch_op( + ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True +) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype.is_integer(): + return op.Mod(self, other) + # TODO(justinchuby): Improve fp16 precision by following the logic in # https://github.com/pytorch/pytorch/blob/3a823e46170778cc32783f27596c77d0103084a9/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L264-L277 @@ -7311,15 +7011,6 @@ def aten_remainder(self: TFloat, other: TFloat) -> TFloat: return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op( - ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True -) -def aten_remainder_int(self: TInt, other: TInt) -> TInt: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.Mod(self, other) - - def aten_rename(self: TensorType, names: Optional[str]) -> TensorType: """rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)""" @@ -7538,23 +7229,29 @@ def aten_rnn_tanh_cell( def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 0: return op.Identity(self) elif self.shape[0] == 0: # empty tensor return op.Identity(self) + + # NOTE: In pytorch, default value of dims is an empty list. + if len(dims) == 0: # Empty sequence + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + return _aten_roll_shift_no_dim_onnx(self, shifts[0]) else: - # NOTE: In pytorch, default value of dims is an empty list. - if len(dims) == 0: # Empty sequence - # assert isinstance(shifts, int) - return _aten_roll_shift_no_dim_onnx(self, shifts) - else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list - result = self - for i, shift in enumerate(shifts): - dim = dims[i] - result = _aten_roll_shift_and_dim_onnx(result, shift, dim) - return result + assert len(shifts) == len(dims) + result = self + for i, shift in enumerate(shifts): + dim = dims[i] + result = _aten_roll_shift_and_dim_onnx(result, shift, dim) + return result @torch_op("aten::roll", trace_only=True, complex=True) @@ -7563,6 +7260,12 @@ def aten_roll_complex( ) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 1: return op.Identity(self) @@ -7573,37 +7276,34 @@ def aten_roll_complex( self_real = op.Slice(self, [0], [1], axes=[-1]) self_imag = op.Slice(self, [1], [2], axes=[-1]) if not dims: - # assert isinstance(shifts, int) - shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts) - shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts) + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0]) + shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0]) result = op.Concat(shift_real, shift_imag, axis=-1) else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list + assert len(shifts) == len(dims) for i, dim in enumerate(dims): - shift = op.Gather(shifts, i, axis=0) - self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim) - self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim) + self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim) + self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim) result = op.Concat(self_real, self_imag, axis=-1) return result -@torch_op("aten::roll", private=True) -def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: +def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D] self_flatten = op.Reshape(self, neg_1) # Compute slice length - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: + if shift < 0: # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end - slice_length = -shift_tensor + slice_length = op.Constant(value_ints=[-shift]) else: # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end # The effect equals to move [D] to the beginning - slice_length = op.Size(self_flatten) - shift_tensor + slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift]) # Get second part of the tensor, e.g. [A,B,C] suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length) # Get first part of the tensor, e.g. [D] @@ -7613,15 +7313,13 @@ def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: return op.Reshape(result, op.Shape(self)) -@torch_op("aten::roll", private=True) -def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor: +def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) - dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1) - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: - slice_length = -shift_tensor + dim_tensor = op.Constant(value_ints=[dim]) + if shift < 0: + slice_length = op.Constant(value_ints=[-shift]) else: - slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor + slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift]) # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) @@ -7700,7 +7398,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op("aten::scalar_tensor", trace_only=True) def aten_scalar_tensor( - s: float, + s: TensorType, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -7709,8 +7407,7 @@ def aten_scalar_tensor( """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # Set trace_only=True because different if branches return different dtypes - # which is not supported in an ONNX function + return common_ops.cast_to(s, dtype=dtype) @@ -7739,20 +7436,6 @@ def aten_scalar_tensor_complex( return result -@torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor_sym_number( - s: TensorType, - dtype: int = FLOAT.dtype, - layout: str = "", - device: str = "", - pin_memory: bool = False, -) -> RealType: - """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1: - dtype = FLOAT.dtype - return common_ops.cast_to(s, dtype=dtype) - - @torch_op("aten::scatter.src", trace_only=True) def aten_scatter_src( self: TTensor, @@ -8140,7 +7823,7 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) - if dtype != -1: + if dtype != -1 and dtype is not None: result = op.Cast(result, to=dtype) if self_is_scalar: # Convert to scalar when input is scalar @@ -8149,21 +7832,6 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: return result -@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) -def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat: - """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - - self_is_scalar = len(self.shape) == 0 - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) - if self_is_scalar: - # Convert to scalar when input is scalar - result = op.Squeeze(result) - - return result - - @torch_op("aten::sort", trace_only=True) def aten_sort( self: TReal, dim: int = -1, descending: bool = False, stable: bool = False diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 1a31c9eac8..2a7a46ec28 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -294,20 +294,16 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu", trace_only=True) -def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: +def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - return op.Celu(self, alpha=alpha) # op.Celu only support float32 + if self.dtype != FLOAT.dtype: + self_upcasted = op.Cast(self, to=FLOAT.dtype) + # op.Celu only support float32 + return op.Cast(op.Celu(self_upcasted, alpha=alpha), to=self.dtype) -@torch_op("aten::celu", trace_only=True) -def aten_celu_type_promoted( - self: TFloatUnlessFloat32, alpha: float = 1.0 -) -> TFloatUnlessFloat32: - """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - - self_upcasted = op.Cast(self, to=FLOAT.dtype) - return op.CastLike(op.Celu(self_upcasted, alpha=alpha), self) + return op.Celu(self, alpha=alpha) @torch_op("aten::col2im", trace_only=True) @@ -1804,7 +1800,7 @@ def aten_scaled_dot_product_attention( query: TFloat, key: TFloat, value: TFloat, - attn_mask: Optional[TFloat] = None, + attn_mask: Optional[TensorType] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1854,6 +1850,11 @@ def aten_scaled_dot_product_attention( query, key, value, scale, dropout_p ) + if attn_mask.dtype == ir.DataType.BOOL: + return _aten_scaled_dot_product_attention_bool_mask_onnx( + query, key, value, attn_mask, scale, dropout_p + ) + return _aten_scaled_dot_product_attention_float_mask_onnx( query, key, value, attn_mask, scale, dropout_p ) @@ -1921,7 +1922,6 @@ def aten__scaled_dot_product_flash_attention( ) -@torch_op("aten::_scaled_dot_product_efficient_attention", private=True) def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, @@ -2016,64 +2016,6 @@ def aten__scaled_dot_product_efficient_attention( ) -@torch_op("aten::scaled_dot_product_attention", trace_only=True) -def aten_scaled_dot_product_attention_bool_mask( - query: TFloat, - key: TFloat, - value: TFloat, - attn_mask: Optional[BOOL] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor - - Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - - Equivalent to the PyTorch code:: - scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale - attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask - attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask - attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p) - return attn_weight @ V - - where Q, K, V are the query, key, and value tensors, respectively. - L is the target sequence length, S is the source sequence length, and E is the embedding size. - """ - # Use trace_only to handle optional inputs - assert (not is_causal) or (is_causal and attn_mask is None), ( - "is_causal and attn_mask cannot be set at the same time" - ) - assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( - "only 4D query, key, and value are supported" - ) - - if scale is None: - scale = _attention_scale(query) - scale = op.CastLike(scale, query) - - if is_causal: - attn_mask = _causal_attention_mask(query, key) - # The causal mask is always float - return _aten_scaled_dot_product_attention_float_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - if enable_gqa: - key, value = _attention_repeat_kv_for_group_query(query, key, value) - - if attn_mask is None: - return _aten_scaled_dot_product_attention_no_mask_onnx( - query, key, value, scale, dropout_p - ) - - return _aten_scaled_dot_product_attention_bool_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - def _aten_scaled_dot_product_attention_no_mask_onnx( query: TFloat, key: TFloat, diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index 4ed908b4e2..b54550738b 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ # https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -onnxruntime==1.23.0.dev20250517001 +onnxruntime==1.23.0.dev20251001001 diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 45875043ea..a45050fb22 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -99,7 +99,7 @@ def _should_skip_xfail_test_sample( class TestFunctionValidity(unittest.TestCase): @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_script_function_passes_checker( self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo @@ -110,10 +110,12 @@ def test_script_function_passes_checker( onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo): func = torchlib_op_info.op + if not hasattr(func, "op_schema"): + raise AssertionError(f"Function {func.__name__} does not have op_schema attribute") schema = func.op_schema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index c8d0bf5786..b60fd8cf31 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -48,7 +48,6 @@ from torch.testing._internal.opinfo import definitions as opinfo_definitions from typing_extensions import Self -from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib import _flags from onnxscript.function_libs.torch_lib.ops import core as core_ops from onnxscript.function_libs.torch_lib.ops import fft as fft_ops @@ -459,40 +458,13 @@ def _where_input_wrangler( fft_ops.aten__fft_r2c, tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, ), + TorchLibOpInfo("ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense), TorchLibOpInfo( - "ops.aten._local_scalar_dense", - core_ops.aten__local_scalar_dense, - ), - TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax), - TorchLibOpInfo( - "ops.aten._log_softmax_half", - core_ops.aten__log_softmax_half, + "ops.aten._log_softmax", + core_ops.aten__log_softmax, tolerance={torch.float16: (1e-3, 1e-3)}, - ) - .xfail( - reason="PyTorch does not implement _log_softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), - TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half) - .xfail( - reason="PyTorch does not implement _softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), @@ -503,10 +475,7 @@ def _where_input_wrangler( reason="this overload requires dim to be a tuple", ), TorchLibOpInfo("allclose", core_ops.aten_allclose), - TorchLibOpInfo( - "all", - core_ops.aten_all, - ).skip( + TorchLibOpInfo("all", core_ops.aten_all).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -541,32 +510,14 @@ def _where_input_wrangler( reason="zero sized inputs cannot be compared", ), TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}), - TorchLibOpInfo( - "addr", - core_ops.aten_addr, - tolerance={torch.float16: (3e-3, 4e-3)}, - ), - TorchLibOpInfo( - "amax", - core_ops.aten_amax, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "amin", - core_ops.aten_amin, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "any", - core_ops.aten_any, - ).skip( + TorchLibOpInfo("addr", core_ops.aten_addr, tolerance={torch.float16: (3e-3, 4e-3)}), + TorchLibOpInfo("amax", core_ops.aten_amax, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("amin", core_ops.aten_amin, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("any", core_ops.aten_any).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), - TorchLibOpInfo( - "any_dim", - core_ops.aten_any_dim, - ).skip( + TorchLibOpInfo("any_dim", core_ops.aten_any_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", @@ -584,76 +535,46 @@ def _where_input_wrangler( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_1d_Sequence", - core_ops.aten_atleast_1d_sequence, - ) + TorchLibOpInfo("atleast_1d_Sequence", core_ops.aten_atleast_1d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_2d", core_ops.aten_atleast_2d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_2d_Sequence", - core_ops.aten_atleast_2d_sequence, - ) + TorchLibOpInfo("atleast_2d_Sequence", core_ops.aten_atleast_2d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_3d", core_ops.aten_atleast_3d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_3d_Sequence", - core_ops.aten_atleast_3d_sequence, - ) + TorchLibOpInfo("atleast_3d_Sequence", core_ops.aten_atleast_3d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), @@ -671,16 +592,10 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p), TorchLibOpInfo("bitwise_and", core_ops.aten_bitwise_and), - TorchLibOpInfo("bitwise_left_shift_int16", core_ops.aten_bitwise_left_shift_int16), - TorchLibOpInfo("bitwise_left_shift_int32", core_ops.aten_bitwise_left_shift_int32), - TorchLibOpInfo("bitwise_left_shift_int64", core_ops.aten_bitwise_left_shift_int64), - TorchLibOpInfo("bitwise_left_shift_int8", core_ops.aten_bitwise_left_shift_int8), + TorchLibOpInfo("bitwise_left_shift", core_ops.aten_bitwise_left_shift), TorchLibOpInfo("bitwise_not", core_ops.aten_bitwise_not), TorchLibOpInfo("bitwise_or", core_ops.aten_bitwise_or), - TorchLibOpInfo("bitwise_right_shift_int16", core_ops.aten_bitwise_right_shift_int16), - TorchLibOpInfo("bitwise_right_shift_int32", core_ops.aten_bitwise_right_shift_int32), - TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), - TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), + TorchLibOpInfo("bitwise_right_shift", core_ops.aten_bitwise_right_shift), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), @@ -698,10 +613,7 @@ def _where_input_wrangler( reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("ceil", core_ops.aten_ceil), - TorchLibOpInfo("chunk", core_ops.aten_chunk).skip( - enabled_if=version_utils.torch_older_than("2.7"), - reason="Test for chunk is not configured for torch<2.7", - ), + TorchLibOpInfo("chunk", core_ops.aten_chunk), TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, @@ -737,7 +649,6 @@ def _where_input_wrangler( TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB TorchLibOpInfo("diagonal", core_ops.aten_diagonal), - TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool), TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", @@ -755,7 +666,6 @@ def _where_input_wrangler( # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ), - TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( "empty", @@ -765,8 +675,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler) .xfail( - reason="fixme: PyTorch produces int64 output with int32 input", - dtypes=(torch.int32,), + reason="fixme: PyTorch produces int64 output with int32 input", dtypes=(torch.int32,) ) .xfail( reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739", @@ -800,21 +709,15 @@ def _where_input_wrangler( TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), - TorchLibOpInfo( - "full_like", - core_ops.aten_full_like, - ).skip( - enabled_if=ops_test_common.IS_MACOS, - reason="fixme: memory allocation issue on CI", + TorchLibOpInfo("full_like", core_ops.aten_full_like).skip( + enabled_if=ops_test_common.IS_MACOS, reason="fixme: memory allocation issue on CI" ), TorchLibOpInfo("gather", core_ops.aten_gather).skip( matcher=lambda sample: sample.input.numel() == 0 or sample.args[1].numel() == 0, reason="fixme: ORT does not support empty tensors as input", ), TorchLibOpInfo("ge", core_ops.aten_ge), - TorchLibOpInfo("ge_bool", core_ops.aten_ge_bool), TorchLibOpInfo("gt", core_ops.aten_gt), - TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), @@ -828,9 +731,7 @@ def _where_input_wrangler( reason="this Aten overload only supports tensor(bool) as indices", ), TorchLibOpInfo( - "index_put", - core_ops.aten_index_put, - input_wrangler=_index_put_input_wrangler, + "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler ) .skip( matcher=lambda sample: sample.args[0][0].dtype != torch.int64, @@ -870,20 +771,13 @@ def _where_input_wrangler( dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", ) - .xfail( - variant_name="tensor_overload", - dtypes=(torch.int64, torch.int32), + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - enabled_if=not version_utils.torch_older_than("2.2"), ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), - TorchLibOpInfo("le_bool", core_ops.aten_le_bool), - TorchLibOpInfo( - "lerp", - core_ops.aten_lerp, - tolerance={torch.float16: (2e-3, 2e-1)}, - ), + TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), TorchLibOpInfo("log10", core_ops.aten_log10), TorchLibOpInfo("log1p", core_ops.aten_log1p), TorchLibOpInfo( @@ -922,7 +816,6 @@ def _where_input_wrangler( TorchLibOpInfo("logdet", core_ops.aten_logdet), TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp), TorchLibOpInfo("lt", core_ops.aten_lt), - TorchLibOpInfo("lt_bool", core_ops.aten_lt_bool), TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", @@ -938,19 +831,12 @@ def _where_input_wrangler( reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), TorchLibOpInfo("maximum", core_ops.aten_maximum), - TorchLibOpInfo("maximum_bool", core_ops.aten_maximum_bool), - TorchLibOpInfo( - "mean", - core_ops.aten_mean, - input_wrangler=_mean_input_wrangler, - ).skip( + TorchLibOpInfo("mean", core_ops.aten_mean, input_wrangler=_mean_input_wrangler).skip( matcher=lambda sample: sample.kwargs.get("dim") is not None, reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo( - "mean_dim", - core_ops.aten_mean_dim, - input_wrangler=_mean_input_wrangler, + "mean_dim", core_ops.aten_mean_dim, input_wrangler=_mean_input_wrangler ).skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="this Aten overload can accept 2 inputs:(self, dim)", @@ -962,15 +848,11 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "min", - core_ops.aten_min, - ).skip( + TorchLibOpInfo("min", core_ops.aten_min).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), TorchLibOpInfo("minimum", core_ops.aten_minimum), - TorchLibOpInfo("minimum_bool", core_ops.aten_minimum_bool), TorchLibOpInfo("mm", core_ops.aten_mm).skip( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", @@ -979,39 +861,19 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), - TorchLibOpInfo( - "mv", - core_ops.aten_mv, - tolerance={torch.float16: (3e-2, 1e-2)}, - ), + TorchLibOpInfo("mv", core_ops.aten_mv, tolerance={torch.float16: (3e-2, 1e-2)}), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), TorchLibOpInfo("neg", core_ops.aten_neg), + TorchLibOpInfo("new_empty", core_ops.aten_new_empty, nondeterministic=True), TorchLibOpInfo( - "new_empty", - core_ops.aten_new_empty, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_empty_strided", - core_ops.aten_new_empty_strided, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_full", - core_ops.aten_new_full, - ), - TorchLibOpInfo( - "new_ones", - core_ops.aten_new_ones, - ), - TorchLibOpInfo( - "new_zeros", - core_ops.aten_new_zeros, + "new_empty_strided", core_ops.aten_new_empty_strided, nondeterministic=True ), + TorchLibOpInfo("new_full", core_ops.aten_new_full), + TorchLibOpInfo("new_ones", core_ops.aten_new_ones), + TorchLibOpInfo("new_zeros", core_ops.aten_new_zeros), TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu), - TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted), TorchLibOpInfo( "nn.functional.cross_entropy", # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) @@ -1024,9 +886,7 @@ def _where_input_wrangler( reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[target] as int type", ), TorchLibOpInfo( - "nn.functional.dropout", - core_ops.aten_dropout, - input_wrangler=_dropout_input_wrangler, + "nn.functional.dropout", core_ops.aten_dropout, input_wrangler=_dropout_input_wrangler ).skip( matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0, reason="dropout is random so the result not match", @@ -1037,10 +897,7 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), - ).skip( - dtypes=(torch.float16,), - reason="fixme: results mismatch in torch nightly.", - ), + ).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, @@ -1075,10 +932,7 @@ def _where_input_wrangler( tolerance={torch.float16: (5e-2, 1e-2)}, ), TorchLibOpInfo("nn.functional.pad", nn_ops.aten_pad) - .skip( - variant_name="circular", - reason="fixme: ORT does not support the circular mode", - ) + .skip(variant_name="circular", reason="fixme: ORT does not support the circular mode") .skip( variant_name="replicate_negative", reason="fixme: The implementation for negative paddings is not correct", @@ -1100,10 +954,7 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.reflection_pad1d", nn_ops.aten_reflection_pad1d, - ).xfail( - dtypes=(torch.int64,), - reason="Torch not implement reflection_pad1d for int64.", - ), + ).xfail(dtypes=(torch.int64,), reason="Torch not implement reflection_pad1d for int64."), TorchLibOpInfo( "nn.functional.reflection_pad2d", nn_ops.aten_reflection_pad2d, @@ -1112,26 +963,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "reflect"), reason="this Aten overload need args[1] == 'reflect' for pad mode", ), - TorchLibOpInfo( - "nn.functional.relu", - nn_ops.aten_relu, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "nn.functional.relu6", - nn_ops.aten_relu6, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "ops.aten.replication_pad1d", - nn_ops.aten_replication_pad1d, - ), + TorchLibOpInfo("nn.functional.relu", nn_ops.aten_relu), + TorchLibOpInfo("nn.functional.relu6", nn_ops.aten_relu6), + TorchLibOpInfo("ops.aten.replication_pad1d", nn_ops.aten_replication_pad1d), TorchLibOpInfo( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d, @@ -1141,10 +975,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", ) - .xfail( + .skip( variant_name="replicate_negative", - enabled_if=not version_utils.torch_older_than("2.2"), - reason="fixme: negative padding is not implemented yet", + reason="fixme: The implementation for negative paddings is not correct. Potentially an ORT issue", ), TorchLibOpInfo( "nn.functional.replication_pad3d", @@ -1160,15 +993,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("nn.functional.selu", core_ops.aten_selu), TorchLibOpInfo( - "nn.functional.mse_loss", - nn_ops.aten_mse_loss, - input_wrangler=_mse_loss_input_wrangler, + "nn.functional.mse_loss", nn_ops.aten_mse_loss, input_wrangler=_mse_loss_input_wrangler ), - TorchLibOpInfo( - "nonzero", - core_ops.aten_nonzero, - input_wrangler=_nonzero_input_wrangler, - ) + TorchLibOpInfo("nonzero", core_ops.aten_nonzero, input_wrangler=_nonzero_input_wrangler) .xfail( matcher=lambda sample: sample.kwargs.get("as_tuple"), reason="as_tuple=True is not supported", @@ -1231,26 +1058,19 @@ def _where_input_wrangler( nondeterministic=True, ), TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( - dtypes=(torch.float16,), - reason="fixme: Shape inference error", + dtypes=(torch.float16,), reason="fixme: Shape inference error" ), TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), TorchLibOpInfo("rad2deg", core_ops.aten_rad2deg), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), - TorchLibOpInfo( - "remainder", - core_ops.aten_remainder, - ), + TorchLibOpInfo("remainder", core_ops.aten_remainder), TorchLibOpInfo("repeat", core_ops.aten_repeat), TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) .skip( matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), reason=("ignore cases when repeasts is a Tensor"), ) - .skip( - dtypes=(torch.bool,), - reason="bool not supported", - ) + .skip(dtypes=(torch.bool,), reason="bool not supported") .skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="fixme: conversion not implemented if dim is None", @@ -1264,10 +1084,7 @@ def _where_input_wrangler( matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), reason=("ignore cases when repeasts is an int"), ) - .skip( - dtypes=(torch.bool,), - reason="bool not supported", - ) + .skip(dtypes=(torch.bool,), reason="bool not supported") .skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="fixme: conversion not implemented if dim is None", @@ -1297,14 +1114,9 @@ def _where_input_wrangler( complex=True, ), TorchLibOpInfo( - "ops.aten.scalar_tensor", - core_ops.aten_scalar_tensor_complex, - complex=True, + "ops.aten.scalar_tensor", core_ops.aten_scalar_tensor_complex, complex=True ), - TorchLibOpInfo( - "scatter_add", - core_ops.aten_scatter_add, - ) + TorchLibOpInfo("scatter_add", core_ops.aten_scatter_add) .xfail( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch. https://github.com/onnx/onnx/issues/4986", @@ -1353,48 +1165,10 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: Tensor-likes are not close. Tests pass for float32.", ), - TorchLibOpInfo( - "split_with_sizes", - core_ops.aten_split_with_sizes, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), - TorchLibOpInfo( - "split", - core_ops.aten_split, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), + TorchLibOpInfo("split_with_sizes", core_ops.aten_split_with_sizes), + TorchLibOpInfo("split", core_ops.aten_split), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1404,11 +1178,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim_complex, - complex=True, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1418,10 +1188,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze", - core_ops.aten_squeeze, - ).skip( + TorchLibOpInfo("squeeze", core_ops.aten_squeeze).skip( matcher=lambda sample: len(sample.args) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -1430,20 +1197,14 @@ def _where_input_wrangler( TorchLibOpInfo("sub", core_ops.aten_sub, tolerance={torch.float16: (2e-3, 1e-3)}), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB - TorchLibOpInfo( - "t", - core_ops.aten_t, - ).xfail( + TorchLibOpInfo("t", core_ops.aten_t).xfail( enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986", test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("tan", core_ops.aten_tan), TorchLibOpInfo("tanh", core_ops.aten_tanh), - TorchLibOpInfo( - "tile", - core_ops.aten_tile, - ).skip( + TorchLibOpInfo("tile", core_ops.aten_tile).skip( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) or not sample.input.shape, reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", @@ -1471,20 +1232,7 @@ def _where_input_wrangler( reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("trunc", core_ops.aten_trunc), - TorchLibOpInfo( - "unbind", - core_ops.aten_unbind, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - enabled_if=version_utils.torch_older_than("2.7"), - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), + TorchLibOpInfo("unbind", core_ops.aten_unbind), TorchLibOpInfo("unflatten", core_ops.aten_unflatten), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), @@ -1503,10 +1251,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("xlogy", special_ops.aten_special_xlogy), TorchLibOpInfo("zeros", core_ops.aten_zeros), - TorchLibOpInfo( - "arange_start_step", - core_ops.aten_arange_start_step, - ) + TorchLibOpInfo("arange_start_step", core_ops.aten_arange_start_step) .skip( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", @@ -1516,10 +1261,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange_start", - core_ops.aten_arange_start, - ) + TorchLibOpInfo("arange_start", core_ops.aten_arange_start) .skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", @@ -1529,10 +1271,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange", - core_ops.aten_arange, - ) + TorchLibOpInfo("arange", core_ops.aten_arange) .xfail( dtypes=(torch.int32,), reason="fixme: output shape mismatch in edge cases. https://github.com/microsoft/onnxscript/issues/974", @@ -1555,10 +1294,7 @@ def _where_input_wrangler( TorchLibOpInfo( "as_strided", core_ops.aten_as_strided, - ).xfail( - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor", - ), + ).xfail(variant_name="partial_views", reason="ONNX doesn't have partial view for tensor"), TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", @@ -1578,19 +1314,13 @@ def _where_input_wrangler( tolerance={torch.float32: (2e-4, 9e-4)}, ), TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), - TorchLibOpInfo( - "grid_sampler_2d", - core_ops.aten_grid_sampler_2d, - ) + TorchLibOpInfo("grid_sampler_2d", core_ops.aten_grid_sampler_2d) .skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", ) - .skip( - dtypes=(torch.float16,), - reason="fixme: Accuracy is not high enough", - ), + .skip(dtypes=(torch.float16,), reason="fixme: Accuracy is not high enough"), TorchLibOpInfo( "nn.functional.group_norm", nn_ops.aten_group_norm, @@ -1651,10 +1381,7 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "max", - core_ops.aten_max, - ).skip( + TorchLibOpInfo("max", core_ops.aten_max).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), @@ -1712,8 +1439,7 @@ def _where_input_wrangler( reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( - "ops.aten._native_batch_norm_legit.no_stats", - core_ops.aten__native_batch_norm_no_stats, + "ops.aten._native_batch_norm_legit.no_stats", core_ops.aten__native_batch_norm_no_stats ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit_functional", @@ -1734,10 +1460,6 @@ def _where_input_wrangler( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, tolerance={torch.float16: (1e-2, 7e-3)}, - ).xfail( - dtypes=(torch.float16,), - reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly", - enabled_if=version_utils.torch_older_than("2.2"), ), TorchLibOpInfo( "native_layer_norm", @@ -1819,9 +1541,7 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-3)}, ), TorchLibOpInfo( - "ops.aten.conv3d", - core_ops.aten_conv3d, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + "ops.aten.conv3d", core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)} ), TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), @@ -1902,11 +1622,6 @@ def _where_input_wrangler( nn_ops.aten_scaled_dot_product_attention, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype == torch.bool, - reason="this overload takes a non-boolean mask", - ) .skip( matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", @@ -1929,15 +1644,7 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( - device_type="cpu", - reason="_scaled_dot_product_flash_attention only supports CUDA", - ), + ).skip(device_type="cpu", reason="_scaled_dot_product_flash_attention only supports CUDA"), TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", nn_ops.aten__scaled_dot_product_efficient_attention, @@ -1945,40 +1652,10 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( + ).skip( enabled_if=not torch.cuda.is_available(), reason="_scaled_dot_product_efficient_attention only supports CUDA", ), - TorchLibOpInfo( - "nn.functional.scaled_dot_product_attention_bool_mask", - nn_ops.aten_scaled_dot_product_attention_bool_mask, - tolerance={torch.float32: (3e-4, 1.5e-5)}, - ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype != torch.bool, - reason="this overload takes a boolean mask", - ) - .skip( - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ) - .xfail( - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ) - .xfail( - matcher=lambda sample: len(sample.input.shape) != 4 - or len(sample.args[0].shape) != 4 - or len(sample.args[1].shape) != 4, - reason="torch sdpa is expected to pass in 4d q, k, and v.", - ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, @@ -1998,10 +1675,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bilinear2d.vec", - nn_ops.aten_upsample_bilinear2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec), TorchLibOpInfo( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, @@ -2021,10 +1695,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bicubic2d.vec", - nn_ops.aten_upsample_bicubic2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, @@ -2033,38 +1704,14 @@ def _where_input_wrangler( and sample.kwargs.get("scales") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d", - nn_ops.aten_upsample_nearest1d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d", - nn_ops.aten_upsample_nearest2d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d", - nn_ops.aten_upsample_nearest3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.default", - nn_ops.aten_upsample_trilinear3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.vec", - nn_ops.aten_upsample_trilinear3d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d), + TorchLibOpInfo("ops.aten.upsample_nearest1d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d), + TorchLibOpInfo("ops.aten.upsample_nearest2d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d), + TorchLibOpInfo("ops.aten.upsample_nearest3d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), TorchLibOpInfo("ones_like", core_ops.aten_ones_like), TorchLibOpInfo( "roll", @@ -2082,10 +1729,7 @@ def _where_input_wrangler( core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, ) - .xfail( - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ) + .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64), @@ -2159,40 +1803,13 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_left_shift", - ( - "bitwise_left_shift_int8", - "bitwise_left_shift_int16", - "bitwise_left_shift_int32", - "bitwise_left_shift_int64", - ), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_right_shift", - ( - "bitwise_right_shift_int8", - "bitwise_right_shift_int16", - "bitwise_right_shift_int32", - "bitwise_right_shift_int64", - ), -) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) -ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) -ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad", @@ -2202,20 +1819,6 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.scaled_dot_product_attention", - ("nn.functional.scaled_dot_product_attention_bool_mask",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.celu", - ("nn.functional.celu_type_promoted",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) -) -ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) From cb6f873612d05d7e5abf40dd1fe49325b5143a46 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 8 Oct 2025 22:18:47 -0700 Subject: [PATCH 063/123] chore(deps): bump onnxruntime from 1.23.0.dev20250517001 to 1.23.1 in /requirements/ci (#2614) --- requirements/ci/requirements-ort-nightly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index b54550738b..cb16597719 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ # https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -onnxruntime==1.23.0.dev20251001001 +onnxruntime==1.23.1 From 59c3d32ea0cf18fbd348d8b4e23fdb8dad6427ea Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 14:34:18 -0700 Subject: [PATCH 064/123] [torchlib] Fix implementations for bitwise_* overloads (#2618) Some overloads for bitwise_* can accept scalar inputs which do not have the dtype. This PR creates implementations for the overloads. Fix https://github.com/microsoft/onnxscript/issues/2617 --------- Signed-off-by: Justin Chu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../function_libs/torch_lib/ops/core.py | 144 +++++++++++++----- .../function_libs/torch_lib/e2e_ops_tests.py | 13 ++ 2 files changed, 122 insertions(+), 35 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e837bfadae..5127f3f9f6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1220,8 +1220,6 @@ def aten_binomial( @torch_op( ( "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", "_operator::and_", ), trace_only=True, @@ -1229,42 +1227,61 @@ def aten_binomial( def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None - if self.dtype.is_integer(): + if dtype.is_integer(): return op.BitwiseAnd(self, other) - if self.dtype == ir.DataType.BOOL: + if dtype == ir.DataType.BOOL: return op.And(self, other) raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") +@torch_op("aten::bitwise_and.Scalar", trace_only=True) +def aten_bitwise_and_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor""" + + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_and(self, other_tensor) + + +@torch_op("aten::bitwise_and.Scalar_Tensor", trace_only=True) +def aten_bitwise_and_scalar_tensor(self: float, other: TTensor) -> TTensor: + """bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_and(self_tensor, other) + + @torch_op( ( "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", "_operator::__lshift__", - "aten::__lshift__.Scalar", ), trace_only=True, ) def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + # assert other >= 0 - if self.dtype.bitwidth == 8: + if dtype.bitwidth == 8: unsigned_dtype = ir.DataType.UINT8 signed_dtype = ir.DataType.INT8 - elif self.dtype.bitwidth == 16: + elif dtype.bitwidth == 16: unsigned_dtype = ir.DataType.UINT16 signed_dtype = ir.DataType.INT16 - elif self.dtype.bitwidth == 32: + elif dtype.bitwidth == 32: unsigned_dtype = ir.DataType.UINT32 signed_dtype = ir.DataType.INT32 - elif self.dtype.bitwidth == 64: + elif dtype.bitwidth == 64: unsigned_dtype = ir.DataType.UINT64 signed_dtype = ir.DataType.INT64 else: - raise NotImplementedError(f"Not implemented for type {self.dtype}") + raise NotImplementedError(f"Not implemented for type {dtype}") self = op.Cast(self, to=unsigned_dtype) other = op.Cast(other, to=unsigned_dtype) @@ -1274,6 +1291,22 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: return op.Cast(result, to=signed_dtype) +@torch_op( + ("aten::bitwise_left_shift.Tensor_Scalar", "aten::__lshift__.Scalar"), trace_only=True +) +def aten_bitwise_left_shift_tensor_scalar(self: TInt, other: int) -> TInt: + """bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_left_shift(self, other_tensor) + + +@torch_op("aten::bitwise_left_shift.Scalar_Tensor", trace_only=True) +def aten_bitwise_left_shift_scalar_tensor(self: int, other: TInt) -> TInt: + """bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_left_shift(self_tensor, other) + + @torch_op("aten::bitwise_not", trace_only=True) def aten_bitwise_not(self: TTensor) -> TTensor: """bitwise_not(Tensor self) -> Tensor""" @@ -1288,8 +1321,6 @@ def aten_bitwise_not(self: TTensor) -> TTensor: @torch_op( ( "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", "_operator::or_", ), trace_only=True, @@ -1297,45 +1328,62 @@ def aten_bitwise_not(self: TTensor) -> TTensor: def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None - if self.dtype.is_integer(): + if dtype.is_integer(): return op.BitwiseOr(self, other) - if self.dtype == ir.DataType.BOOL: + if dtype == ir.DataType.BOOL: return op.Or(self, other) raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") +@torch_op("aten::bitwise_or.Scalar", trace_only=True) +def aten_bitwise_or_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_or(self, other_tensor) + + +@torch_op("aten::bitwise_or.Scalar_Tensor", trace_only=True) +def aten_bitwise_or_scalar_tensor(self: int, other: TTensor) -> TTensor: + """bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_or(self_tensor, other) + + @torch_op( ( "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", "_operator::__rshift__", - "aten::__rshift__.Scalar", ), trace_only=True, ) def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - if self.dtype.bitwidth == 8: + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + + if dtype.bitwidth == 8: unsigned_dtype = ir.DataType.UINT8 signed_dtype = ir.DataType.INT8 mask = ir.tensor(0xFF, dtype=unsigned_dtype) - elif self.dtype.bitwidth == 16: + elif dtype.bitwidth == 16: unsigned_dtype = ir.DataType.UINT16 signed_dtype = ir.DataType.INT16 mask = ir.tensor(0xFFFF, dtype=unsigned_dtype) - elif self.dtype.bitwidth == 32: + elif dtype.bitwidth == 32: unsigned_dtype = ir.DataType.UINT32 signed_dtype = ir.DataType.INT32 mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype) - elif self.dtype.bitwidth == 64: + elif dtype.bitwidth == 64: unsigned_dtype = ir.DataType.UINT64 signed_dtype = ir.DataType.INT64 mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF else: - raise NotImplementedError(f"Not implemented for type {self.dtype}") + raise NotImplementedError(f"Not implemented for type {dtype}") negative = op.Less(self, 0) self = op.Cast(self, to=unsigned_dtype) @@ -1356,24 +1404,50 @@ def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: @torch_op( - ( - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ), - trace_only=True, + ("aten::bitwise_right_shift.Tensor_Scalar", "aten::__rshift__.Scalar"), trace_only=True ) +def aten_bitwise_right_shift_tensor_scalar(self: TInt, other: int) -> TInt: + """bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_right_shift(self, other_tensor) + + +@torch_op("aten::bitwise_right_shift.Scalar_Tensor", trace_only=True) +def aten_bitwise_right_shift_scalar_tensor(self: int, other: TInt) -> TInt: + """bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_right_shift(self_tensor, other) + + +@torch_op("aten::bitwise_xor.Tensor", trace_only=True) def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor: """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" - assert self.dtype == other.dtype - if self.dtype.is_integer(): + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + + if dtype.is_integer(): return op.BitwiseXor(self, other) - if self.dtype == ir.DataType.BOOL: + if dtype == ir.DataType.BOOL: return op.Xor(self, other) raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") +@torch_op("aten::bitwise_xor.Scalar", trace_only=True) +def aten_bitwise_xor_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_xor(self, other_tensor) + + +@torch_op("aten::bitwise_xor.Scalar_Tensor", trace_only=True) +def aten_bitwise_xor_scalar_tensor(self: int, other: TTensor) -> TTensor: + """bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_xor(self_tensor, other) + + @torch_op("aten::blackman_window", trace_only=True) def aten_blackman_window( window_length: int, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 1b0410c27f..754f5e2a25 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -225,6 +225,19 @@ def forward(self, q, k, v): ) _testing.assert_onnx_program(onnx_program) + def test_bitwise_and_scalar(self): + class Model(torch.nn.Module): + def forward(self, x): + return x & 3 + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([1, 2, 3, 4, 5]),), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From 28a8f561957c46131581bc33c8b43508f41b844f Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 10 Oct 2025 09:30:25 -0700 Subject: [PATCH 065/123] Fix constant in constant folding (#2622) This PR moves the processing of constant ops upward to return before node-level shape type inference (including serialization) and optimizer optimization. Essentially, avoiding serializing constant ops (potentially large weights in LLMs) reduces the export time in optimize_ir. Before this PR: Screenshot 2025-10-09 141403 After this PR: Screenshot 2025-10-09 141238 --- onnxscript/optimizer/_constant_folding.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 8317d2be63..9a740c783c 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -76,7 +76,7 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool: def _process_constant_node(node: ir.Node) -> None: """Sets const_value of output value of a Constant op node.""" - if node.op_type != "Constant" or node.domain != "": + if not _is_onnx_op(node, "Constant"): return if len(node.attributes) != 1: return @@ -1099,8 +1099,12 @@ def process_node(self, node: ir.Node) -> Replacement | None: self._modified = True # TODO(rama): consider merging type/other info from both values + # Propagate const_value, and manually find out shape and type + # to avoid potentially expensive shape inference on large tensors. + if _is_onnx_op(node, "Constant"): + _process_constant_node(node) # Do incremental shape inference - if self.shape_inference and not _is_control_flow_op(node): + elif self.shape_inference and not _is_control_flow_op(node): self._do_inference(node) if node.domain not in self._opset_imports: @@ -1118,6 +1122,10 @@ def process_node(self, node: ir.Node) -> Replacement | None: output = [output] return Replacement(output, context.nodes) + if _is_onnx_op(node, "Constant"): + logger.debug("Skipping constant folding for Constant node %r", node.name) + return None + if _is_control_flow_op(node): logger.info( "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet", @@ -1137,10 +1145,6 @@ def process_node(self, node: ir.Node) -> Replacement | None: ) return None - if _is_onnx_op(node, "Constant"): - _process_constant_node(node) - return None - if any(x.is_graph_input() for x in node.inputs if x is not None): logger.info( "Skipping constant folding for node %r because it is graph input to preserve graph signature", From 071ff1eb833defcb25c7eefa69917372a69e11ce Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 10 Oct 2025 11:19:10 -0700 Subject: [PATCH 066/123] Create helper for comparing semantic equivalence of shapes (#2620) This pull request introduces new utility functions for comparing shapes and dimensions in the intermediate representation (IR) utilities, and refactors existing rewrite rules to use these new utilities. The goal is to improve semantic correctness and code clarity when checking shape and dimension equality, especially in the presence of symbolic or unknown values. Key changes: **New IR utility functions:** * Added `same_shape` and `same_dim` functions to `_ir_utils.py` for more robust and semantically correct comparison of shapes and dimensions, accounting for unknown or symbolic values. **Refactoring of rewrite rules to use new utilities:** * Updated `_collapse_slices.py` and `_redundant_scatter_nd.py` to use `_ir_utils.same_shape` and `_ir_utils.same_dim` instead of direct equality checks or previous logic, ensuring that shape and dimension comparisons are handled consistently and correctly. [[1]](diffhunk://#diff-bd2dba53e1a4b4fb79975f7bceacf4b1c5b0b38a10d953af1e18a0b7af6c1050L85-R88) [[2]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL57-R57) [[3]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL90-R90) **Code consistency improvements:** * Standardized imports in affected files to use `_ir_utils` consistently, replacing previous aliasing or direct imports. [[1]](diffhunk://#diff-bd2dba53e1a4b4fb79975f7bceacf4b1c5b0b38a10d953af1e18a0b7af6c1050L8-R8) [[2]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL23-R23) [[3]](diffhunk://#diff-47bc4cbfc2fee996791be5a58bf9447dd44dd833e540139b5cd18b807757be4aL44-R44) --------- Signed-off-by: Justin Chu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/rewriter/_ir_utils.py | 24 +++++++++++++++++++ .../rewriter/rules/common/_collapse_slices.py | 10 +++----- .../rules/common/_redundant_scatter_nd.py | 8 +++---- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 91c3308bc2..953d5f33d5 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -152,3 +152,27 @@ def get_dim(value: ir.Value | None, dim: int) -> ir.SymbolicDim | int | None: if dim < 0 or dim >= shape.rank(): return None return shape[dim] + + +def same_shape(shape1: ir.Shape | None, shape2: ir.Shape | None) -> bool: + """Check if two shapes are semantically the same.""" + if shape1 is None or shape2 is None: + return False + + # If any dim is unknown, the shapes are not the same + if shape1.has_unknown_dim() or shape2.has_unknown_dim(): + return False + + return shape1 == shape2 + + +def same_dim(dim1: ir.SymbolicDim | int, dim2: ir.SymbolicDim | int) -> bool: + """Check if two dimensions are semantically the same.""" + if type(dim1) is not type(dim2): + return False + if isinstance(dim1, int) and isinstance(dim2, int): + return dim1 == dim2 + assert isinstance(dim1, ir.SymbolicDim) and isinstance(dim2, ir.SymbolicDim) + if dim1.value is None or dim2.value is None: + return False + return dim1.value == dim2.value diff --git a/onnxscript/rewriter/rules/common/_collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py index eda8547037..21b2694b82 100644 --- a/onnxscript/rewriter/rules/common/_collapse_slices.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -5,7 +5,7 @@ import logging from onnxscript import ir -from onnxscript.rewriter._ir_utils import is_singleton_value +from onnxscript.rewriter import _ir_utils from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet logger = logging.getLogger(__name__) @@ -82,14 +82,10 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ if data.shape is None or slice_output.shape is None: return False - if not is_singleton_value(steps, 1): + if not _ir_utils.is_singleton_value(steps, 1): return False - # If any dim is unknown, the shapes are not the same - if data.shape.has_unknown_dim() or slice_output.shape.has_unknown_dim(): - return False - - return data.shape == slice_output.shape + return _ir_utils.same_shape(data.shape, slice_output.shape) # Register the rewrite rules diff --git a/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py index cca5f36558..09c5db7735 100644 --- a/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py @@ -20,7 +20,7 @@ import onnx_ir as ir import onnxscript.rewriter -from onnxscript.rewriter import _ir_utils as ir_utils +from onnxscript.rewriter import _ir_utils from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet @@ -41,7 +41,7 @@ def check(self, context, data, axis, transposed_data, **_): # Check that updated-indices represent the full range of the first dimension of the transposed data. # That is: check that the data.shape[axis] matches transposed_data.shape[0]. result = onnxscript.rewriter.MatchResult() - axis_value = ir_utils.get_singleton_value(axis) + axis_value = _ir_utils.get_singleton_value(axis) if not isinstance(axis_value, int): return result.fail("Axis value must be a constant integer.", axis) shape: ir.Shape | None = data.shape @@ -54,7 +54,7 @@ def check(self, context, data, axis, transposed_data, **_): "Transposed data shape is not statically known.", transposed_data ) actual_dim_value = transposed_data_shape[0] - if updated_dim_value != actual_dim_value: + if not _ir_utils.same_dim(updated_dim_value, actual_dim_value): # The first dimension of the transposed data does not match the updated dimension, # so we cannot apply this rule. return result.fail( @@ -87,7 +87,7 @@ def check(self, context, data, indices, updates, **_): return result.fail("The value 'data' shape is not statically known.", data) if updates.shape is None: return result.fail("The value 'updates' shape is not statically known.", updates) - if data.shape != updates.shape: + if not _ir_utils.same_shape(data.shape, updates.shape): return result.fail( "The shape of 'data' and 'updates' are different.", [data, updates] ) From 32a61f4bd5c7e51c9c5aaa562dc4dc90f67bf6b9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 13 Oct 2025 09:31:20 -0700 Subject: [PATCH 067/123] [torchlib] Deprecate Rank and IsScalar (#2624) Deprecate Rank and IsScalar and remove all usages. Do not remove the definitions because older versions of PyTorch assumes their existance. --------- Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/common.py | 14 +++- .../function_libs/torch_lib/ops/core.py | 65 +++++++------------ onnxscript/function_libs/torch_lib/ops/nn.py | 12 ++-- .../torch_lib/ops_test_common.py | 16 ----- 4 files changed, 39 insertions(+), 68 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index b3ebbc1c53..38544b59ba 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -28,14 +28,24 @@ @onnxscript.script(common_opset) def Rank(input: tensor_typing.TTensor) -> INT64: - """Take the rank of the input tensor.""" + """Deprecated. + + NOTE: Do not remove, for backward compatibility with PyTorch < 2.10. + + Take the rank of the input tensor. + """ return op.Size(op.Shape(input)) @onnxscript.script(common_opset) def IsScalar(input: tensor_typing.TTensor) -> BOOL: - """Return whether the input has rank 0, or is a scalar.""" + """Deprecated. + + NOTE: Do not remove, for backward compatibility with PyTorch < 2.10. + + Return whether the input has rank 0, or is a scalar. + """ return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0)) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5127f3f9f6..e088b887f6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -54,7 +54,6 @@ _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi -Rank = common_ops.Rank @torch_op("aten::_local_scalar_dense", trace_only=True) @@ -947,11 +946,11 @@ def reshape_to_1d(tensor): return op.SequenceMap(self, body=reshape_to_1d) -@torch_op("aten::atleast_2d") +@torch_op("aten::atleast_2d", trace_only=True) def aten_atleast_2d(self: TTensor) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" - if Rank(self) <= 1: + if len(self.shape) <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1])) return op.Identity(self) @@ -975,7 +974,7 @@ def reshape_to_2d(tensor): def aten_atleast_3d(self: TTensor) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" - rank = Rank(self) + rank = len(self.shape) if rank <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) elif rank == 2: @@ -1820,39 +1819,21 @@ def aten_conj_physical(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::constant_pad_nd") -def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTensor: +@torch_op("aten::constant_pad_nd", trace_only=True) +def aten_constant_pad_nd(self: TTensor, pad: Sequence[INT64], value: float = 0.0) -> TTensor: """constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor""" # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # assume zero-dimensions in the beginning - # rank = len(self.shape) # rank must be scalar - # paddings = list(pad[:]) + [0] * (rank * 2 - len(pad)) + rank = len(self.shape) + paddings = list(pad) + [0] * (rank * 2 - len(pad)) # reverse order and collate first beginnings and then ends - # paddings = paddings[-2::-2] + paddings[-1::-2] - - neg_1 = op.Constant(value_ints=[-1]) - - zero_count = op.Sub(op.Mul(Rank(self), 2), op.Size(pad)) - zero_count = op.Reshape(zero_count, neg_1) - zero = op.Constant(value_ints=[0]) - zeros = op.Expand(zero, zero_count) - torch_paddings = op.Concat(pad, zeros, axis=0) - size_d = op.Size(torch_paddings) - steps = op.Constant(value_ints=[-2]) - - starts = steps - ends = op.Sub(starts, size_d) - odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - - starts = neg_1 - ends = op.Sub(starts, size_d) - even_elements = op.Slice(torch_paddings, starts, ends, zero, steps) + paddings = paddings[-2::-2] + paddings[-1::-2] + constant_value = op.Constant(value=ir.tensor(value, dtype=self.dtype)) - onnx_padding = op.Concat(odd_elements, even_elements, axis=0) - return op.Pad(self, onnx_padding, value) + return op.Pad(self, paddings, constant_value) @torch_op("aten::contiguous", trace_only=True) @@ -3996,7 +3977,7 @@ def reshape_to_atleast_2d(tensor): result = op.ConcatFromSequence(tensors_atleast_2d, axis=1, new_axis=0) # hstack expects a non-empty sequence of tensors. So we don't need to check for length - rank_1d_or_less = op.Less(Rank(op.SequenceAt(tensors, 0)), 2) + rank_1d_or_less = op.Less(op.Size(op.Shape(op.SequenceAt(tensors, 0))), 2) if rank_1d_or_less: result = op.Reshape(result, op.Constant(value_ints=[-1])) return result @@ -6076,7 +6057,7 @@ def aten_native_group_norm( norm = op.Reshape(norm, op.Shape(input), allowzero=True) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy - input_rank = Rank(input) + input_rank = len(input.shape) axes_unsqueeze = op.Range(1, input_rank - 1, 1) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) @@ -8229,7 +8210,7 @@ def aten_symeig( def aten_t(self: TTensor) -> TTensor: """t(Tensor(a) self) -> Tensor(a)""" - rank = Rank(self) + rank = len(self.shape) if rank == 2: result = op.Transpose(self, perm=[1, 0]) else: @@ -8312,26 +8293,24 @@ def aten_threshold_backward( raise NotImplementedError() -@torch_op("aten::tile") -def aten_tile(self: TTensor, dims: INT64) -> TTensor: +@torch_op("aten::tile", trace_only=True) +def aten_tile(self: TTensor, dims: Sequence[int]) -> TTensor: """tile(Tensor self, int[] dims) -> Tensor""" - self_rank = Rank(self) - dims_rank = op.Size(dims) - diff = op.Sub(self_rank, dims_rank) + self_rank = len(self.shape) + dims_rank = len(dims) + diff = self_rank - dims_rank if diff > 0: # dims is shorter than self.shape # pad dims with 1 - diff_1d = op.Reshape(diff, op.Constant(value_ints=[1])) - exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) - dims = op.Concat(exapnd_ones, dims, axis=0) + exapnd_ones = [1] * diff + dims = [*exapnd_ones, *dims] - if diff < 0: + elif diff < 0: # dims is longer than self.shape # pad self.shape with 1 - diff_1d = op.Reshape(op.Abs(diff), op.Constant(value_ints=[1])) - exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) + exapnd_ones = op.Constant(value_ints=[1] * (-diff)) self_shape = op.Shape(self) self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0) self = op.Reshape(self, self_final_shape, allowzero=True) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 2a7a46ec28..4f81cc7907 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -18,7 +18,6 @@ from typing import Optional, Sequence, Tuple, TypeVar, Union from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir -from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, @@ -32,7 +31,6 @@ from onnxscript.onnx_types import TensorType _MATH_PI = math.pi -Rank = common_ops.Rank _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 @@ -576,7 +574,7 @@ def aten_group_norm( norm = op.Reshape(norm, op.Shape(input)) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy - input_rank = Rank(input) + input_rank = len(input.shape) one = op.Constant(value_int=1) axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) @@ -999,7 +997,7 @@ def _aten_max_pool_onnx( ceil_mode: bool, unbatched_rank: int, ) -> TFloatOrUInt8: - self_rank_is_unbatched_rank = Rank(self) == unbatched_rank + self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 self = op.Unsqueeze(self, [0]) @@ -1133,7 +1131,7 @@ def _aten_max_pool_with_indices_onnx( n_dims_zero: Sequence[int], n_dims_axes: Sequence[int], ) -> Tuple[TFloatOrUInt8, INT64]: - self_rank_is_unbatched_rank = Rank(self) == unbatched_rank + self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank if self_rank_is_unbatched_rank: self = op.Unsqueeze(self, axes=[0]) @@ -1362,11 +1360,11 @@ def aten_nll_loss( ) -> TFloat: """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor""" - self_rank_is_1 = Rank(self) == 1 + self_rank_is_1 = len(self.shape) == 1 if self_rank_is_1: # self rank should be at least 2 self = op.Unsqueeze(self, [0]) - rank_target = Rank(target) + rank_target = len(target.shape) if rank_target == 0: # target rank should be at least 1 target = op.Unsqueeze(target, [0]) diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index decaddddf4..99594ee17e 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -26,7 +26,6 @@ import numpy as np import onnx -import onnx_ir.passes.common as common_passes import onnxruntime as ort import onnxruntime.capi.onnxruntime_pybind11_state import pytest @@ -37,7 +36,6 @@ import onnxscript import onnxscript.evaluator from onnxscript import ir -from onnxscript.function_libs.torch_lib.ops import common as common_ops from tests.function_libs.torch_lib import error_reproduction T = TypeVar("T") @@ -412,19 +410,6 @@ def _format_model_and_input_information(onnx_model, inputs): } -def add_torchlib_common_imports(model: ir.Model) -> None: - """Hack to add torchlib common imports to the model.""" - - model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 - rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) - is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto()) - model.functions[rank_func.identifier()] = rank_func - model.functions[is_scalar_func.identifier()] = is_scalar_func - removal_pass = common_passes.RemoveUnusedFunctionsPass() - assert removal_pass.in_place - removal_pass(model) - - def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: """Checks if the dtype is compatible with the schema. @@ -593,7 +578,6 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, proto = onnxscript_function.to_function_proto() ir_function = ir.serde.deserialize_function(proto) onnx_model.functions[identifier] = ir_function - add_torchlib_common_imports(onnx_model) # Make sure the model is valid model_proto = ir.to_proto(onnx_model) try: From dd14682b08518d88ccffb659003b750aeab36491 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 13 Oct 2025 15:55:02 -0700 Subject: [PATCH 068/123] [torchlib] Fix operator add (#2630) Operator add may take in python scalar sometimes (and doesn't have dtype). This PR splits the implementation out so that it is handled differently. Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e088b887f6..2ec3b8f207 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -148,6 +148,11 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: return op.Add(self, other) +@torch_op(("_operator::add"), trace_only=True) +def operator_add(self: TTensor, other: TTensor) -> TTensor: + return op.Add(self, other) + + @torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True) def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -5567,7 +5572,7 @@ def aten_msort(self: TensorType) -> TensorType: @torch_op( - ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), + ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), trace_only=True, ) def aten_mul(self: TTensor, other: TTensor) -> TTensor: @@ -5579,6 +5584,11 @@ def aten_mul(self: TTensor, other: TTensor) -> TTensor: return op.Mul(self, other) +@torch_op("_operator::mul", trace_only=True) +def operator_mul(self: TTensor, other: TTensor) -> TTensor: + return op.Mul(self, other) + + @torch_op( ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), trace_only=True, From f44b314871e36c107bef76ac0c76b102e1623507 Mon Sep 17 00:00:00 2001 From: NoRaincheck <2498638+NoRaincheck@users.noreply.github.com> Date: Wed, 15 Oct 2025 01:33:13 +1100 Subject: [PATCH 069/123] Allow `opset_version` to be set explicitly when exporting (#2615) I think it would be nice to explicitly set opset_version when exporting, particularly when a custom/particular Opset is being used and the default opset can't be inferred. Example: ```py from onnxscript import script from onnxscript import opset15 as op from onnxscript.values import Opset import numpy as np from onnxscript import STRING from onnxruntime import InferenceSession ai_onnx = Opset("ai.onnx.ml", version=2) @script(ai_onnx, default_opset = op) def label_encoder(X: STRING["D"]): Y = ai_onnx.LabelEncoder(X, keys_strings=["a", "b", "c"], values_int64s=[0, 1, 2], default_int64=42) # Y = Y + 0.0 # to force opset version downgrade return Y print(label_encoder(np.array(["a", "b", "c"]))) session = InferenceSession(label_encoder.to_model_proto(ir_version=10).SerializeToString()) for key, value in {"a": 0, "b": 1, "c": 2}.items(): assert label_encoder(np.array([key]))[0] == value assert session.run(None, {"X": np.array([key])})[0] == value ``` This currently errors with ```sh Traceback (most recent call last): File "/Users/XXX/Development/projects/jet/test_onnxscript_label.py", line 25, in session = InferenceSession(label_encoder.to_model_proto(ir_version=10).SerializeToString()) File "/Users/XXX/Development/projects/jet/.venv/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 472, in __init__ self._create_inference_session(providers, provider_options, disabled_optimizers) File "/Users/XXX/Development/projects/jet/.venv/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 552, in _create_inference_session sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model) onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : /Users/runner/work/1/s/onnxruntime/core/graph/model_load_utils.h:56 void onnxruntime::model_load_utils::ValidateOpsetForDomain(const std::unordered_map &, const logging::Logger &, bool, const std::string &, int) ONNX Runtime only *guarantees* support for models stamped with official released onnx opset versions. Opset 23 is under development and support for this is limited. The operator schemas and or other functionality may change before next ONNX release and in this case ONNX Runtime will not guarantee backward compatibility. Current official support for domain ai.onnx is till opset 22. ``` To force it to work in the current state, one would have to do: ```py @script(ai_onnx, default_opset = op) def label_encoder(X: STRING["D"]): Y = ai_onnx.LabelEncoder(X, keys_strings=["a", "b", "c"], values_int64s=[0, 1, 2], default_int64=42) Y = Y + 0.0 # to force opset version downgrade/inserted from the to_model_proto call return Y ``` To force the opset to be downgraded, since the `default_opset` is never called. Happy to be challenged if there is a better way. I can imagine something weird/unintended might occur if the user sets `default_opset` to something other than what is defined in `@script(..., default_opset=)` but that generally shouldn't be a problem? --- onnxscript/irbuilder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 76023ea002..4274bf2062 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -321,6 +321,7 @@ def to_model_proto( input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, value_infos: dict[str, ONNXType] | None = None, + opset_version: int | None = None, **kwargs, ) -> onnx.ModelProto: """Converts this instance into a `onnx.ModelProto`. @@ -336,6 +337,8 @@ def to_model_proto( are set to be of the corresponding type in this list. value_infos: A dictionary mapping intermediate variable names to ONNX types. Used to set value_info for intermediate variables. + opset_version: The standard opset version to use for the model if it + cannot be inferred. Otherwise defaults to the current opset version. kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. Returns: @@ -393,8 +396,8 @@ def to_proto(f): if "" not in opsets: # No operator is using the standard opset. - # A default value is given. - opsets[""] = onnx_opset_version() + # Use the specified version if provided or the default value. + opsets[""] = opset_version if opset_version is not None else onnx_opset_version() if "ir_version" not in kwargs: kwargs["ir_version"] = select_ir_version(opsets[""]) From b6a2d02c00b201a323d6a8c2ceb61e0481156691 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 09:24:46 -0700 Subject: [PATCH 070/123] Remove redundant registration of operator::add (#2631) I forgot to remove the previous registration. --- .../function_libs/torch_lib/ops/core.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2ec3b8f207..36f2a70f8c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -132,7 +132,7 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) +@torch_op("aten::add.Tensor", trace_only=True) def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -148,7 +148,15 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: return op.Add(self, other) -@torch_op(("_operator::add"), trace_only=True) +@torch_op("aten::add.Scalar", trace_only=True) +def aten_add_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: + """add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" + + other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_add(self, other, alpha=alpha) + + +@torch_op("_operator::add", trace_only=True) def operator_add(self: TTensor, other: TTensor) -> TTensor: return op.Add(self, other) @@ -8113,9 +8121,7 @@ def aten_std_mean_correction( @torch_op( ( "aten::sub.Tensor", - "aten::sub.Scalar", "aten::subtract.Tensor", - "aten::subtract.Scalar", "_operator::sub", ), trace_only=True, @@ -8128,6 +8134,14 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Sub(self, other) +@torch_op(("aten::sub.Scalar", "aten::subtract.Scalar"), trace_only=True) +def aten_sub_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: + """sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" + + other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_sub(self, other, alpha=alpha) + + @torch_op( ( "aten::sub.Tensor", From 811937ce3732536b26daf4969934e0d72d5239bd Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 14 Oct 2025 09:55:17 -0700 Subject: [PATCH 071/123] Merge shapes only in identity op and nodel-level shape inference (#2623) node-level shape inference covers the forward shape inference, and relying on the logic of constant-folding, we only need `_merge_shapes` in identity op to have backward shape inference. --- onnxscript/optimizer/_constant_folding.py | 29 ++++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 9a740c783c..927d8e47f6 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -496,13 +496,6 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: if input is None or output is None: return None - # TODO(rama): Parts of the following logic (implementing type/shape inference - # for Cast op) should be unnecessary. Generic incremental shape-inference - # should handle this. Only the optimization to eliminate redundant Cast ops - # should be needed here. - - output.shape = _merge_shapes(output.shape, input.shape) - input_dtype = _get_input_element_type(node, 0) output_dtype = _get_int_attribute(node, "to", None) if output_dtype is not None: @@ -608,6 +601,7 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] output = node.outputs[0] if input is not None and output is not None: + # NOTE: backward shape inference input.shape = _merge_shapes(input.shape, output.shape) if input.type is None: input.type = output.type @@ -904,7 +898,11 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None -def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None: +def _merge_shapes( + preferred_shape: ir.Shape | None, other_shape: ir.Shape | None +) -> ir.Shape | None: + """Merge two shapes, preferring dimensions from preferred_shapes.""" + def merge_dims(dim1, dim2): if dim1 == dim2: return dim1 @@ -916,13 +914,15 @@ def merge_dims(dim1, dim2): return dim2 return dim1 - if shape1 is None: - return shape2 - if shape2 is None: - return shape1 - if len(shape1) != len(shape2): + if preferred_shape is None: + return other_shape + if other_shape is None: + return preferred_shape + if len(preferred_shape) != len(other_shape): raise ValueError("Shapes must have the same rank.") - return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) + return ir.Shape( + [merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)] + ) def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None: @@ -1029,6 +1029,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: inferred_shape = ir.serde.deserialize_type_proto_for_shape( inferred_type ) + # NOTE: forward shape inference output.shape = _merge_shapes(output.shape, inferred_shape) output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: From 75b3d42e3416a0c18cf90136d6ca582fa4abc782 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 15 Oct 2025 11:01:40 -0700 Subject: [PATCH 072/123] Fix GQA fusion to produce present key/value (#2634) Output present key value from the Attention op because past key value is provided. Previously the Attention op created would consume past key/value but not produce present key/value, which is not correct for ORT. image Replaces https://github.com/microsoft/onnxscript/pull/2632 Signed-off-by: Justin Chu --- onnxscript/rewriter/rules/fusion/_gqa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/rules/fusion/_gqa.py b/onnxscript/rewriter/rules/fusion/_gqa.py index 8d6f156ed5..c12dcc7140 100644 --- a/onnxscript/rewriter/rules/fusion/_gqa.py +++ b/onnxscript/rewriter/rules/fusion/_gqa.py @@ -52,7 +52,7 @@ def pattern( _outputs=["attention_BHSDh"], ) - return attention_BHSDh + return attention_BHSDh, present_key_BHkvStD, present_value_BHkvStD def check( self, @@ -103,6 +103,7 @@ def rewrite( past_key_BHkvSpD, past_value_BHkvSpD, **original_attrs, + _outputs=3, ) From 8089bc7b641d307709a79ec15c0b7980c2793384 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 16 Oct 2025 13:19:40 -0700 Subject: [PATCH 073/123] Add RMS Normalization rule variant (#2638) Add RMS Normalization rule variant to support different order of multiplying by scale. Signed-off-by: Ganesan Ramalingam --- .../rewriter/ort_fusions/rms_normalization.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index de6e51a5c0..6e9810ce63 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -31,6 +31,10 @@ class RmsNormFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, _mul_order: bool): + super().__init__(name) + self._mul_order = _mul_order + def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) x_square = op.Pow(x, 2.0) @@ -42,7 +46,11 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) # To support float16, we need to ensure the scale is casted or not. scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale]) - return op.Mul(scale, normalized) + # Workaround: can't use OrValue for final (returned) value + if self._mul_order: + return op.Mul(normalized, scale) + else: + return op.Mul(scale, normalized) def check( self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ @@ -77,8 +85,10 @@ def rewrite(self, op, x, scale, epsilon, **_): ) -_rule = RmsNormFusion.rule() -rms_normalization_rules = [_rule] +_rule1 = RmsNormFusion.rule("RmsNormFusion1", _mul_order=False) +_rule2 = RmsNormFusion.rule("RmsNormFusion2", _mul_order=True) + +rms_normalization_rules = [_rule1, _rule2] rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) From dd8cb694e08f1b0f127a7ce21e46a929b34678de Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 16 Oct 2025 14:56:08 -0700 Subject: [PATCH 074/123] [DRAFT] Extend GQA fusion for Gemma3 (#2639) Gemma3 applies an extra (simplified) normalization to query and key before the rotary embedding. Extend GQA fusion to handle this. TODO: will add test-case separately. Signed-off-by: Ganesan Ramalingam Co-authored-by: Justin Chu --- onnxscript/rewriter/ort_fusions/gqa.py | 39 ++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 5fff910bcf..f1971904f0 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -166,11 +166,23 @@ def pattern( # Transpose from (B, S, H, D/H) to (B, H, S, D/H) query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + # Gemma variant uses normalization of query/key before rotary embedding: + query_BHSDh_normalized = op.SimplifiedLayerNormalization( + query_BHSDh, pattern.ANY_VALUE, axis=-1, _outputs=["query_BHSDh_normalized"] + ) + query_BHSDh = pattern.OrValue([query_BHSDh, query_BHSDh_normalized]) + # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + # Gemma variant uses normalization of query/key before rotary embedding: + key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( + key_BHkvSDh, pattern.ANY_VALUE, axis=-1, _outputs=["key_BHkvSDh_normalized"] + ) + key_BHkvSDh = pattern.OrValue([key_BHkvSDh, key_BHkvSDh_normalized]) + # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) @@ -316,6 +328,10 @@ def rewrite( cos, sin, mask, + query_BSHDh, + key_BSHkvDh, + query_BHSDh_normalized=None, + key_BHkvSDh_normalized=None, **_, ): # Note that the following optimization is specific to current ORT GenAI attention-mask @@ -335,6 +351,29 @@ def rewrite( seqlens_k = op.Cast(seqlens_k_int64, to=ir.DataType.INT32) max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0) total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d) + + if query_BHSDh_normalized is not None: + # We apply normalization without the transpose, which is fused into GQA + norm_node = query_BHSDh_normalized.producer() + norm_attrs = norm_node.attributes + norm_scale = norm_node.inputs[1] + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, norm_scale, **norm_attrs + ) + reshape_BSHDh_to_BSD = op.Constant(value_ints=[0, 0, -1]) + query_BSD = op.Reshape(query_BSHDh_normalized, reshape_BSHDh_to_BSD) + + if key_BHkvSDh_normalized is not None: + # We apply normalization without the transpose, which is fused into GQA + norm_node = key_BHkvSDh_normalized.producer() + norm_attrs = norm_node.attributes + norm_scale = norm_node.inputs[1] + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, norm_scale, **norm_attrs + ) + reshape_BSHkvDh_to_BSDkv = op.Constant(value_ints=[0, 0, -1]) + key_BSDkv = op.Reshape(key_BSHkvDh_normalized, reshape_BSHkvDh_to_BSDkv) + return op.GroupQueryAttention( query_BSD, key_BSDkv, From 55f5b82dc9fa2f333d04d05e0f1519f50635e86e Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 16 Oct 2025 15:21:40 -0700 Subject: [PATCH 075/123] Bump version to 0.5.5 (#2640) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7d8568351b..d1d899fa33 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.4 +0.5.5 From 80f28c9843e1ce6909f22acc017d0270f753e4ff Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 17 Oct 2025 16:00:54 -0700 Subject: [PATCH 076/123] Add Gemma3 GQA fusion test case (#2642) Add Gemma3 GQA fusion test case: variant with SimplifiedLayerNormalization applied to query and key. --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gqa_test.py | 275 ++++++++++++++++++++ 1 file changed, 275 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 64cb84d18e..c7ed888142 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -361,6 +361,281 @@ def test_fusion(self): assert_allclose(outputs3, source_model_outputs) +class GemmaGQAFusionTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Config parameters + self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 + self.head_size = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + + # Computed config parameters + self.hidden_size = self.head_size * self.num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + assert (self.num_heads % self.kv_num_heads) == 0, ( + "num_heads must be divisible by kv_num_heads" + ) + self.num_groups = self.num_heads // self.kv_num_heads + self.total_seqlen = self.seqlen + self.past_seqlen + + # Abbreviations + B = self.batchsize + S = self.seqlen + P = self.past_seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + total_seqlen = S + P + max_seqlen = total_seqlen + + # Input/output types have some dimensions as dynamic (even though the + # test case instance has specific values above). + self.input_types = ( + FLOAT["B", "S", D], # query + FLOAT["B", "S", Dkv], # key + FLOAT["B", "S", Dkv], # value + FLOAT["B", Hkv, "P", Dh], # past_key + FLOAT["B", Hkv, "P", Dh], # past_value + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + FLOAT["Dh"], # query_scale + FLOAT["Dh"], # key_scale + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "T", Dh], # present_key + FLOAT["B", Hkv, "T", Dh], # present_value + ) + + self.inputs = { + "query": np.random.rand(B, S, D).astype(np.float32), + "key": np.random.rand(B, S, Dkv).astype(np.float32), + "value": np.random.rand(B, S, Dkv).astype(np.float32), + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "query_scale": np.random.rand(Dh).astype(np.float32), + "key_scale": np.random.rand(Dh).astype(np.float32), + } + + def source_model_script(self): + scale_factor = math.sqrt(math.sqrt(self.head_size)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.head_size] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + plus_1 = [1] + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scale): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) + + shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + query_BHSDh_normalized = op.SimplifiedLayerNormalization( + query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( + key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + # Concat past and do rotary embedding + position_ids_1d = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids_1d, [0]) + position_ids_k = op.Unsqueeze(position_ids_1d, [0]) + + query_BHSDh_rope = msft_op.RotaryEmbedding( + query_BHSDh_normalized, + position_ids_q, + cos, + sin, + ) + key_BHkvSDh_rope = msft_op.RotaryEmbedding( + key_BHkvSDh_normalized, + position_ids_k, + cos, + sin, + ) + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + + # Now, expand from shared heads to all heads + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate causal mask: + # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + seq_len = op.Shape(query, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len_0D = op.Squeeze(past_seq_length) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But duplicating same logic here. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_val = op.Constant(value=minval_tp) + mask_all_min = op.Expand(min_val, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + shape_B111 = op.Concat(B, plus_1, plus_1, plus_1, axis=0) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh_rope, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask_B1ST) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to BSD format + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh + + return gqa + + def test_fusion(self): + """Test that GQA fusion is successful on source model and produces an equivalent model.""" + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + # Some shapes need to be present in input model for fusion to be successful. + # (i) Shape inference doesn't handle handle ORT contrib ops. + # (ii) TODO: investigate if Reshape(..., ["B", "S", -1, Dh]) handled precisely + # by shape inference. + query_BHSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "query_BHSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.seqlen, self.head_size], + ) + key_BHkvSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "key_BHkvSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.kv_num_heads, self.seqlen, self.head_size], + ) + query_BSHDh_value_info = onnx.helper.make_tensor_value_info( + "query_BSHDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.num_heads, self.head_size], + ) + key_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "key_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) + key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info( + "key_BSHkvDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.kv_num_heads, self.head_size], + ) + key_transposed_value_info = onnx.helper.make_tensor_value_info( + "key_transposed", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.head_size, self.total_seqlen], + ) + value_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "value_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) + source_model.graph.value_info.extend( + [ + query_BHSDh_rope_value_info, + key_BHkvSDh_rope_value_info, + query_BSHDh_value_info, + key_BHSDh_value_info, + key_BSHkvDh_value_info, + key_transposed_value_info, + value_BHSDh_value_info, + ] + ) + + source_model_ir = ir.serde.from_proto(source_model) + inferred_model = shape_inference.infer_shapes(source_model_ir) + onnxscript.optimizer.optimize(inferred_model) + + count = fuse_sdpa(inferred_model, debug=True) + self.assertGreater(count, 0) + + count = fuse_gqa(inferred_model, debug=True) + self.assertGreater(count, 0) + + fused_model = ir.serde.to_proto(inferred_model) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs3 = session.run(None, inputs) + + self.assertEqual(len(outputs3), len(source_model_outputs)) + assert_allclose(outputs3, source_model_outputs) + + class GQAFusionTest2(unittest.TestCase): @unittest.skip("Needs too much memory.") def test_phi4lm(self): From 8a94ad646440f462dd9ae1de6b303fe5f7b7f564 Mon Sep 17 00:00:00 2001 From: Ayoub BIH <89558574+AyoubMDL@users.noreply.github.com> Date: Sat, 18 Oct 2025 21:26:32 +0200 Subject: [PATCH 077/123] [Rewriter]: introduce remove_optional_bias (#2635) Fixes https://github.com/microsoft/onnxscript/issues/2547. I've kept the same ops as in https://github.com/microsoft/onnxscript/pull/2555. --- onnxscript/rewriter/__init__.py | 2 + onnxscript/rewriter/rules/common/__init__.py | 10 + .../rules/common/_remove_optional_bias.py | 123 +++++++++ .../common/_remove_optional_bias_test.py | 237 ++++++++++++++++++ 4 files changed, 372 insertions(+) create mode 100644 onnxscript/rewriter/rules/common/_remove_optional_bias.py create mode 100644 onnxscript/rewriter/rules/common/_remove_optional_bias_test.py diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fc000dc176..75f43bf3ea 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -41,6 +41,7 @@ _min_max_to_clip, _no_op, _redundant_scatter_nd, + _remove_optional_bias, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) @@ -55,6 +56,7 @@ *_redundant_scatter_nd.rules, *_fuse_pad_into_conv.rules, *_fuse_batchnorm.rules, + *_remove_optional_bias.rules, ) diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 14ed3587f3..76d9e4f4b0 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -34,6 +34,10 @@ "normalize_pad_format_conv_integer_rule", "normalize_pad_format_conv_rule", "one_reshape_matmul_reshape_rule", + "remove_optional_bias_from_conv_rule", + "remove_optional_bias_from_conv_transpose_rule", + "remove_optional_bias_from_gemm_rule", + "remove_optional_bias_from_qlinear_conv_rule", "reshape_reshape_rule", "slice_split_rule", "squeeze_reshape_1d_rule", @@ -121,3 +125,9 @@ no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule, ) +from onnxscript.rewriter.rules.common._remove_optional_bias import ( + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_gemm_rule, + remove_optional_bias_from_qlinear_conv_rule, +) diff --git a/onnxscript/rewriter/rules/common/_remove_optional_bias.py b/onnxscript/rewriter/rules/common/_remove_optional_bias.py new file mode 100644 index 0000000000..ead8a73eab --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_optional_bias.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Remove optional bias when it is all zero from Conv, ConvTranspose, Gemm and QLinearConv operations.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _RemoveOptionalBias(RewriteRuleClassBase): + def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: + node = out.producer() + + return op.op( + self.op_type, + inputs=node.inputs[:-1], + attributes=node.attributes, + ) + + def check(self, context, b: ir.Value, **_) -> MatchResult: + """Condition to check if we need to replace the pattern. + + The pattern is applied only when the bias is all zeros. The bias should be + a constant value (i.e., provided by Constant nodes or initializers). + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Unused + check_result = MatchResult() + + # Check if bias is a constant/initializer + bias_tensor = ir.convenience.get_const_tensor(b) + if bias_tensor is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = bias_tensor.numpy() + if not np.equal(bias_array, 0.0).all(): + return check_result.fail("Bias is not all zeros.") + + return check_result + + +class RemoveOptionalBiasFromConv(_RemoveOptionalBias): + """Remove zero bias from Conv operation.""" + + op_type: ClassVar[str] = "Conv" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Conv(x, w, b, _outputs=["out"]) + + +class RemoveOptionalBiasFromConvTranspose(_RemoveOptionalBias): + """Remove zero bias from ConvTranspose operation.""" + + op_type: ClassVar[str] = "ConvTranspose" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.ConvTranspose(x, w, b, _outputs=["out"]) + + +class RemoveOptionalBiasFromQLinearConv(_RemoveOptionalBias): + """Remove zero bias from QLinearConv operation.""" + + op_type: ClassVar[str] = "QLinearConv" + + def pattern( + self, + op: ir.tape.Tape, + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b: ir.Value, + ) -> ir.Value: + return op.QLinearConv( + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b, + _outputs=["out"], + ) + + +class RemoveOptionalBiasFromGemm(_RemoveOptionalBias): + """Remove zero bias from Gemm operation.""" + + op_type: ClassVar[str] = "Gemm" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Gemm(x, w, b, _outputs=["out"]) + + +remove_optional_bias_from_conv_rule = RemoveOptionalBiasFromConv().rule() +remove_optional_bias_from_conv_transpose_rule = RemoveOptionalBiasFromConvTranspose().rule() +remove_optional_bias_from_qlinear_conv_rule = RemoveOptionalBiasFromQLinearConv().rule() +remove_optional_bias_from_gemm_rule = RemoveOptionalBiasFromGemm().rule() + +rules = RewriteRuleSet( + [ + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_qlinear_conv_rule, + remove_optional_bias_from_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py b/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py new file mode 100644 index 0000000000..4349d7aae3 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker + +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing +from onnxscript.rewriter.rules.common import _remove_optional_bias +from onnxscript.rewriter.rules.common._remove_optional_bias import ( + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_gemm_rule, + remove_optional_bias_from_qlinear_conv_rule, +) + + +class _RemoveOptionalBiasTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20251016) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def _get_test_model( + self, + op_type: str, + input_shape: ir.Shape, + weight_shape: ir.Shape, + zero_bias: bool, + attributes=None, + ): + tape = ir.tape.Tape() + bias_shape = weight_shape[1] if op_type == "ConvTranspose" else weight_shape[0] + output_shape = ir.Shape(("?",) * input_shape.rank()) + + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + + w = tape.initializer( + ir.tensor(self.rng.uniform(-0.5, 0.5, weight_shape).astype(np.float32), name="W") + ) + + if zero_bias: + bias = np.zeros(bias_shape, dtype=np.float32) + else: + bias = self.rng.uniform(-0.5, 0.5, bias_shape).astype(np.float32) + + b = tape.initializer(ir.tensor(bias, name="B")) + y = tape.op( + op_type, + inputs=[x, w, b], + attributes=attributes, + output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + ) + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="test_model", + ), + ir_version=10, + ) + onnx_checker.CheckerPass(True)(ir_model) + return ir_model + + def run_test( + self, + base_model: ir.Model, + input_shape: tuple, + input_dtype=np.float32, + ): + updated_model = self.clone_model(base_model) + count = _remove_optional_bias.rules.apply_to_model(updated_model) + + # Check rule is applied + self.assertEqual(count, 1) + + # Check number of inputs is reduced + self.assertEqual( + len(updated_model.graph[0].inputs), len(base_model.graph[0].inputs) - 1 + ) + + # Prepare inputs + inputs = (self.rng.random(input_shape).astype(input_dtype),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class RemoveOptionalBiasGemmTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_gemm(self): + input_shape = (512, 256) + base_model = self._get_test_model( + op_type="Gemm", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((64, 256)), + zero_bias=True, + attributes={"transB": 1}, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_gemm(self): + input_shape = (512, 256) + base_model = self._get_test_model( + op_type="Gemm", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((64, 256)), + zero_bias=False, + attributes={"transB": 1}, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_gemm_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasGonvTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="Conv", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((16, 3, 3, 3)), + zero_bias=True, + attributes={"strides": (2, 2)}, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="Conv", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((16, 3, 3, 3)), + zero_bias=False, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_conv_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasGonvTransposeTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_conv_transpose(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="ConvTranspose", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((3, 16, 3, 3)), + zero_bias=True, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_conv_transpose(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="ConvTranspose", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((3, 16, 3, 3)), + zero_bias=False, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_conv_transpose_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasQLinearConvTest(_RemoveOptionalBiasTestBase): + def _get_test_model(self, zero_bias): + if zero_bias: + bias = np.zeros((16,), dtype=np.int32) + else: + bias = self.rng.uniform(-5, 5, (16,)).astype(np.int32) + + w = ir.tensor(self.rng.uniform(-5, 5, (16, 3, 3, 3)).astype(np.uint8), name="W") + b = ir.tensor(bias, name="B") + + model = ir.from_onnx_text( + """ + < ir_version: 10, opset_import: ["" : 20] > + test_model (uint8[N, 3, 32, 32] X) => (uint8 [N, ?, ?, ?] Y) + + { + Y = QLinearConv(X, x_scale, x_zero_point, W, w_scale, w_zero_point, y_scale, y_zero_point, B) + } + """, + initializers=[w, b], + ) + onnx_checker.CheckerPass(True)(model) + return model + + def test_successful_remove_optional_bias_qlinear_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model(zero_bias=True) + self.run_test(base_model, input_shape, np.uint8) + + def test_fail_remove_optional_bias_qlinear_conv(self): + base_model = self._get_test_model(zero_bias=False) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_qlinear_conv_rule, "Bias is not all zeros." + ) + + +if __name__ == "__main__": + unittest.main() From 04a9da427d04537d728868b36e74c8e11c9fa28d Mon Sep 17 00:00:00 2001 From: Daniel Zhang Date: Tue, 28 Oct 2025 01:27:57 +0800 Subject: [PATCH 078/123] Unsqueeze unbatched input of avg_pool (#2646) Onnx's `AveragePool` require input shape as `N,C,H,W`, but torch accept both `N,C,H,W` and `C,H,W`. Unsqueeze if input is unbatched, just like what `max_pool` does. --- onnxscript/function_libs/torch_lib/ops/nn.py | 60 +++++++++---------- .../function_libs/torch_lib/e2e_ops_tests.py | 24 ++++++++ 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4f81cc7907..5edcc233d0 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -114,6 +114,33 @@ def _adjust_attributes_of_avg_pool( return (kernel_shape, strides, pads) +def _aten_avg_pool_onnx( + self: TFloat, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + ceil_mode: bool, + count_include_pad: bool, +) -> TFloat: + self_rank_is_unbatched_rank = len(self.shape) == len(kernel_shape) + 1 + if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = op.Unsqueeze(self, [0]) + + result = op.AveragePool( + self, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + if self_rank_is_unbatched_rank: + result = op.Squeeze(result, [0]) + + return result + + @torch_op("aten::avg_pool1d", trace_only=True) def aten_avg_pool1d( self: TFloat, @@ -134,16 +161,7 @@ def aten_avg_pool1d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) - - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) @torch_op("aten::avg_pool2d", trace_only=True) @@ -167,15 +185,6 @@ def aten_avg_pool2d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) - # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) # mask = [ # 1, 2, 3, S,..3, 2, 1 @@ -189,7 +198,7 @@ def aten_avg_pool2d( # S is stride size, in this case S=4, # S may dup lot of times according to the image size - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) def aten_avg_pool2d_backward( @@ -228,15 +237,6 @@ def aten_avg_pool3d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - kernel_shape=kernel_shape, - strides=strides, - pads=pads, - count_include_pad=count_include_pad, - ceil_mode=ceil_mode, - ) - # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) # mask = [ # 1, 2, 3, S,..3, 2, 1 @@ -250,7 +250,7 @@ def aten_avg_pool3d( # S is stride size, in this case S=4, # S may dup lot of times according to the image size - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) def aten_avg_pool3d_backward( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 754f5e2a25..3c557be4f0 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -238,6 +238,30 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_avg_pool(self): + class Model(torch.nn.Module): + def forward(self, x2d, x3d, x4d, x5d): + return ( + torch.nn.functional.avg_pool1d(x2d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool1d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x5d, 2), # pylint: disable=not-callable + ) + + x2d = torch.randn(10, 10) + x3d = torch.randn(10, 10, 10) + x4d = torch.randn(10, 10, 10, 10) + x5d = torch.randn(10, 10, 10, 10, 10) + onnx_program = torch.onnx.export( + Model(), + (x2d, x3d, x4d, x5d), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From 8c0b72b9d4bdf16cf3e0d86d0577c36ff7cea32b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 27 Oct 2025 10:30:10 -0700 Subject: [PATCH 079/123] Add a verbose mode to torch api for external data save (#2643) Show progress bar with tqdm when verbose is True. It will be enabled in PyTorch 2.10 image --------- Signed-off-by: Justin Chu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/_framework_apis/torch_2_5.py | 31 +++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index 162faf4b75..5bbb64af88 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -13,6 +13,7 @@ ] import dataclasses +import importlib.util import os import pathlib from typing import Callable @@ -63,7 +64,9 @@ def check_model(model: ir.Model) -> None: del model # Unused yet -def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None: +def save_model_with_external_data( + model: ir.Model, model_path: str | os.PathLike, verbose: bool = False +) -> None: """Save the model with external data. The model is unchanged after saving.""" # TODO(#1835): Decide if we want to externalize large attributes as well @@ -78,7 +81,31 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike destination_path = pathlib.Path(model_path) data_path = f"{destination_path.name}.data" - ir.save(model, model_path, external_data=data_path) + # Show a progress bar if verbose is True and tqdm is installed + use_tqdm = verbose and importlib.util.find_spec("tqdm") is not None + + if use_tqdm: + import tqdm # pylint: disable=import-outside-toplevel + + with tqdm.tqdm() as pbar: + total_set = False + + def callback( + tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo + ) -> None: + nonlocal total_set + if not total_set: + pbar.total = metadata.total + total_set = True + + pbar.update() + pbar.set_description( + f"Saving {tensor.name} ({tensor.dtype.short_name()}, {tensor.shape}) at offset {metadata.offset}" + ) + + ir.save(model, model_path, external_data=data_path, callback=callback) + else: + ir.save(model, model_path, external_data=data_path) def get_torchlib_ops() -> list[_OnnxFunctionMeta]: From bb75e2b69fbc450b16f4760ca6cba9544c93c6f3 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 27 Oct 2025 11:36:50 -0700 Subject: [PATCH 080/123] Support math trunc (#2653) Fix https://github.com/pytorch/pytorch/issues/166110 --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 36f2a70f8c..eb276a239c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8533,6 +8533,14 @@ def aten_trunc(self: TFloat) -> TFloat: return op.Floor(op.Abs(self)) * op.Sign(self) +@torch_op("math::trunc", trace_only=True) +def python_math_trunc(self: TFloat) -> TInt: + """trunc(Tensor self) -> Tensor""" + # NOTE: This is used in SymInt/SymBool/SymFloat context, so + # we don't expect overflow to happen here. + return op.Cast(self, to=INT64.dtype) + + @torch_op("aten::type_as", trace_only=True) def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: """type_as(Tensor self, Tensor other) -> Tensor""" From 3334ba10dbde6156bbcbe081536e51533dc185b5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Oct 2025 23:01:07 +0000 Subject: [PATCH 081/123] chore(deps): bump actions/upload-artifact from 4 to 5 (#2656) --- .github/workflows/main.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index faf40b9ec3..85d2a0b331 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -83,7 +83,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} - name: Upload torchlib error reports if: always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: Error reports (${{ matrix.name }}-${{ matrix.os }}) path: error_reports From b84d595efae441d8073127c3620147a6e6c8544c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Oct 2025 17:56:15 -0700 Subject: [PATCH 082/123] chore(deps): bump onnx-weekly from 1.20.0.dev20251006 to 1.20.0.dev20251027 in /requirements/ci (#2657) --- requirements/ci/requirements-onnx-weekly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index e005031603..728f319adf 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.20.0.dev20251006 +onnx-weekly==1.20.0.dev20251027 From ad83914fa2deffac8a637775edeeeb58e3e7bd64 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 Oct 2025 16:43:04 +0000 Subject: [PATCH 083/123] chore(deps): bump ruff from 0.13.2 to 0.14.2 in /requirements/lintrunner (#2658) --- .../tools/torch_lib/deduce_type_constraints_test.py | 2 +- .../function_libs/tools/torch_lib/generate_prims_signatures.py | 1 - .../onnxruntime/bfloat16_utils/bfloat16_converter_test.py | 1 - onnxscript/rewriter/rules/fusion/_layer_norm_test.py | 1 - requirements/lintrunner/requirements.txt | 2 +- tests/eager_mode_test.py | 1 - 6 files changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a8d15c242a..b547737bf5 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -12,7 +12,7 @@ import parameterized import onnxscript -import onnxscript.function_libs.torch_lib.ops # Import to populate registry +import onnxscript.function_libs.torch_lib.ops # Import to populate registry # noqa: F401 from onnxscript.function_libs.tools.torch_lib import deduce_type_constraints from onnxscript.function_libs.torch_lib import registration diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index ebbdd43bd8..6661e34afe 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -16,7 +16,6 @@ from typing import Any, Dict, List, Sequence import torch -import torchgen.gen import torchgen.model from torch._ops import _OpNamespace from torchgen.model import FunctionSchema diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index a64d6e6023..c527855bb7 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -3,7 +3,6 @@ import unittest import numpy as np -import onnx import onnx.checker import onnx.shape_inference import onnxruntime diff --git a/onnxscript/rewriter/rules/fusion/_layer_norm_test.py b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py index 6ea7f116fb..5e13f5e479 100644 --- a/onnxscript/rewriter/rules/fusion/_layer_norm_test.py +++ b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py @@ -5,7 +5,6 @@ import onnx_ir as ir -import onnxscript import onnxscript.optimizer import onnxscript.rewriter.testing from onnxscript import FLOAT, OnnxFunction, script diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index f95977610e..2a913e68f4 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.13.2 +ruff==0.14.2 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250915 diff --git a/tests/eager_mode_test.py b/tests/eager_mode_test.py index 566169f223..e4cb0ab313 100644 --- a/tests/eager_mode_test.py +++ b/tests/eager_mode_test.py @@ -6,7 +6,6 @@ import numpy as np import parameterized -import onnxscript import onnxscript.evaluator import onnxscript.tensor from onnxscript import opset17 as op From 69025f7f7b746eb62292135be2eb3e00f1b36890 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 28 Oct 2025 15:29:51 -0700 Subject: [PATCH 084/123] [version converter] Fix DFT opset 20 (#2659) Fixes https://github.com/pytorch/pytorch/issues/148687 Axis is actually the third input of DFT. --- onnxscript/version_converter/_version_converter.py | 3 ++- .../version_converter/_version_converter_test.py | 4 ++-- tests/function_libs/torch_lib/e2e_ops_tests.py | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index dddf11150c..cb7a6c43ad 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -155,12 +155,13 @@ def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> @register("DFT", node_version=19, up_conversion=True) def dft_19_20(node: ir.Node, op): input = node.inputs[0] + dft_length = node.inputs[1] if len(node.inputs) > 1 else None inverse = _get_int_attribute(node, "inverse", 0) onesided = _get_int_attribute(node, "onesided", 0) axis = _get_int_attribute(node, "axis", None) if axis is not None: axis_value = op.Constant(value_int=axis) - return op.DFT(input, axis_value, inverse=inverse, onesided=onesided) + return op.DFT(input, dft_length, axis_value, inverse=inverse, onesided=onesided) return None diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index cf6507196b..021c6e72bb 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -144,7 +144,7 @@ def test_version_convert_compatible(self): self.assertEqual(model.graph.node(3).version, 20) self.assertEqual(model.graph.node(3).op_type, "DFT") self.assertEqual(model.graph.node(3).version, 20) - self.assertEqual(len(model.graph.node(3).inputs), 2) + self.assertEqual(len(model.graph.node(3).inputs), 3) def test_version_convert_gridsample_linear(self): model = ir.from_onnx_text( @@ -241,7 +241,7 @@ def test_version_convert_inline(self): self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") self.assertEqual(model.graph.node(6).op_type, "DFT") self.assertEqual(model.graph.node(6).version, 20) - self.assertEqual(len(model.graph.node(6).inputs), 2) + self.assertEqual(len(model.graph.node(6).inputs), 3) class VersionConverter20to21Test(unittest.TestCase): diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 3c557be4f0..75457f77a2 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -238,6 +238,20 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_dft_axis_promoted_from_attribute_to_input(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._fft_r2c(x, [0], normalization=1, onesided=True) # pylint: disable=protected-access + + onnx_program = torch.onnx.export( + Model(), + (torch.randn(2, 3),), + opset_version=20, + dynamic_shapes=({0: "dim_x"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_avg_pool(self): class Model(torch.nn.Module): def forward(self, x2d, x3d, x4d, x5d): From 45b51898dc605c2ce3e800e23aeca05e25cacb4c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 28 Oct 2025 22:03:01 -0700 Subject: [PATCH 085/123] [torchlib] Fix concat when input tensor has shape `(0,)` (#2661) Filter out size-0 tensors. When there is only one input, create an identity op instead of a concat op. Fix https://github.com/microsoft/onnxscript/issues/2660 Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 16 +++++++++--- .../function_libs/torch_lib/e2e_ops_tests.py | 26 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index eb276a239c..be30520878 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1521,7 +1521,7 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::cat", trace_only=True, complex=True) +@torch_op(("aten::cat", "aten::concat", "aten::concatenate"), trace_only=True, complex=True) def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" # Real representation unsqueezes the last dimension @@ -1534,8 +1534,18 @@ def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" - # Remove None tensors - tensors = [tensor for tensor in tensors if tensor is not None] + filtered_tensors = [] + for tensor in tensors: + # Remove None tensors + if tensor is None: + continue + # Remove empty tensors + if tensor.shape == (0,): + continue + filtered_tensors.append(tensor) + assert filtered_tensors, "aten::cat received all None or empty tensors" + if len(filtered_tensors) == 1: + return op.Identity(filtered_tensors[0]) return op.Concat(*tensors, axis=dim) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 75457f77a2..24ccaf4b40 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -276,6 +276,32 @@ def forward(self, x2d, x3d, x4d, x5d): ) _testing.assert_onnx_program(onnx_program) + def test_concat_with_empty_tensor(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.cat([x, torch.tensor([]), x], dim=0) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([1, 2]),), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_concat_with_empty_tensor_single_element(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.cat([x, torch.tensor([])], dim=1) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([[1, 2]]),), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From 9e0366cc6410776a5e2dfae6b92b0c864e2a4127 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Wed, 29 Oct 2025 10:27:03 -0700 Subject: [PATCH 086/123] Create initializers not constant nodes in constant folding pass (#2650) Partially From #2598 This provides a better optimized graph after constant folding in terms of the number of nodes, which is better for debugging. --- noxfile.py | 2 +- onnxscript/optimizer/_constant_folding.py | 59 +++++++++++-------- .../optimizer/_constant_folding_test.py | 42 +++++++------ pyproject.toml | 2 +- 4 files changed, 60 insertions(+), 45 deletions(-) diff --git a/noxfile.py b/noxfile.py index 60c2bb901b..fc80761b68 100644 --- a/noxfile.py +++ b/noxfile.py @@ -41,7 +41,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.10" +ONNX_IR = "onnx_ir==0.1.12" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 927d8e47f6..e1bff26791 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1039,24 +1039,29 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - def new_constant(self, node: ir.Node, value) -> ir.Node | None: - irvalue = node.outputs[0] - if not isinstance(value, np.ndarray): + def new_initializer(self, node: ir.Node, array) -> ir.Value | None: + original_value = node.outputs[0] + if not isinstance(array, np.ndarray): # ONNX does not have a way to represent non-tensor constants, eg. a sequence. # So, a constant-value of type sequence is not folded, but it can be used # to optimize subsequent operations when possible. logger.info( "Skip storing constant folded value %s due to unsupported type %s.", - irvalue.name, - type(value), + original_value.name, + type(array), ) return None - tensor = ir.tensor(value) - tensor.name = irvalue.name - irvalue.const_value = tensor + tensor = ir.tensor(array) + tensor.name = original_value.name + initializer = ir.Value( + name=original_value.name, + type=ir.TensorType(ir.DataType(tensor.dtype)), + shape=tensor.shape, # type: ignore[arg-type] + const_value=tensor, + ) - if value.size > self.output_size_limit: + if array.size > self.output_size_limit: # Handle examples like Transpose(weight) to be folded even if the size is large, # as long as weight has no other uses. This won't increase model size. removed_input_size = 0 @@ -1065,25 +1070,23 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None: array = _get_numpy_value(input) if array is not None: removed_input_size += array.size - increased_size = value.size - removed_input_size + increased_size = array.size - removed_input_size if increased_size > 0: logger.info( "Skip storing constant folded nvalue %s due to large size %s.", - irvalue.name, - value.size, + original_value.name, + array.size, ) return None logger.debug( - "New constant for value %s dtype: %s shape: %s", - irvalue.name, - value.dtype, - value.shape, + "New Initializer for value %s dtype: %s shape: %s", + original_value.name, + array.dtype, + array.shape, ) - attributes = ir.convenience.convert_attributes({"value": tensor}) - node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) - return node + return initializer def process_node(self, node: ir.Node) -> Replacement | None: """Process a node and return a Replacement if the node can be replaced.""" @@ -1109,7 +1112,13 @@ def process_node(self, node: ir.Node) -> Replacement | None: self._do_inference(node) if node.domain not in self._opset_imports: + logger.debug( + "Skipping constant folding for node %r due to missing opset import for domain %r.", + node.name, + node.domain, + ) return None + version = self._opset_imports[node.domain] op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) for optimizer in op_optimizers: @@ -1153,7 +1162,7 @@ def process_node(self, node: ir.Node) -> Replacement | None: ) return None - # Ensure all node inputs are constants + # Ensure all node inputs are constants or initializers if any(x.const_value is None for x in node.inputs if x is not None): return None @@ -1227,10 +1236,13 @@ def convert(av): if outputs is None: return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): - replacement = self.new_constant(node, outputs) - if replacement is None: + new_initializer_value = self.new_initializer(node, outputs) + if new_initializer_value is None: return None - return Replacement(replacement.outputs, [replacement]) + # Add the new initializer to the graph + assert node.graph is not None + node.graph.register_initializer(new_initializer_value) + return Replacement([new_initializer_value], []) else: logger.warning( "Skipping constant folding for op %s with multiple outputs.", node.op_type @@ -1244,7 +1256,6 @@ def replace_node( # Record the names of the values that has contributed to the replacement _record_contributing_values(node, replacement) - ir.convenience.replace_nodes_and_values( root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs ) diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index d3d76c4a23..d9395e811c 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -36,8 +36,8 @@ def test_fold_add(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("four", optimized.graph.initializers) def test_fold_cast_like(self): model = """ @@ -51,8 +51,8 @@ def test_fold_cast_like(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("four", optimized.graph.initializers) def test_fold_shape(self): model = """ @@ -67,8 +67,8 @@ def test_fold_shape(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("four", optimized.graph.initializers) def test_fold_shape_slice(self): model = """ @@ -83,8 +83,8 @@ def test_fold_shape_slice(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("four", optimized.graph.initializers) def test_fold_if_cond(self): model = """ @@ -130,9 +130,11 @@ def test_fold_inside_if_branch(self): optimized = self._fold(model) self.assertEqual(len(optimized.graph), 1) then_graph = optimized.graph[0].attributes["then_branch"].as_graph() - self.assertEqual(len(then_graph), 2) + self.assertEqual(len(then_graph), 1) + self.assertIn("temp", then_graph.initializers) else_graph = optimized.graph[0].attributes["else_branch"].as_graph() - self.assertEqual(len(else_graph), 2) + self.assertEqual(len(else_graph), 1) + self.assertIn("temp", else_graph.initializers) def test_fold_if_propagate(self): model = """ @@ -154,9 +156,8 @@ def test_fold_if_propagate(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "m_square") - self.assertEqual(optimized.graph[0].op_type, "Constant") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("m_square", optimized.graph.initializers) def test_fold_redundant_cast(self): model = """ @@ -209,8 +210,8 @@ def test_shape_inference(self): """ optimized = self._fold(model, onnx_shape_inference=True) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "C") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("C", optimized.graph.initializers) def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( self, @@ -614,7 +615,8 @@ def test_input_size_limit(self): # Since there is no increase in model-size, output-size is not a concern. optimized = self._fold(model, input_size_limit=256 * 256, output_size_limit=256 * 256) ops = [node.op_type for node in optimized.graph] - self.assertEqual(ops, ["Constant", "Add"]) + self.assertEqual(ops, ["Add"]) + self.assertIn("w_squared", optimized.graph.initializers) def test_transpose_is_always_folded(self): model_text = """ @@ -633,7 +635,8 @@ def test_transpose_is_always_folded(self): # Input size limit will not prevent folding of Transpose op optimized = self._fold(model, input_size_limit=1) ops = [node.op_type for node in optimized.graph] - self.assertEqual(ops, ["Constant"]) + self.assertEqual(ops, []) + self.assertIn("z", optimized.graph.initializers) def test_node_is_folded_if_specified_as_should_fold(self): model_text = """ @@ -656,9 +659,10 @@ def test_node_is_folded_if_specified_as_should_fold(self): model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None ) ops = [node.op_type for node in optimized.graph] - self.assertEqual(ops, ["Constant"]) + self.assertEqual(ops, []) + self.assertIn("z", optimized.graph.initializers) np.testing.assert_array_equal( - optimized.graph.node(0).attributes["value"].as_tensor().numpy(), + optimized.graph.initializers["z"].const_value, np.ones((42, 42), dtype=np.int64), ) diff --git a/pyproject.toml b/pyproject.toml index 4f7edc9bf8..b318042633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dependencies = [ "ml_dtypes", "numpy", - "onnx_ir>=0.1.10,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx_ir>=0.1.12,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. "onnx>=1.16", "packaging", "typing_extensions>=4.10", From fe50b8365da071aaedf471393f640130873c6893 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 29 Oct 2025 10:47:16 -0700 Subject: [PATCH 087/123] Add support for traced if statements in onnxscript script (#2644) Extend the onnxscript converter to handle conditionals that are evaluated at script-time (that is, in the style of trace-mode). This makes it easier to define parametric scripts that can be used to generate variations of a pattern: for example, like the many variations of the [SDPA pattern test cases](https://github.com/microsoft/onnxscript/blob/8a94ad646440f462dd9ae1de6b303fe5f7b7f564/onnxscript/rewriter/ort_fusions/sdpa_test.py#L40). This supports just a very basic version, where the if-condition is an outer-scope variable, like below: ```py if outer_scope_variable: ... else: ... ``` For such cases, the script will just include the then or else branch as appropriate, without generating an if-node. Also: introduce an analyzer class to encapsulate analysis information, and avoid updates to AST node. TODO: some simple extension may be useful (perhaps allow any expression in the if-condition that does not contain local variables). --------- Signed-off-by: Ganesan Ramalingam Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/_internal/analysis.py | 377 +++++++++++++++----------- onnxscript/_internal/analysis_test.py | 58 +++- onnxscript/converter.py | 44 ++- onnxscript/converter_test.py | 30 ++ 4 files changed, 337 insertions(+), 172 deletions(-) diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py index 0403f60c91..c89542d344 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/analysis.py @@ -47,183 +47,254 @@ def get_id(e): return {get_id(lhs)} -def assigned_vars( - stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter -) -> Set[str]: - """Return the set of all variables that may be assigned to in an execution of input stmt - or sequence of statements. - """ - - def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: - result: set[Any] = set() - for s in block: - result = result | assigned_vars(s, formatter) - return result - - if isinstance(stmt, ast.Assign): - return _lhs_vars(stmt.targets[0]) - if isinstance(stmt, ast.AnnAssign): - return _lhs_vars(stmt.target) - if isinstance(stmt, ast.Return): - return set() - if isinstance(stmt, ast.If): - return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse) - if isinstance(stmt, ast.For): - return assigned_in_block(stmt.body) | {_get_loop_var(stmt, formatter)} - if isinstance(stmt, ast.While): - return assigned_in_block(stmt.body) - if isinstance(stmt, list): - return assigned_in_block(stmt) - if isinstance(stmt, ast.Break): - return set() - if ast_utils.is_print_call(stmt): - return set() - if ast_utils.is_doc_string(stmt): - return set() - error_message = formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") - raise ValueError(error_message) +class AstAnalyzer: + def __init__( + self, + fun: ast.FunctionDef, + formatter: sourceinfo.Formatter, + globals: dict[str, Any] | None = None, + ) -> None: + self._formatter = formatter + self._constant_if_condition: dict[ast.If, bool] = {} + self._live_in: dict[ast.stmt, Set[str]] = {} + self._live_out: dict[ast.stmt, Set[str]] = {} + if globals: + self._compute_constant_if_conditions(fun, globals) + self.do_liveness_analysis(fun) + def live_in(self, stmt: ast.stmt) -> Set[str] | None: + """Get the set of variables that are live at the entry of the given statement.""" + return self._live_in.get(stmt) -def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): - """Perform liveness analysis of the given function-ast. The results of the - analysis are stored directly with each statement-ast `s` as attributes `s.live_in` - and `s.live_out`. - """ + def live_out(self, stmt: ast.stmt) -> Set[str] | None: + """Get the set of variables that are live at the exit of the given statement.""" + return self._live_out.get(stmt) - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - stmt.live_out = live_out # type: ignore[attr-defined] - live = do_visit(stmt, live_out) - stmt.live_in = live # type: ignore[attr-defined] - return live + def _compute_constant_if_conditions( + self, fun: ast.FunctionDef, globals: dict[str, Any] + ) -> None: + """Identify if-statements with constant conditions. - def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: - for s in reversed(block): - live_out = visit(s, live_out) - return live_out + If-statements of the form `if name:` where `name` is an outer-scope variable + and name is not assigned to within the function body, are treated as constant + conditions. The value of such conditions is determined from the outer-scope. + """ + + assigned_vars = self.assigned_vars(fun.body) + for node in ast.walk(fun): + if isinstance(node, ast.If): + if isinstance(node.test, ast.Name): + python_var = node.test.id + if python_var not in assigned_vars and python_var in globals: + # Condition depends on an outer-scope variable. + self._constant_if_condition[node] = bool(globals[python_var]) + + def constant_if_condition(self, if_stmt: ast.If) -> Optional[bool]: + """Return the constant value of the if-statement condition, if it is constant. + + Args: + if_stmt: The if-statement-ast to analyze. + + Returns: + The constant boolean value of the if-statement condition, or None if not constant. + """ + return self._constant_if_condition.get(if_stmt, None) + + def assigned_vars(self, stmt: ast.stmt | list[ast.stmt]) -> Set[str]: + """Return the set of all variables that may be assigned to in an execution of input stmt + or sequence of statements. + """ + + def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: + result: set[Any] = set() + for s in block: + result = result | self.assigned_vars(s) + return result if isinstance(stmt, ast.Assign): - return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) + return _lhs_vars(stmt.targets[0]) if isinstance(stmt, ast.AnnAssign): - return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) + return _lhs_vars(stmt.target) if isinstance(stmt, ast.Return): - return _used_vars(stmt.value) + return set() if isinstance(stmt, ast.If): - live1 = visitBlock(stmt.body, live_out) - live2 = visitBlock(stmt.orelse, live_out) - return live1 | live2 | _used_vars(stmt.test) + constant_cond = self.constant_if_condition(stmt) + if constant_cond is None: + return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse) + elif constant_cond: + return assigned_in_block(stmt.body) + else: + return assigned_in_block(stmt.orelse) if isinstance(stmt, ast.For): - p_loop_var = _get_loop_var(stmt, formatter) - prev = None - curr = live_out - while curr != prev: - prev = curr - curr = visitBlock(stmt.body, prev).difference({p_loop_var}) - return curr + return assigned_in_block(stmt.body) | {_get_loop_var(stmt, self._formatter)} if isinstance(stmt, ast.While): - cond_vars = _used_vars(stmt.test) - prev = None - curr = live_out | cond_vars - while curr != prev: - prev = curr - curr = visitBlock(stmt.body, prev) | cond_vars - return curr + return assigned_in_block(stmt.body) + if isinstance(stmt, list): + return assigned_in_block(stmt) if isinstance(stmt, ast.Break): - # The following is sufficient for the current restricted usage, where - # a (conditional) break is allowed only as the last statement of a loop. - # Break statements in the middle of the loop, however, will require - # a generalization. - return live_out - if ast_utils.is_doc_string(stmt): - return live_out + return set() if isinstance(stmt, ast.FunctionDef): - return live_out + # Supported function-definitions (used for higher order ops like Scan) + # do not assign to any variable in the outer scope. + return set() if ast_utils.is_print_call(stmt): - return live_out - raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")) + return set() + if ast_utils.is_doc_string(stmt): + return set() + error_message = self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") + raise ValueError(error_message) - assert isinstance(fun, ast.FunctionDef) - live: set[Any] = set() - for s in reversed(fun.body): - live = visit(s, live) + def do_liveness_analysis(self, fun: ast.FunctionDef): + """Perform liveness analysis of the given function-ast.""" + def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + self._live_out[stmt] = live_out + live = do_visit(stmt, live_out) + self._live_in[stmt] = live + return live -def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter): - """Return the set of variables that are used before being defined by given block. - In essence, this identifies the "inputs" to a given code-block. - For example, consider the following code-block: - :: + def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + for s in reversed(block): + live_out = visit(s, live_out) + return live_out - x = x + 10 - y = 20 - z = x + y - x = 30 + if isinstance(stmt, ast.Assign): + return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) + if isinstance(stmt, ast.AnnAssign): + return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) + if isinstance(stmt, ast.Return): + return _used_vars(stmt.value) + if isinstance(stmt, ast.If): + constant_cond = self.constant_if_condition(stmt) + if constant_cond is None: + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return live1 | live2 | _used_vars(stmt.test) + elif constant_cond: + return visitBlock(stmt.body, live_out) + else: + return visitBlock(stmt.orelse, live_out) + if isinstance(stmt, ast.For): + p_loop_var = _get_loop_var(stmt, self._formatter) + prev = None + curr = live_out + while curr != prev: + prev = curr + curr = visitBlock(stmt.body, prev).difference({p_loop_var}) + return curr + if isinstance(stmt, ast.While): + cond_vars = _used_vars(stmt.test) + prev = None + curr = live_out | cond_vars + while curr != prev: + prev = curr + curr = visitBlock(stmt.body, prev) | cond_vars + return curr + if isinstance(stmt, ast.Break): + # The following is sufficient for the current restricted usage, where + # a (conditional) break is allowed only as the last statement of a loop. + # Break statements in the middle of the loop, however, will require + # a generalization. + return live_out + if ast_utils.is_doc_string(stmt): + return live_out + if isinstance(stmt, ast.FunctionDef): + return live_out + if ast_utils.is_print_call(stmt): + return live_out + raise ValueError( + self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") + ) - The exposed_uses of this code-block is { x }. The value of z is not used within - the block. Even though the value of y is used within the block, it is assigned - a value before it is used. However, in contrast, the incoming value of x is used - (in the first statement). Hence x is included in the exposed_uses. - """ + assert isinstance(fun, ast.FunctionDef) + live: set[Any] = set() + for s in reversed(fun.body): + live = visit(s, live) - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: - for stmt in reversed(block): - live_out = visit(stmt, live_out) - return live_out + def exposed_uses(self, stmts: Sequence[ast.stmt]): + """Return the set of variables that are used before being defined by given block. + In essence, this identifies the "inputs" to a given code-block. + For example, consider the following code-block: + :: - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - if isinstance(stmt, ast.Assign): - return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) - if isinstance(stmt, ast.AnnAssign): - return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) - if isinstance(stmt, ast.Return): - return _used_vars(stmt.value) - if isinstance(stmt, ast.If): - live1 = visitBlock(stmt.body, live_out) - live2 = visitBlock(stmt.orelse, live_out) - return (live1 | live2) | _used_vars(stmt.test) - if ast_utils.is_print_call(stmt): - return live_out - if ast_utils.is_doc_string(stmt): - return live_out - if isinstance(stmt, ast.For): - # Analysis assumes loop may execute zero times. Results can be improved - # for loops that execute at least once. - loop_var_set = {_get_loop_var(stmt, formatter)} - used_after_loop = live_out.difference(loop_var_set) - used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set) - used_in_loop_header = _used_vars(stmt.iter) - return used_inside_loop | used_in_loop_header | used_after_loop - if isinstance(stmt, ast.While): - # Analysis assumes loop may execute zero times. Results can be improved - # for loops that execute at least once. - used_inside_loop = visitBlock(stmt.body, set()) - used_in_loop_header = _used_vars(stmt.test) - return used_inside_loop | used_in_loop_header | live_out - if isinstance(stmt, ast.Break): - # Currently, we assume that break statements are only allowed as the last - # statement in a loop, as "if cond: break". - return live_out - if isinstance(stmt, ast.FunctionDef): - if stmt.name in live_out: - live_out.remove(stmt.name) - live_out = live_out | outer_scope_variables(stmt, formatter) + x = x + 10 + y = 20 + z = x + y + x = 30 + + The exposed_uses of this code-block is { x }. The value of z is not used within + the block. Even though the value of y is used within the block, it is assigned + a value before it is used. However, in contrast, the incoming value of x is used + (in the first statement). Hence x is included in the exposed_uses. + """ + + def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + for stmt in reversed(block): + live_out = visit(stmt, live_out) return live_out - raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")) - return visitBlock(stmts, set()) + def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + if isinstance(stmt, ast.Assign): + return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) + if isinstance(stmt, ast.AnnAssign): + return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) + if isinstance(stmt, ast.Return): + return _used_vars(stmt.value) + if isinstance(stmt, ast.If): + constant_cond = self.constant_if_condition(stmt) + if constant_cond is None: + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return (live1 | live2) | _used_vars(stmt.test) + elif constant_cond: + return visitBlock(stmt.body, live_out) + else: + return visitBlock(stmt.orelse, live_out) + if ast_utils.is_print_call(stmt): + return live_out + if ast_utils.is_doc_string(stmt): + return live_out + if isinstance(stmt, ast.For): + # Analysis assumes loop may execute zero times. Results can be improved + # for loops that execute at least once. + loop_var_set = {_get_loop_var(stmt, self._formatter)} + used_after_loop = live_out.difference(loop_var_set) + used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set) + used_in_loop_header = _used_vars(stmt.iter) + return used_inside_loop | used_in_loop_header | used_after_loop + if isinstance(stmt, ast.While): + # Analysis assumes loop may execute zero times. Results can be improved + # for loops that execute at least once. + used_inside_loop = visitBlock(stmt.body, set()) + used_in_loop_header = _used_vars(stmt.test) + return used_inside_loop | used_in_loop_header | live_out + if isinstance(stmt, ast.Break): + # Currently, we assume that break statements are only allowed as the last + # statement in a loop, as "if cond: break". + return live_out + if isinstance(stmt, ast.FunctionDef): + if stmt.name in live_out: + live_out.remove(stmt.name) + live_out = live_out | self.outer_scope_variables(stmt) + return live_out + raise ValueError( + self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") + ) + return visitBlock(stmts, set()) -def outer_scope_variables(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): - """Return the set of outer-scope variables used in a nested function. + def outer_scope_variables(self, fun: ast.FunctionDef): + """Return the set of outer-scope variables used in a nested function. - Args: - fun: The function-ast to analyze. - formatter: The formatter object. + Args: + fun: The function-ast to analyze. + formatter: The formatter object. - Returns: - A set of variable names (strings). - """ - assert isinstance(fun, ast.FunctionDef) - used_vars_ = exposed_uses(fun.body, formatter) - inputs = [x.arg for x in fun.args.args] - return used_vars_.difference(inputs) + Returns: + A set of variable names (strings). + """ + assert isinstance(fun, ast.FunctionDef) + used_vars_ = self.exposed_uses(fun.body) + inputs = [x.arg for x in fun.args.args] + return used_vars_.difference(inputs) diff --git a/onnxscript/_internal/analysis_test.py b/onnxscript/_internal/analysis_test.py index 74e7ca4c18..7a7e5feaa0 100644 --- a/onnxscript/_internal/analysis_test.py +++ b/onnxscript/_internal/analysis_test.py @@ -14,24 +14,27 @@ class AnalysisResultsVisitor(ast.NodeVisitor): """Visitor class to flatten the results of liveness analysis in a pre-order traversal.""" - def __init__(self) -> None: + def __init__(self, analyzer: analysis.AstAnalyzer) -> None: super().__init__() self.results: list[Any] = [] + self.analyzer = analyzer def generic_visit(self, node): - if hasattr(node, "live_in"): - self.results.append(node.live_in) + live_in = self.analyzer.live_in(node) + if live_in is not None: + self.results.append(live_in) ast.NodeVisitor.generic_visit(self, node) if isinstance(node, (ast.For, ast.While)): last = node.body[-1] - self.results.append(last.live_out) # type: ignore + live_out = self.analyzer.live_out(last) + self.results.append(live_out) # type: ignore class TestLivenessAnalysis(unittest.TestCase): def analyze(self, fun): source, parse_tree = ast_utils.get_src_and_ast(fun) - analysis.do_liveness_analysis(parse_tree, formatter(source)) - visitor = AnalysisResultsVisitor() + analyzer = analysis.AstAnalyzer(parse_tree, formatter(source)) + visitor = AnalysisResultsVisitor(analyzer) visitor.visit(parse_tree) return visitor.results @@ -113,7 +116,8 @@ def while_eg(x): class TestExposedUses(unittest.TestCase): def assertUses(self, f, expected): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.exposed_uses(parse_tree.body, formatter(source)) + analyzer = analysis.AstAnalyzer(parse_tree, formatter(source)) + result = analyzer.exposed_uses(parse_tree.body) self.assertEqual(result, set(expected)) def test_basic(self): @@ -190,7 +194,8 @@ def f(x): class TestAssignedVarAnalysis(unittest.TestCase): def assert_assigned_vars(self, f, expected: set[str]): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.assigned_vars(parse_tree.body, formatter(source)) + analyzer = analysis.AstAnalyzer(parse_tree, formatter(source)) + result = analyzer.assigned_vars(parse_tree.body) self.assertEqual(result, expected) def test_basic_defs(self): @@ -248,5 +253,42 @@ def f(x): self.assert_assigned_vars(f, {"x", "y"}) +class ConstantIfAnalysisTest(unittest.TestCase): + def test_constant_ifs(self): + cond1 = True + cond2 = False + + def f(x): + if cond1: + y = x + 1 + else: + y = x + 2 + if cond2: + z = y * 2 + else: + z = y * 3 + if x > 0: + w = z - 1 + else: + w = z + 1 + return w + + source, parse_tree = ast_utils.get_src_and_ast(f) + + analyzer = analysis.AstAnalyzer( + parse_tree, formatter(source), {"cond1": True, "cond2": False} + ) + for node in ast.walk(parse_tree): + if isinstance(node, ast.If): + result = analyzer.constant_if_condition(node) + if isinstance(node.test, ast.Name): + if node.test.id == "cond1": + self.assertEqual(result, True) + elif node.test.id == "cond2": + self.assertEqual(result, False) + else: + self.assertIsNone(result) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index dfcddefbd3..3e87c366ad 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -183,6 +183,13 @@ def __init__( self._nextvar: int = 0 self._used_vars: set[str] = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._analyzer: analysis.AstAnalyzer | None = None + + @property + def analyzer(self) -> analysis.AstAnalyzer: + if self._analyzer is None: + raise RuntimeError("Analyzer not initialized.") + return self._analyzer @property def default_opset(self) -> values.Opset: @@ -1089,12 +1096,24 @@ def ret(exp, i, suffix): return ret(val, 0, "") def _translate_if_stmt(self, stmt: ast.If) -> None: - if hasattr(stmt, "live_out"): - live_defs = list( - stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message)) - ) - else: - live_defs = list(analysis.assigned_vars(stmt, self._message)) + constant_cond = self.analyzer.constant_if_condition(stmt) + if constant_cond is True: + # Translate only the "then" branch + for s in stmt.body: + self._translate_stmt(s) + return + if constant_cond is False: + # Translate only the "else" branch + for s in stmt.orelse: + self._translate_stmt(s) + return + live_def_set = self.analyzer.assigned_vars(stmt) + live_out = self.analyzer.live_out(stmt) + if live_out is not None: + # Ideally, live_out should never be None here. But handle this conditionally + # due to some existing usage. + live_def_set = live_out.intersection(live_def_set) + live_defs = list(live_def_set) test = self._translate_expr(stmt.test, "cond").name lineno = self._source_of(stmt).lineno thenGraph, sub_fct_then = self._translate_block( @@ -1174,9 +1193,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: else: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body - exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message) - vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message) - loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) + exposed_uses = self.analyzer.exposed_uses(loop_stmt.body) + vars_def_in_loop = self.analyzer.assigned_vars(loop_stmt.body) + live_out = self.analyzer.live_out(loop_stmt) + assert live_out is not None, "live_out cannot be None here." + loop_state_vars = vars_def_in_loop.intersection(exposed_uses | live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) @@ -1362,7 +1383,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: self._enter_scope(fn.name, fn) self._translate_function_def_common(fn) function_ir = self._exit_scope() - outer_scope_vars = analysis.outer_scope_variables(fn, self._message) + outer_scope_vars = self.analyzer.outer_scope_variables(fn) function_ir.outer_scope_variables = [ (var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars ] @@ -1448,10 +1469,11 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: self._set_default_opset(opset, stmt) domain = self.this_module.domain self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) - analysis.do_liveness_analysis(stmt, self._message) + self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals) fn_ir = self._translate_function_def_common(stmt) fn_ir.debug_print() self.this_module.add_function_def(fn_ir) + self._analyzer = None return fn_ir raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.") diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 9a7ca504a7..a35711aea9 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -710,6 +710,36 @@ def model(x): self.assertEqual(len(onnx_opset_import), 1) self.assertEqual(onnx_opset_import[0].version, 19) + def test_traced_if(self): + """Test that traced if statements are converted correctly.""" + + @script() + def add_model(x: FLOAT[10]) -> FLOAT[10]: + y = op.Add(x, x) + return y + + @script() + def sub_model(x: FLOAT[10]) -> FLOAT[10]: + y = op.Sub(x, x) + return y + + def make_model(flag: bool): + @script() + def model(x: FLOAT[10]) -> FLOAT[10]: + if flag: + y = op.Add(x, x) + else: + y = op.Sub(x, x) + return y + + return model.to_model_proto() + + model_true = make_model(True) + onnxscript.testing.assert_isomorphic(model_true, add_model.to_model_proto()) + + model_false = make_model(False) + onnxscript.testing.assert_isomorphic(model_false, sub_model.to_model_proto()) + if __name__ == "__main__": unittest.main(verbosity=2) From 647754f31cc164a772c6a3af4dae6470cd13e857 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 29 Oct 2025 11:05:38 -0700 Subject: [PATCH 088/123] Extend GQA fusion for Qwen (#2662) A couple of extensions to the GQA fusion pattern: * Support the case where there is no past key/value cache, and * Normalization and Transpose occur in the opposite order in Qwen (which has the same behavior). Support this pattern variation. TODO: add test-cases to cover and validate this --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gqa.py | 48 ++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index f1971904f0..907ffe27bc 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -163,6 +163,13 @@ def pattern( ): # Reshape query from (B, S, D) to (B, S, H, D/H) query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) + # Qwen variant uses normalization of query/key before rotary embedding: + # The normalization can happen before (eg., Qwen) or after the Transpose (eg., Gemma). + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, pattern.ANY_VALUE, axis=-1, _outputs=["query_BSHDh_normalized"] + ) + query_BSHDh = pattern.OrValue([query_BSHDh, query_BSHDh_normalized]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) @@ -174,6 +181,11 @@ def pattern( # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"]) + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, pattern.ANY_VALUE, axis=-1, _outputs=["key_BSHkvDh_normalized"] + ) + key_BSHkvDh = pattern.OrValue([key_BSHkvDh, key_BSHkvDh_normalized]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) @@ -209,6 +221,8 @@ def pattern( # that share key/value. key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + # Concat with past_key is optional: + key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope]) key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) key_seq_BHTDh = op.Reshape( @@ -218,6 +232,8 @@ def pattern( # Concatenate past_value cache and current value, expand across heads # that share key/value. value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + # Concat with past_value is optional: + value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh]) value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) value_seq_BHTDh = op.Reshape( @@ -254,8 +270,23 @@ def check( query_BSHDh, key_BSHkvDh, mask, + query_BSHDh_normalized=None, + query_BHSDh_normalized=None, + key_BSHkvDh_normalized=None, + key_BHkvSDh_normalized=None, **_, ): + result = pattern.MatchResult() + if query_BSHDh_normalized is not None and query_BHSDh_normalized is not None: + return result.fail( + "Query normalized twice", + [query_BSHDh_normalized, query_BHSDh_normalized], + ) + if key_BSHkvDh_normalized is not None and key_BHkvSDh_normalized is not None: + return result.fail( + "Key normalized twice", + [key_BSHkvDh_normalized, key_BHkvSDh_normalized], + ) bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: @@ -268,9 +299,9 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: if no_match(value_BSDkv, ["B", "S", "Dkv"]): return False - if no_match(past_key, ["B", "Hkv", "P", "Dh"]): + if past_key is not None and no_match(past_key, ["B", "Hkv", "P", "Dh"]): return False - if no_match(past_value, ["B", "Hkv", "P", "Dv"]): + if past_value is not None and no_match(past_value, ["B", "Hkv", "P", "Dv"]): return False # TODO: verify Reshapes: @@ -278,7 +309,6 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: # or check Reshape's shape-input value - result = pattern.MatchResult() num_heads = _ir_utils.get_dim(query_BSHDh, 2) kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) if not isinstance(num_heads, int): @@ -330,7 +360,9 @@ def rewrite( mask, query_BSHDh, key_BSHkvDh, + query_BSHDh_normalized=None, query_BHSDh_normalized=None, + key_BSHkvDh_normalized=None, key_BHkvSDh_normalized=None, **_, ): @@ -352,9 +384,10 @@ def rewrite( max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0) total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d) - if query_BHSDh_normalized is not None: + normalized_query = query_BHSDh_normalized or query_BSHDh_normalized + if normalized_query is not None: # We apply normalization without the transpose, which is fused into GQA - norm_node = query_BHSDh_normalized.producer() + norm_node = normalized_query.producer() norm_attrs = norm_node.attributes norm_scale = norm_node.inputs[1] query_BSHDh_normalized = op.SimplifiedLayerNormalization( @@ -363,9 +396,10 @@ def rewrite( reshape_BSHDh_to_BSD = op.Constant(value_ints=[0, 0, -1]) query_BSD = op.Reshape(query_BSHDh_normalized, reshape_BSHDh_to_BSD) - if key_BHkvSDh_normalized is not None: + normalized_key = key_BHkvSDh_normalized or key_BSHkvDh_normalized + if normalized_key is not None: # We apply normalization without the transpose, which is fused into GQA - norm_node = key_BHkvSDh_normalized.producer() + norm_node = normalized_key.producer() norm_attrs = norm_node.attributes norm_scale = norm_node.inputs[1] key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( From ee9a6e8f3548f138a6ee241bf7e62b7f77c103a1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 29 Oct 2025 13:56:05 -0700 Subject: [PATCH 089/123] Declare support for Python 3.14 in pyproject.toml (#2663) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b318042633..1e6a99f656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "License :: OSI Approved :: MIT License", ] dependencies = [ From 3846705a9571b47b189caa14755aef6041770574 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 30 Oct 2025 18:58:46 -0700 Subject: [PATCH 090/123] Clear initializers in constant folding pass (#2668) Clear unused initializers on the fly to prevent memory usage jump due to intermediate folded tensors. --------- Signed-off-by: Justin Chu --- VERSION | 2 +- onnxscript/optimizer/_constant_folding.py | 23 ++++++++++++ .../optimizer/_constant_folding_test.py | 35 +++++++++++++++---- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/VERSION b/VERSION index d1d899fa33..b49b25336d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.5 +0.5.6 diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index e1bff26791..dfb072417a 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1256,10 +1256,20 @@ def replace_node( # Record the names of the values that has contributed to the replacement _record_contributing_values(node, replacement) + + # Obtain the list of non-None inputs to the node before it is cleared by + # replace_nodes_and_values to check for unused initializers later. + node_inputs = [v for v in node.inputs if v is not None] + ir.convenience.replace_nodes_and_values( root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs ) + if isinstance(root, ir.Graph): + # The old node should now be detached from the graph + assert node.graph is None + _clear_unused_initializers(node_inputs) + self._modified = True # TODO: what about new opset_imports? @@ -1336,6 +1346,19 @@ def _sym_value_can_replace_graph_output( return True +def _clear_unused_initializers(values: Sequence[ir.Value]) -> None: + # Detach all inputs to the node, then check for unused initializers + for value in values: + if value is None or not value.is_initializer(): + continue + + if not value.uses(): + assert value.is_initializer() + assert value.graph is not None + assert value.name is not None + value.graph.initializers.pop(value.name) + + @dataclasses.dataclass class FoldConstantsResult(ir.passes.PassResult): symbolic_value_map: dict[ir.Value, SymbolicValue] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index d9395e811c..96a143f81a 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -14,13 +14,20 @@ class FoldConstantsTest(unittest.TestCase): - def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs): + def _fold( + self, + model: ir.Model | str, + onnx_shape_inference: bool = False, + dce: bool = True, + **kwargs, + ): if isinstance(model, str): model = ir.from_onnx_text(model) _constant_folding.fold_constants( model, onnx_shape_inference=onnx_shape_inference, **kwargs ) - optimizer.remove_unused_nodes(model) + if dce: + optimizer.remove_unused_nodes(model) # Ensure the model is valid after optimization onnx.checker.check_model(ir.serde.serialize_model(model)) return model @@ -50,9 +57,16 @@ def test_fold_cast_like(self): } """ - optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 1) + optimized = self._fold(model, dce=False) self.assertIn("four", optimized.graph.initializers) + np.testing.assert_equal( + optimized.graph.initializers["four"].const_value, np.array(4.0) + ) + # Intermediates should be removed + self.assertNotIn("two_float", optimized.graph.initializers) + + optimized = self._fold(model, dce=True) + self.assertEqual(len(optimized.graph), 1) def test_fold_shape(self): model = """ @@ -66,9 +80,18 @@ def test_fold_shape(self): } """ - optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 1) + optimized = self._fold(model, dce=False) self.assertIn("four", optimized.graph.initializers) + np.testing.assert_equal( + optimized.graph.initializers["four"].const_value, np.array(4.0) + ) + # Intermediates should be removed + self.assertNotIn("two_float", optimized.graph.initializers) + self.assertNotIn("rank", optimized.graph.initializers) + self.assertNotIn("shape", optimized.graph.initializers) + + optimized = self._fold(model, dce=True) + self.assertEqual(len(optimized.graph), 1) def test_fold_shape_slice(self): model = """ From 8a7de40cf429aab291cdded1b84069bf600a123d Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 30 Oct 2025 22:41:12 -0700 Subject: [PATCH 091/123] Add GQA fusion test cases (#2669) Add GQA fusion test cases to cover extensions introduced recently to cover patterns seen in Qwen. --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/gqa_test.py | 57 ++++++++++++++++----- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index c7ed888142..038749c017 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -10,6 +10,7 @@ import onnx_ir as ir import onnx_ir.passes.common.shape_inference as shape_inference import onnxruntime as ort +import parameterized import torch import onnxscript @@ -361,14 +362,26 @@ def test_fusion(self): assert_allclose(outputs3, source_model_outputs) +@parameterized.parameterized_class( + [ + {"with_past": True, "transpose_first": True}, + {"with_past": True, "transpose_first": False}, + {"with_past": False, "transpose_first": True}, + {"with_past": False, "transpose_first": False}, + ] +) class GemmaGQAFusionTest(unittest.TestCase): + with_past = True + transpose_first = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Config parameters self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? self.seqlen = 8 self.kv_seqlen = self.seqlen - self.past_seqlen = 16 + self.past_seqlen = 16 if self.with_past else 0 self.head_size = 16 self.num_heads = 20 self.kv_num_heads = 10 @@ -425,6 +438,8 @@ def __init__(self, *args, **kwargs): } def source_model_script(self): + with_past = self.with_past + transpose_first = self.transpose_first scale_factor = math.sqrt(math.sqrt(self.head_size)) minval = torch.finfo(torch.float32).min minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) @@ -458,16 +473,30 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only # one sequence length (S) for all Q, K, and V (with no cache). query_BSHDh = op.Reshape(query, shape_BSHDh) - query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - query_BHSDh_normalized = op.SimplifiedLayerNormalization( - query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 - ) - key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) - key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( - key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 - ) + + if transpose_first: + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + query_BHSDh_normalized = op.SimplifiedLayerNormalization( + query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( + key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + else: + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + query_BHSDh_normalized = op.Transpose( + query_BSHDh_normalized, perm=[0, 2, 1, 3] + ) + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh_normalized = op.Transpose( + key_BSHkvDh_normalized, perm=[0, 2, 1, 3] + ) value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) @@ -489,9 +518,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal cos, sin, ) - key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + if with_past: + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + else: + key_seq_BHkvSkvDh = key_BHkvSDh_rope + value_seq_BHkvSkvDh = value_BHkvSDh # Now, expand from shared heads to all heads key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) From 9b699aeb092d4be0ef0474c91e3a32e8620cfbb0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 31 Oct 2025 11:18:18 -0700 Subject: [PATCH 092/123] Improve constant folding error messages and allow Identity to skip shape merging (#2670) When Identity fails to merge shapes, allow the constant folder to proceed by ignoring the conflicting shape. Also improve error message to show node information if constant folding fails. --------- Signed-off-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 24 +++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index dfb072417a..2c6d9b46ff 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -602,7 +602,16 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: output = node.outputs[0] if input is not None and output is not None: # NOTE: backward shape inference - input.shape = _merge_shapes(input.shape, output.shape) + try: + input.shape = _merge_shapes(input.shape, output.shape) + except Exception as e: + logger.warning( + "[Constant folder] Cannot merge shapes on Identity node '%s' " + "(folded from: %s) because of error: %s", + node.name, + input.meta.get(FOLDED_FROM_KEY, set()), + e, + ) if input.type is None: input.type = output.type state.set_sym_value(output, input) @@ -919,7 +928,9 @@ def merge_dims(dim1, dim2): if other_shape is None: return preferred_shape if len(preferred_shape) != len(other_shape): - raise ValueError("Shapes must have the same rank.") + raise ValueError( + f"Shapes must have the same rank, got preferred_shape={preferred_shape}, other_shape={other_shape}" + ) return ir.Shape( [merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)] ) @@ -1035,7 +1046,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: except Exception as e: logger.debug( "Skipping shape inference for node %r due to exception: %s", - node.name, + node, e, ) @@ -1124,7 +1135,12 @@ def process_node(self, node: ir.Node) -> Replacement | None: for optimizer in op_optimizers: assert optimizer context = RewriterContext() - output = optimizer(node, context, self._state) + try: + output = optimizer(node, context, self._state) + except Exception as e: + raise RuntimeError( + f"Error during constant folding for node {node.name!r} ({node.domain}::{node.op_type})" + ) from e if output is not None: if isinstance(output, Replacement): return output From 5be9d3b1460e92d39082cc6d8c65e32d27a20315 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 31 Oct 2025 16:23:27 -0700 Subject: [PATCH 093/123] Fix scalar constant check (#2672) Fix scalar constant check. TODO: for some optimizations, generalizations are possible, but they must be done on a rule-by-rule basis. Eg., for eliminating an addition of zero to x: eliminating this is always safe zero is a scalar, but if it is multi-dimensional, then it is safe if its rank is less than that of x. --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/_matcher.py | 18 ++++++++++-------- .../models/_rotary_embedding_models.py | 12 ++++++------ onnxscript/rewriter/models/_smollm_1.py | 10 +++++----- onnxscript/rewriter/models/_smollm_2.py | 14 +++++++------- .../rewriter/ort_fusions/cos_sin_cache.py | 4 ++-- .../rewriter/ort_fusions/cos_sin_cache_test.py | 2 +- onnxscript/rewriter/ort_fusions/gqa.py | 4 ++-- onnxscript/rewriter/ort_fusions/gqa_test.py | 8 ++++---- .../rewriter/rules/common/_no_op_test.py | 16 ++++++++++++++++ 9 files changed, 53 insertions(+), 35 deletions(-) diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index e347b98375..f54b77033f 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -87,7 +87,7 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu ) try: - constant_value_numpy = constant_value.numpy() + numpy_value = constant_value.numpy() except FileNotFoundError: return self.fail(f"Constant value of {value.name} not available.") @@ -95,11 +95,13 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu if isinstance(pattern_constant_value, list): expected_shape = (len(pattern_constant_value),) - if constant_value_numpy.shape != expected_shape: - return self.fail(f"Value has mismatching shape, expecting {expected_shape}.") + if numpy_value.shape != expected_shape: + return self.fail( + f"Value {value.name} has shape {numpy_value.shape}, expecting {expected_shape}." + ) if not all( math.isclose( - constant_value_numpy.item(i), + numpy_value.item(i), pattern_constant_value[i], rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, @@ -107,24 +109,24 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu for i in range(len(pattern_constant_value)) ): return self.fail( - f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}." + f"Value mismatch: expected {pattern_constant_value}, got {numpy_value}." ) return True # TODO (rama): allow users to specify shape requirement, if desired. - if constant_value_numpy.size != 1: + if numpy_value.ndim != 0: return self.fail( f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.", ) if not math.isclose( - constant_value_numpy.item(), + numpy_value.item(), pattern_constant_value, rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, ): return self.fail( - f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.", + f"Constant value mismatch: expected {pattern_constant_value}, got {numpy_value.item()}.", ) return True diff --git a/onnxscript/rewriter/models/_rotary_embedding_models.py b/onnxscript/rewriter/models/_rotary_embedding_models.py index ecdb7d138b..3709cd04f7 100644 --- a/onnxscript/rewriter/models/_rotary_embedding_models.py +++ b/onnxscript/rewriter/models/_rotary_embedding_models.py @@ -26,8 +26,8 @@ def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOA emb = op.Concat(freqs, freqs, axis=-1) cos = op.Cos(emb) sin = op.Sin(emb) - cos_4d = op.Unsqueeze(cos, 1) - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) x1 = op.Slice(x, [0], [4], [3], [1]) x2 = op.Slice(x, [4], [8], [3], [1]) @@ -73,8 +73,8 @@ def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1 emb = op.Concat(freqs, freqs, axis=-1) cos = op.Cos(emb) sin = op.Sin(emb) - cos_4d = op.Unsqueeze(cos, 1) - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) x1 = op.Slice(x, [0], [4], [3], [1]) x2 = op.Slice(x, [4], [8], [3], [1]) @@ -127,8 +127,8 @@ def _partial_rotary_script(position_ids, query): # Split the query for partial embedding to_embed = op.Slice(query, [0], [32], [3], [1]) unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) - cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd] - sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd] + cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd] + sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd] # Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X) # essentially represents X rotated by 90 degrees to_embed_times_cos = op.Mul(to_embed, cos_4d) diff --git a/onnxscript/rewriter/models/_smollm_1.py b/onnxscript/rewriter/models/_smollm_1.py index d592eb2572..e3efecfe17 100644 --- a/onnxscript/rewriter/models/_smollm_1.py +++ b/onnxscript/rewriter/models/_smollm_1.py @@ -59,8 +59,8 @@ def main_graph( minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38]) mask_10x10 = opset18.Trilu(minus_inf_10x10, 1) slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10]) - unsqueeze_2 = opset18.Unsqueeze(input1, 1) - unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2) + unsqueeze_2 = opset18.Unsqueeze(input1, [1]) + unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, [2]) add = slice_5 + unsqueeze_3 eq = add == 0.0 slice_10 = slice_5 @@ -69,7 +69,7 @@ def main_graph( slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3]) val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3]) slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3]) - unsqueeze_6 = opset18.Unsqueeze(input2, 1) + unsqueeze_6 = opset18.Unsqueeze(input2, [1]) to_copy_1 = opset18.Cast(unsqueeze_6, to=1) view_1 = opset18.Constant( value=ir.tensor( @@ -138,8 +138,8 @@ def main_graph( transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3]) view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0) transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) - unsqueeze_7 = opset18.Unsqueeze(cos, 1) - unsqueeze_8 = opset18.Unsqueeze(sin, 1) + unsqueeze_7 = opset18.Unsqueeze(cos, [1]) + unsqueeze_8 = opset18.Unsqueeze(sin, [1]) mul_5 = transpose_1 * unsqueeze_7 val_267 = opset18.Constant(value_ints=[1]) slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267) diff --git a/onnxscript/rewriter/models/_smollm_2.py b/onnxscript/rewriter/models/_smollm_2.py index 62d857a2d6..47ad451895 100644 --- a/onnxscript/rewriter/models/_smollm_2.py +++ b/onnxscript/rewriter/models/_smollm_2.py @@ -51,7 +51,7 @@ def main_graph( gt = arange_1 > view convert_element_type_default = opset18.Cast(gt, to=1) mul = triu * convert_element_type_default - dim__2 = opset18.Constant(value_int=0) + dim__2 = opset18.Constant(value_ints=[0]) dim_0__2 = opset18.Cast(dim__2, to=7) unsqueeze = opset18.Unsqueeze(model_rotary_emb_inv_freq, dim_0__2) val_15 = opset18.Cast(0, to=7) @@ -65,7 +65,7 @@ def main_graph( val_25 = opset18.Reshape(val_23, val_24, allowzero=0) val_26 = opset18.Constant(value_ints=[1]) slice_1 = opset18.Slice(unsqueeze, val_17, val_21, val_25, val_26) - dim__3 = opset18.Constant(value_int=2) + dim__3 = opset18.Constant(value_ints=[2]) dim_0__3 = opset18.Cast(dim__3, to=7) unsqueeze_1 = opset18.Unsqueeze(slice_1, dim_0__3) _to_copy = opset18.Cast(unsqueeze_1, to=1) @@ -83,7 +83,7 @@ def main_graph( val_36 = opset18.Reshape(val_34, val_35, allowzero=0) val_37 = opset18.Constant(value_ints=[1]) slice_2 = opset18.Slice(position_ids, val_30, val_33, val_36, val_37) - dim__5 = opset18.Constant(value_int=1) + dim__5 = opset18.Constant(value_ints=[1]) dim_0__5 = opset18.Cast(dim__5, to=7) unsqueeze_2 = opset18.Unsqueeze(slice_2, dim_0__5) val_38 = opset18.Cast(0, to=7) @@ -160,10 +160,10 @@ def main_graph( val_71 = opset18.Cast([1, 30, 32, 64], to=7) view_12 = opset18.Reshape(view_9, val_71, allowzero=0) transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) - dim__8 = opset18.Constant(value_int=1) + dim__8 = opset18.Constant(value_ints=[1]) dim_0__8 = opset18.Cast(dim__8, to=7) unsqueeze_3 = opset18.Unsqueeze(_to_copy_4, dim_0__8) - dim__9 = opset18.Constant(value_int=1) + dim__9 = opset18.Constant(value_ints=[1]) dim_0__9 = opset18.Cast(dim__9, to=7) unsqueeze_4 = opset18.Unsqueeze(_to_copy_5, dim_0__9) mul_5 = transpose_1 * unsqueeze_3 @@ -222,10 +222,10 @@ def main_graph( add_2 = mul_7 + mul_8 cat_3 = opset18.Concat(past_key_values_0_0, add_2, axis=-2) cat_4 = opset18.Concat(past_key_values_0_1, transpose_3, axis=-2) - dim__10 = opset18.Constant(value_int=0) + dim__10 = opset18.Constant(value_ints=[0]) dim_0__10 = opset18.Cast(dim__10, to=7) unsqueeze_5 = opset18.Unsqueeze(mul, dim_0__10) - dim__11 = opset18.Constant(value_int=1) + dim__11 = opset18.Constant(value_ints=[1]) dim_0__11 = opset18.Cast(dim__11, to=7) unsqueeze_6 = opset18.Unsqueeze(unsqueeze_5, dim_0__11) val_114 = opset18.Cast(0, to=7) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index cba06d2fb7..8e6ec1d9da 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -148,8 +148,8 @@ def pattern( sin = op.Sin(emb) if self._cast: sin = op.Cast(sin, to=dtype) - cos_4d = op.Unsqueeze(cos, 1) # convert - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) # convert + sin_4d = op.Unsqueeze(sin, [1]) return op.RotaryEmbedding( x, cos_4d, diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py index 48842aa429..4245916c64 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py @@ -45,7 +45,7 @@ def test_cos_sin_fusion(self, name, test_data_constructor): original_outputs = ort_run("original", model, inputs) count = fuse_rotary_embedding(model) self.assertGreater(count, 0) - count = fuse_cos_sin_cache(model) + count = fuse_cos_sin_cache(model, debug=True) self.assertGreater(count, 0) new_outputs = ort_run("optimized", model, inputs) assert_allclose(new_outputs, original_outputs) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 907ffe27bc..bf883c58bc 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -223,7 +223,7 @@ def pattern( key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) # Concat with past_key is optional: key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope]) - key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, [2]) key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) key_seq_BHTDh = op.Reshape( key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"] @@ -234,7 +234,7 @@ def pattern( value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) # Concat with past_value is optional: value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh]) - value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, [2]) value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) value_seq_BHTDh = op.Reshape( value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 038749c017..1a79b9c29f 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -195,11 +195,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin): value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) # Now, expand from shared heads to all heads - key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2]) key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2]) value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) @@ -527,11 +527,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal value_seq_BHkvSkvDh = value_BHkvSDh # Now, expand from shared heads to all heads - key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2]) key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2]) value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) diff --git a/onnxscript/rewriter/rules/common/_no_op_test.py b/onnxscript/rewriter/rules/common/_no_op_test.py index 7815473e34..2c2f9e6e2b 100644 --- a/onnxscript/rewriter/rules/common/_no_op_test.py +++ b/onnxscript/rewriter/rules/common/_no_op_test.py @@ -15,6 +15,11 @@ def _check(self, model_text: str) -> None: self.assertEqual(count, 1) self.assertEqual(model.graph[-1].op_type, "Identity") + def _check_no_optimization(self, model_text: str) -> None: + model = ir.from_onnx_text(model_text) + count = _no_op.rules.apply_to_model(model) + self.assertEqual(count, 0) + @parameterized.parameterized.expand( [ ("float one input", "float[M]", "value_float=1.0", "one, input"), @@ -195,6 +200,17 @@ def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: st ) # TODO: Test the negative cases + def test_broadcast_is_not_eliminated(self): + model_text = """ + + agraph (float[M] input) => (float[1, 1, M] output) + + { + output = Add(zero, input) + } + """ + self._check_no_optimization(model_text) + if __name__ == "__main__": unittest.main() From 93783ee94e4b951cdb6ab92fa8ad9d8240b3212a Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Mon, 3 Nov 2025 13:53:23 -0800 Subject: [PATCH 094/123] Capture rewrite rule name as metadata (#2675) Capture rewrite rule name as metadata to simplify debugging issues from rewrites. This is just a basic version. TODO / Extensions: * Sometimes we apply a sequence of rewrite-rules one after another, to perform complex fusions. This currently records only the last rule applied. * This can be solved when we merge metadata from original nodes into new nodes. (See https://github.com/microsoft/onnxscript/pull/2375 ) * May be useful standardize on a single ONNX metadata key for "source" info (that can be used by torchlib/other exporters/rewriter/optimizer etc.) --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/__init__.py | 2 ++ onnxscript/rewriter/_rewrite_rule.py | 12 +++++++++ onnxscript/rewriter/pattern_test.py | 39 ++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 75f43bf3ea..78eb4398f3 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -16,6 +16,7 @@ "RewriterContext", "MatchingTracer", "MatchStatus", + "RULE_NAME_TAG", ] import onnx @@ -25,6 +26,7 @@ from onnxscript.rewriter import pattern from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus from onnxscript.rewriter._rewrite_rule import ( + RULE_NAME_TAG, RewriterContext, RewriteRule, RewriteRuleClassBase, diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 8964230fe0..9c88aa848e 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -25,6 +25,11 @@ RewriterContext = _tape.Builder +# TODO(rama): Standardize metadata property keys. May be worth standardizing at ONNX level for +# source/producer metadata. + +RULE_NAME_TAG = "pkg.onnxscript.rewriter.rule_name" + @dataclasses.dataclass class ReplacementSubgraph: @@ -719,6 +724,13 @@ def _apply_to_graph_or_function( _ir_utils.display_nodes(delta.new_nodes) print("++++End Replacement Nodes++++") + # Capture rewrite rule name as metadata. + # TODO(rama): This is just a basic version. We may wish to compose "source" metadata + # from multiple rules in future. + if rule.name: + for n in delta.new_nodes: + n.metadata_props[RULE_NAME_TAG] = rule.name + convenience.replace_nodes_and_values( graph_or_function, node, diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0a29080b4d..f296b5320c 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -10,6 +10,7 @@ import onnx.parser import onnxscript.optimizer +import onnxscript.rewriter from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op from onnxscript.rewriter import pattern @@ -936,6 +937,44 @@ def add_pattern(op, x, y): match_result = rule_pattern.match(model, model.graph, add_nodes[2]) self.assertFalse(bool(match_result)) + def test_rule_name_metadata(self): + """Test that RewriteRule carries name metadata.""" + + class ReciprocalMulRule(pattern.RewriteRuleClassBase): + def __init__(self, name: str | None = None): + super().__init__(name) + + def pattern(self, op, x, y): + return (1 / x) * y + + def rewrite(self, op, x, y): + return op.Div(y, x) + + @script() + def test_script(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: + return op.Mul(op.Div(op.Constant(value_float=1.0), x), y) + + rule = ReciprocalMulRule.rule(name="ReciprocalMulToDiv") + model_proto = test_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + for node in model.graph: + if node.op_type == "Div": + tag = onnxscript.rewriter.RULE_NAME_TAG + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulToDiv") + + # By default, the rule name is the class name (if not provided) + rule = ReciprocalMulRule.rule() + model_proto = test_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + for node in model.graph: + if node.op_type == "Div": + tag = onnxscript.rewriter.RULE_NAME_TAG + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulRule") + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): From 1a27df145b7ec03da7d316a38c2cb005cf0a45b7 Mon Sep 17 00:00:00 2001 From: Om Biradar <152481302+ombrdr47@users.noreply.github.com> Date: Wed, 5 Nov 2025 00:10:34 +0530 Subject: [PATCH 095/123] feat: implement LSTM and GRU operators for torchlib (#2674) Implement aten_lstm and aten_gru operators to enable torch.onnx.export for PyTorch LSTM and GRU layers. This addresses issue #2546. Key features: - Full support for multi-layer RNNs (num_layers > 1) - Bidirectional support (forward and backward directions) - Handles both biased and non-biased configurations - batch_first parameter support with automatic transposition - Dropout support between layers (nondeterministic seeded) - Proper gate reordering for ONNX compatibility: * LSTM: PyTorch [i,f,g,o] -> ONNX [i,o,f,g] * GRU: PyTorch [r,z,n] -> ONNX [z,r,n] Implementation details: - Uses ONNX LSTM/GRU operators with proper parameter formatting - Handles weight matrix transposition and reshaping - Correctly concatenates biases using op.Concat - Processes each layer independently with proper state management - Returns outputs in PyTorch-compatible format Closes: #2546 Also resolves: - pytorch/pytorch#120626 (GRU export) - pytorch/pytorch#123089 (LSTM export) - pytorch/pytorch#164834 (GRU dynamo export) --- .../function_libs/torch_lib/ops/core.py | 392 ++++++++++++++++++ .../function_libs/torch_lib/e2e_ops_tests.py | 104 +++++ 2 files changed, 496 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index be30520878..96f64bbb8a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3761,6 +3761,192 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::gru.input", trace_only=True) +def aten_gru( + input: TFloat, + hx: TFloat, + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat]: + """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + hidden_size = op.Shape(hx, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + + for layer_idx in range(num_layers): + # Extract hidden state for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX GRU + # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] + # PyTorch provides: W_ih shape [3*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[ + dir_param_start + ] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] + W_hh = params[ + dir_param_start + 1 + ] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] + + # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] + # Split into individual gates + W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + W_ih_reordered = op.Concat( + W_iz, W_ir, W_in, axis=0 + ) # [3*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hz, W_hr, W_hn, axis=0 + ) # [3*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 3*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] + b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] + + # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] + b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + b_ih_reordered = op.Concat( + b_iz, b_ir, b_in, axis=0 + ) # [3*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hz, b_hr, b_hn, axis=0 + ) # [3*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [6*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) + + # Call ONNX GRU operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = hx.shape[2] + + if B is not None: + Y, Y_h = op.GRU( + current_input, + W, + R, + B, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h = op.GRU( + current_input, + W, + R, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden state + output_h_list.append(Y_h) + + # Concatenate all layer outputs + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h + + @torch_op(("_operator::getitem", "aten::getitem")) def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: return op.SequenceAt(self, i) @@ -4991,6 +5177,212 @@ def aten_lstm_mps_backward( raise NotImplementedError() +@torch_op("aten::lstm.input", trace_only=True) +def aten_lstm( + input: TFloat, + hx: Sequence[TFloat], + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat, TFloat]: + """lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)""" + + # Extract initial hidden and cell states + initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size] + initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size] + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + hidden_size = op.Shape(initial_h, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + output_c_list = [] + + for layer_idx in range(num_layers): + # Extract hidden and cell states for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0]) + layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX LSTM + # ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size] + # PyTorch provides: W_ih shape [4*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[ + dir_param_start + ] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] + W_hh = params[ + dir_param_start + 1 + ] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + # Split into individual gates + W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + W_ih_reordered = op.Concat( + W_ii, W_io, W_if, W_ig, axis=0 + ) # [4*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hi, W_ho, W_hf, W_hg, axis=0 + ) # [4*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 4*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[ + dir_param_start + 2 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] + b_hh = params[ + dir_param_start + 3 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + b_ih_reordered = op.Concat( + b_ii, b_io, b_if, b_ig, axis=0 + ) # [4*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hi, b_ho, b_hf, b_hg, axis=0 + ) # [4*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [8*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) + + # Call ONNX LSTM operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = initial_h.shape[2] + + if B is not None: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + B, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden and cell states + output_h_list.append(Y_h) + output_c_list.append(Y_c) + + # Concatenate all layer outputs + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) + final_c = ( + output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) + ) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h, final_c + + @torch_op( ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 24ccaf4b40..f74dda699d 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -302,6 +302,110 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_lstm_unidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_bidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_multilayer(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_unidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_bidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_multilayer(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From d80575dce25e8a3db2f697bf1f5c6d15e4b91643 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 4 Nov 2025 17:13:40 -0800 Subject: [PATCH 096/123] Keep creating constants when constants are folded inside ir.Function (#2679) Fixes #2673 --------- Co-authored-by: Justin Chu --- onnxscript/optimizer/_constant_folding.py | 100 ++++++++++++------ .../optimizer/_constant_folding_test.py | 20 ++++ 2 files changed, 90 insertions(+), 30 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 2c6d9b46ff..03536cc9ce 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1050,46 +1050,79 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - def new_initializer(self, node: ir.Node, array) -> ir.Value | None: - original_value = node.outputs[0] - if not isinstance(array, np.ndarray): - # ONNX does not have a way to represent non-tensor constants, eg. a sequence. - # So, a constant-value of type sequence is not folded, but it can be used - # to optimize subsequent operations when possible. + def _prepare_folded_tensor( + self, node: ir.Node, output_name: str, output_array: np.ndarray | Any + ) -> ir.Tensor | None: + """ + Shared helper for constant/init creation: + - Validates the folded Python value is a numpy ndarray. + - Wraps it in an ir.Tensor and names it. + - Applies output_size_limit logic with input-usage compensation. + Returns the ir.Tensor or None if it should be skipped. + """ + if not isinstance(output_array, np.ndarray): logger.info( "Skip storing constant folded value %s due to unsupported type %s.", - original_value.name, - type(array), + output_name, + type(output_array), ) return None - tensor = ir.tensor(array) - tensor.name = original_value.name - initializer = ir.Value( - name=original_value.name, - type=ir.TensorType(ir.DataType(tensor.dtype)), - shape=tensor.shape, # type: ignore[arg-type] - const_value=tensor, - ) + tensor = ir.tensor(output_array) + tensor.name = output_name - if array.size > self.output_size_limit: - # Handle examples like Transpose(weight) to be folded even if the size is large, - # as long as weight has no other uses. This won't increase model size. + # Size gating (shared logic) + if output_array.size > self.output_size_limit: removed_input_size = 0 - for input in node.inputs: - if (input is not None) and (len(input.uses()) == 1): - array = _get_numpy_value(input) - if array is not None: - removed_input_size += array.size - increased_size = array.size - removed_input_size + for input_val in node.inputs: + if (input_val is not None) and (len(input_val.uses()) == 1): + input_array = _get_numpy_value(input_val) + if input_array is not None: + removed_input_size += input_array.size + increased_size = output_array.size - removed_input_size if increased_size > 0: logger.info( - "Skip storing constant folded nvalue %s due to large size %s.", - original_value.name, - array.size, + "Skip storing constant folded array %s due to large size %s.", + output_name, + output_array.size, ) return None + return tensor + + def new_constant(self, node: ir.Node, array: np.ndarray | Any) -> ir.Node | None: + """Create a new Constant node with the given array as its value.""" + original_value = node.outputs[0] + + tensor = self._prepare_folded_tensor(node, original_value.name, array) + if tensor is None: + return None + + logger.debug( + "New constant for value %s dtype: %s shape: %s", + original_value.name, + array.dtype, + array.shape, + ) + + node = ir.Node("", "Constant", inputs=[], attributes=(ir.AttrTensor("value", tensor),)) + return node + + def new_initializer(self, node: ir.Node, array: np.ndarray | Any) -> ir.Value | None: + """Create a new initializer value with the given array as its value.""" + original_value = node.outputs[0] + + tensor = self._prepare_folded_tensor(node, original_value.name, array) + if tensor is None: + return None + + initializer = ir.Value( + name=original_value.name, + type=ir.TensorType(ir.DataType(tensor.dtype)), + shape=tensor.shape, # type: ignore[arg-type] + const_value=tensor, + ) + logger.debug( "New Initializer for value %s dtype: %s shape: %s", original_value.name, @@ -1099,7 +1132,7 @@ def new_initializer(self, node: ir.Node, array) -> ir.Value | None: return initializer - def process_node(self, node: ir.Node) -> Replacement | None: + def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None: """Process a node and return a Replacement if the node can be replaced.""" for i, value in enumerate(node.inputs): sym_value = self._state.get_sym_value(value) @@ -1252,6 +1285,12 @@ def convert(av): if outputs is None: return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): + # We don't support initializers in functions, so we need to create Constant nodes + if is_function: + replacement = self.new_constant(node, outputs) + if replacement is None: + return None + return Replacement(replacement.outputs, [replacement]) new_initializer_value = self.new_initializer(node, outputs) if new_initializer_value is None: return None @@ -1301,7 +1340,8 @@ def visit_attribute(self, attr: ir.Attr) -> None: self.visit_graph(graph) def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None: - replacement = self.process_node(node) + is_function = isinstance(root, ir.Function) + replacement = self.process_node(node, is_function=is_function) if replacement is None: # No change. Process attributes. for attr in node.attributes.values(): diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 96a143f81a..ae5c9901bd 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -721,6 +721,26 @@ def test_attribute_reference(self): optimized = self._fold(model) self.assertEqual(len(optimized.graph), 2) + def test_constant_folding_creates_constant_nodes_in_function(self): + model = """ + + model (float x) => (float return_val) { + return_val = this.function (x) + } + + function (x) => (return_val) { + tmp = Constant () + tmp_0 = Cast (tmp) + return_val = Sub (tmp_0, x) + } + """ + optimized = self._fold(model) + self.assertEqual(len(optimized.functions), 1) + for func in optimized.functions.values(): + # Ensure that constant folding has created constant nodes in the function + constant_nodes = [n for n in func.graph if n.op_type == "Constant"] + self.assertEqual(len(constant_nodes), 1) + if __name__ == "__main__": unittest.main() From 8845fb22911efda49aa22770ddafbfce48d38821 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Wed, 5 Nov 2025 08:55:26 -0800 Subject: [PATCH 097/123] Avoid initializer name collision in _fuse_batchnorm.py (#2680) Fixes https://github.com/pytorch/pytorch/issues/166797 The original naming collides when there are multiple matched patterns sharing the same parent node. This PR changes the naming to depend on their own Conv weight name, which should be non-duplicated identifier. ~~NOTE: I don't know if my understanding is correct. It seems x is an input of the pattern, which x.name + "_bias" collides with `max_pool` bias (see the pic in the original issue)? If we check the output model after _fuse_batchnorm.py, the bias would be correct with a name `val_17` (the name may be collided and given by NameAuthority?). However, when the following rule _remove_optional_bias tries to fetch the bias, it would see all zero for some reasons.~~ --- .../rewriter/rules/common/_fuse_batchnorm.py | 5 +- .../rules/common/_fuse_batchnorm_test.py | 58 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py index 9d8b8f23f4..e3298ffbd8 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -68,7 +68,10 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu bias_name = inbound_node.inputs[2].name else: original_bias = np.zeros_like(input_mean) - bias_name = x.name + "_bias" + # Use inbound input 1 (should be weight) to derive a name for the bias + # to avoid name collision on initializer creation when there are multiple patterns + # sharing the same parent nodes. + bias_name = inbound_node.inputs[1].name + "_bias" fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta) return op.op( diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py index 3e617340ff..2007033ef6 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py @@ -253,6 +253,64 @@ def test_fuse_batchnorm_graph_inputs(self): # No changes were applied as W is a graph input self.assertEqual(count, 0) + def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y1, float [N, ?, ?, ?] Y2) + { + X1 = MaxPool(X) + X2 = Conv(X1, W1) + Y1 = BatchNormalization(X2, gamma_64, beta_64, input_mean_64, input_var_64) + X3 = Conv(X1, W2) + Y2 = BatchNormalization(X3, gamma_256, beta_256, input_mean_256, input_var_256) + } + """) + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W1" + ), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="gamma_64" + ), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="beta_64" + ), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="input_mean_64" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(64)).astype(np.float32), name="input_var_64" + ), + onnx.numpy_helper.from_array( + np.random.randn(256, 32, 3, 3).astype(np.float32), name="W2" + ), + onnx.numpy_helper.from_array( + np.random.randn(256).astype(np.float32), name="gamma_256" + ), + onnx.numpy_helper.from_array( + np.random.randn(256).astype(np.float32), name="beta_256" + ), + onnx.numpy_helper.from_array( + np.random.randn(256).astype(np.float32), name="input_mean_256" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(256)).astype(np.float32), name="input_var_256" + ), + ] + model_proto.graph.initializer.extend(initializers) + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + count = _fuse_batchnorm.rules.apply_to_model(model) + + # Applied twice, once for each BatchNorm + self.assertEqual(count, 2) + # it should have different bias names for the two fused Conv nodes + conv_nodes = [node for node in model.graph if node.op_type == "Conv"] + self.assertEqual(len(conv_nodes), 2) + bias_names_1 = conv_nodes[0].inputs[2].name + bias_names_2 = conv_nodes[1].inputs[2].name + self.assertNotEqual(bias_names_1, bias_names_2) + if __name__ == "__main__": unittest.main() From 971f9bb6f6af6f2563ba1592ccc18ab096713678 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 5 Nov 2025 11:11:02 -0800 Subject: [PATCH 098/123] Merge metadata props in rewriter (#2682) Introduce basic infrastructure for merging metadata props (for use in rewriter/optimizer etc.) A basic version added to rewriter. TODO: * Allow user control over this: should this be configurable at the level of a RewriteRuleSet? Or, perhaps at a global level (given that ORT fusions uses a number of rewrite-rule-sets for various reasons)? * This [line](https://github.com/microsoft/onnxscript/blob/1a27df145b7ec03da7d316a38c2cb005cf0a45b7/onnxscript/rewriter/ort_fusions/_core.py#L148) should also be factored out or made user-controllable in some fashion. Otherwise, the metadata gets lost anyway. --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/__init__.py | 2 + onnxscript/rewriter/_rewrite_rule.py | 15 +++++ onnxscript/utils/metadata_merger.py | 99 ++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 onnxscript/utils/metadata_merger.py diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 78eb4398f3..fb93bc703f 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -5,6 +5,7 @@ from typing import Sequence, TypeVar, Union __all__ = [ + "merge_metadata", "pattern", "rewrite", "RewritePass", @@ -31,6 +32,7 @@ RewriteRule, RewriteRuleClassBase, RewriteRuleSet, + merge_metadata, ) from onnxscript.rewriter.rules.common import ( _basic_rules, diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 9c88aa848e..7c73a738ce 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -18,6 +18,7 @@ import onnxscript.rewriter._ir_utils as _ir_utils import onnxscript.rewriter._matcher as _matcher import onnxscript.rewriter._pattern_ir as _pattern_ir +import onnxscript.utils.metadata_merger as metadata_merger from onnxscript import ir from onnxscript.ir import _tape, convenience @@ -614,6 +615,15 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str: overload += 1 +_default_metadata_merger: metadata_merger.MetadataMerger = metadata_merger.MetadataMerger( + {RULE_NAME_TAG: metadata_merger.comma_separator_merger} +) + +# TODO(rama): Generalize this to support custom metadata mergers. For now, we just allow +# enabling/disabling the default merger. +merge_metadata: bool = True + + class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if not rules: @@ -740,6 +750,11 @@ def _apply_to_graph_or_function( delta.new_outputs, ) + if merge_metadata: + _default_metadata_merger.copy_merged_metadata( + delta.match.nodes, delta.new_nodes + ) + count += 1 break diff --git a/onnxscript/utils/metadata_merger.py b/onnxscript/utils/metadata_merger.py new file mode 100644 index 0000000000..121d8db8c8 --- /dev/null +++ b/onnxscript/utils/metadata_merger.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Merging metadata_props""" + +from __future__ import annotations + +from typing import Callable, Iterable + +import onnx_ir as ir + +# Utilities for merging metadata properties, represented as strings. +# The merging-logic will take care of special cases like missing metadata or +# empty string metadata, and so the functions defined below need not handle +# special cases like empty string. (This does assume that an empty string is +# the same as no metadata, which is a reasonable assumption for most metadata.) + +StringMerger = Callable[[str, str], str] + + +def overwrite(_: str, new: str) -> str: + return new + + +def join(separator: str) -> StringMerger: + """Creates a StringMerger that joins two strings with the given separator. + + Args: + separator (str): The separator to use when joining the strings. + + Returns: + StringMerger: A function that joins two strings with the specified separator. + """ + + def merger(first: str, second: str) -> str: + return f"{first}{separator}{second}" + + return merger + + +comma_separator_merger = join(", ") + + +class MetadataMerger: + """Merges metadata properties using specified merging logic. + + Attributes: + mergers: A mapping from metadata property keys to their corresponding merging functions. + default: The default merging function to use when a specific key does not have a defined merger. + If None, the first value is used. (Specify `overwrite` to always use the second value.) + """ + + def __init__( + self, mergers: dict[str, StringMerger], default: StringMerger | None = None + ) -> None: + self.mergers = mergers + self.default = default + + def update_dict(self, updated: dict[str, str], updates: dict[str, str]) -> None: + """Updates the first metadata property dictionary with values from the second. + + Args: + updated: The metadata dictionary to be updated. + updates: The updates metadata dictionary. + """ + for key, new_value in updates.items(): + if new_value == "": + continue + if (key in updated) and ((updated_value := updated[key]) != ""): + merger = self.mergers.get(key, self.default) + if merger is not None: + updated[key] = merger(updated_value, new_value) + else: + updated[key] = new_value + + def copy_merged_metadata( + self, from_nodes: Iterable[ir.Node], to: ir.Node | Iterable[ir.Node] + ) -> None: + """Merges metadata from multiple nodes and assigns it to one or more target nodes. + + Args: + from_nodes: The source nodes from which to merge metadata. + to: The target node(s) to which the merged metadata will be assigned. + """ + if isinstance(to, ir.Node): + updated = to.metadata_props + for node in from_nodes: + self.update_dict(updated, node.metadata_props) + elif len(to) == 1: + # Handle single node in iterable case + target_node = next(iter(to)) + updated = target_node.metadata_props + for node in from_nodes: + self.update_dict(updated, node.metadata_props) + else: + merged_metadata: dict[str, str] = {} + for node in from_nodes: + self.update_dict(merged_metadata, node.metadata_props) + for target_node in to: + self.update_dict(target_node.metadata_props, merged_metadata) From 478acf741471a4764d507d52a4ca6b544a63fc86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Mo=C3=9Fburger?= Date: Thu, 6 Nov 2025 19:31:21 +0100 Subject: [PATCH 099/123] [torchlib] Fix unbind.int if num_outputs=1 (#2684) This fixes the issue of ``` return [op.Squeeze(out, [dim]) for out in outputs] ^^^^^^^ TypeError: 'SymbolicTensor' object is not iterable ``` when trying to export LSTM modules in `torch`. This also already appeared in torch issues in https://github.com/pytorch/pytorch/issues/126339 The core seems to be the changes in #2597. To my understanding the split returns a single `SymbolicTensor` instead of a sequence when `dim=1`. The fix implemented here is the casting of the return type to a list. I struggled with writing a test that reproduces this nicely in here, any guidance on that would be welcome. --------- Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 96f64bbb8a..767dffacf7 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8957,7 +8957,12 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): # We can create a definitive split op if the input shape is static # Only torch>=2.7 supports correctly generating the correct number of outputs for Split - outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim]) + num_outputs = self.shape[dim] + if num_outputs != 1: + outputs = op.Split(self, axis=dim, num_outputs=num_outputs) + else: + outputs = [self] + return [op.Squeeze(out, [dim]) for out in outputs] return op.SplitToSequence(self, axis=dim, keepdims=False) From ea8cb3ee599a9a2d9dabb573126ac6c128e7255e Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 6 Nov 2025 16:52:32 -0800 Subject: [PATCH 100/123] Add option to clear metadata in ort fusion (#2685) Add option to clear metadata in ort fusion. Otherwise, we lose all metadata after running ORT fusion --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/_core.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index ea7af31b3e..8280b1c39c 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -115,6 +115,7 @@ def optimize_for_ort( config_name: str | None = None, *, debug: bool = False, + clear_metadata: bool = False, ) -> tuple[ir.Model, dict[str, int]]: """ Optimize the model for ORT backend. @@ -128,6 +129,7 @@ def optimize_for_ort( Typically it identifies the Execution Provider (EP) to optimize for. If None, the default configuration will be used. debug: If debug is True, enable pattern matching tracer for debugging. + clear_metadata: If True, clear metadata and doc strings from the model. Returns: A tuple containing: @@ -145,7 +147,6 @@ def optimize_for_ort( passes = ir.passes.Sequential( # Apply the ORT optimization passes. # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 - common_passes.ClearMetadataAndDocStringPass(), # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1), common_passes.RemoveInitializersFromInputsPass(), @@ -154,4 +155,8 @@ def optimize_for_ort( assert passes.in_place result = passes(model) assert result.model is model + + if clear_metadata: + common_passes.ClearMetadataAndDocStringPass()(model) + return model, fusion_count From 70e751ae4a630eb444f80bfd4af435a94ca747f3 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 7 Nov 2025 10:34:00 -0800 Subject: [PATCH 101/123] Implement SDPA via MHA (#2683) Implement SDPA via MHA. This handles the case when earlier fusion rules do not map larger patterns containing SDPA into MHA or GQA or Attention (from ORT contrib ops). It implements SDPA via MHA. --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/_core.py | 2 + onnxscript/rewriter/ort_fusions/sdpa.py | 12 ++++++ onnxscript/rewriter/ort_fusions/sdpa_test.py | 38 ++++++++++++++++--- .../rewriter/ort_fusions/sdpa_via_mha.py | 26 ++++++++++--- 4 files changed, 66 insertions(+), 12 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 8280b1c39c..fa1f0c109b 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -29,6 +29,7 @@ fuse_rotary_embedding, ) from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa +from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha from onnxscript.rewriter.ort_fusions.skip_normalization import ( fuse_skip_layer_normalization, fuse_skip_rms_normalization, @@ -104,6 +105,7 @@ def fuse(func, **kwargs): fusion_count["attention"] = fuse(fuse_attention) fusion_count["gelu"] = fuse(fuse_gelu) fusion_count["bias_gelu"] = fuse(fuse_bias_gelu) + fusion_count["sdpa_via_mha"] = fuse(replace_sdpa_by_mha) # Finally: inline any intermediate fusion functions introduced that were not # consumed by other fusions, and eliminate any remaining unused nodes. optimize(model) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 55b38e9ad4..821537afe5 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -12,6 +12,18 @@ Dim = Union[int, ir.SymbolicDim] +# This file contains a fusion rule that recognizes various patterns of scaled dot-product attention +# (SDPA) implementations and replaces them with a single SDPA op. The SDPA op is a temporary fusion +# op defined in the ai.onnxruntime._fusion domain. Subsequent fusion rules will map it into one +# of the various ops defined in ORT: MHA, GQA, or Attention depending on the input patterns. +# The SDPA is a standard scalar dot-product attention with an optional mask input and scaling factor. +# Currently, it is restricted to query, key, and values of rank 4 with shapes: +# Query: [batch_size, num_heads, seq_len, head_size_qk] +# Key: [batch_size, num_heads, seq_len_kv, head_size_qk] +# or [batch_size, seq_len_kv, num_heads, head_size_qk]) +# Value: [batch_size, num_heads, seq_len_kv, head_size_v] +# The key_format attribute indicates which of the two formats the key uses and can be either "BHSd" or "BSHd". + class SDPA(pattern.RewriteRuleClassBase): _scale: float | None diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index c5326a77b9..3b29418cc6 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -292,20 +292,41 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask): return attn_output +# This tests a scenario where the key is in BSHd format instead of BHSd, which +# happens due to an optimization that fuses two transposes together, the one +# to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first +# transpose down below is different from other test cases. +@script() +def _unmasked_pre_div_sdpa_BSHd_key_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 2, 3, 1]) # BSHd to BHdS + divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + class SDPATestCase: - def __init__(self, script_func, *, with_mask): + def __init__(self, script_func, *, with_mask, BSHd_key=False): self.script_func = script_func self.with_mask = with_mask + self.BSHd_key = BSHd_key def get_onnx_model(self): if not hasattr(self, "_onnx_model"): - qkv_type = FLOAT[B, N, S, H] + qv_type = FLOAT[B, N, S, H] mask_type = FLOAT[B, N, S, S] - input_types = [qkv_type, qkv_type, qkv_type] + k_type = FLOAT[B, S, N, H] if self.BSHd_key else FLOAT[B, N, S, H] + input_types = [qv_type, k_type, qv_type] if self.with_mask: input_types.append(mask_type) model_proto = self.script_func.to_model_proto( - input_types=input_types, output_types=[qkv_type] + input_types=input_types, output_types=[qv_type] ) self._onnx_model = ir.serde.deserialize_model(model_proto) return self._onnx_model @@ -314,7 +335,9 @@ def get_ort_inputs(self): if not hasattr(self, "_ort_inputs"): inputs = { "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), - "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "key": numpy.random.rand(B, S, N, H).astype(numpy.float32) + if self.BSHd_key + else numpy.random.rand(B, N, S, H).astype(numpy.float32), "value": numpy.random.rand(B, N, S, H).astype(numpy.float32), } if self.with_mask: @@ -374,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase): "_custom_multi_scale_pre_mul_sdpa_script", _custom_multi_scale_pre_mul_sdpa_script, ), + ("pre_div_sdpa_BSHd_key", _unmasked_pre_div_sdpa_BSHd_key_script), ] ) def test_sdpa_fusion(self, name, script_func): - test_case = SDPATestCase(script_func, with_mask="masked" in name) + test_case = SDPATestCase( + script_func, with_mask="masked" in name, BSHd_key="BSHd_key" in name + ) model = test_case.get_onnx_model() onnxscript.optimizer.optimize(model) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py index e6484406a9..acbc0705fa 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py @@ -7,43 +7,57 @@ import onnx_ir as ir from onnxscript.rewriter import _fusion_utils, pattern +from onnxscript.rewriter._basics import MatchFailureError Dim = Union[int, ir.SymbolicDim] class SDPAImplementation(pattern.RewriteRuleClassBase): - def pattern(self, op, query, key, value): + def pattern(self, op, query, key, value, key_format): + """Pattern matches any call to SDPA. See sdpa.py for documentation on the SDPA op.""" return op.SDPA( query, key, value, - key_format="BHSd", + key_format=key_format, _allow_other_inputs=True, # Mask is optional _outputs=["sdpa_output"], _domain="ai.onnxruntime._fusion", ) - def check(self, context, query, key, value, sdpa_output): + def check(self, context, query, key, value, key_format, sdpa_output): bindings: dict[str, Dim] = {} _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) - _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) + if key_format.value == "BHSd": + _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) + elif key_format.value == "BSHd": + _fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"]) + else: + raise MatchFailureError( + f"Unexpected key_format value: {key_format.value}", key_format + ) + self._num_heads = bindings["H"] if not isinstance(self._num_heads, int): return False self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed return isinstance(self._num_heads, int) - def rewrite(self, op, query, key, value, sdpa_output): + def rewrite(self, op, query, key, value, key_format, sdpa_output): sdpa_node = sdpa_output.producer() scale = sdpa_node.attributes.get("scale", None) to_3d_shape = op.Constant(value_ints=[0, 0, -1]) to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1]) query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape) - key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape) + if key_format.value == "BHSd": + key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) + else: # BSHd + key_3d = op.Reshape(key, to_3d_shape) + inputs = [query_3d, key_3d, value_3d] if len(sdpa_node.inputs) > 3: mask = sdpa_node.inputs[3] From a1be5c804361228a687a927b67904e16b6a6b265 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 7 Nov 2025 12:44:04 -0800 Subject: [PATCH 102/123] [torchlib] Fix mod on SymInt (#2686) Fix the error: `: 'int' object has no attribute 'dtype'` by splitting the implementation out for `operator.*` ops. Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 767dffacf7..2cbecdcfc2 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3732,7 +3732,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor"), trace_only=True, ) def aten_ge(self: TTensor, other: TTensor) -> BOOL: @@ -3749,6 +3749,12 @@ def aten_ge(self: TTensor, other: TTensor) -> BOOL: return op.GreaterOrEqual(self, other) +@torch_op("_operator::ge", trace_only=True) +def operator_ge(self: TTensor, other: TTensor) -> BOOL: + # operator.ge for SymInt + return op.GreaterOrEqual(self, other) + + def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: """geqrf(Tensor self) -> (Tensor a, Tensor tau)""" @@ -4058,7 +4064,7 @@ def aten_gru_cell( @torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), + ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor"), trace_only=True, ) def aten_gt(self: TTensor, other: TTensor) -> BOOL: @@ -4076,6 +4082,12 @@ def aten_gt(self: TTensor, other: TTensor) -> BOOL: return op.Greater(self, other) +@torch_op("_operator::gt", trace_only=True) +def operator_gt(self: TTensor, other: TTensor) -> BOOL: + # operator.gt for SymInt + return op.Greater(self, other) + + @torch_op("aten::hamming_window", trace_only=True) def aten_hamming_window( window_length: int, @@ -4891,7 +4903,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), + ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor"), trace_only=True, ) def aten_le(self: TTensor, other: TTensor) -> BOOL: @@ -4909,6 +4921,12 @@ def aten_le(self: TTensor, other: TTensor) -> BOOL: return op.LessOrEqual(self, other) +@torch_op("_operator::le", trace_only=True) +def operator_le(self: TTensor, other: TTensor) -> BOOL: + # operator.le for SymInt + return op.LessOrEqual(self, other) + + @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor: """lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor""" @@ -5384,7 +5402,7 @@ def aten_lstm( @torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), + ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor"), trace_only=True, ) def aten_lt(self: TTensor, other: TTensor) -> BOOL: @@ -5401,6 +5419,12 @@ def aten_lt(self: TTensor, other: TTensor) -> BOOL: return op.Less(self, other) +@torch_op("_operator::lt", trace_only=True) +def operator_lt(self: TTensor, other: TTensor) -> BOOL: + # operator.lt for SymInt + return op.Less(self, other) + + def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: """lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor""" @@ -7468,9 +7492,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op( - ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True -) +@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -7486,6 +7508,12 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: return op.Sub(self, op.Mul(rounded_quotient, other)) +@torch_op("_operator::mod", trace_only=True) +def operator_mod(self: TTensor, other: TTensor) -> TTensor: + # Modulus operator % on SymInt + return op.Mod(self, other) + + def aten_rename(self: TensorType, names: Optional[str]) -> TensorType: """rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)""" From 10e541ef1d333038b8b44cbe8801c26adb04d404 Mon Sep 17 00:00:00 2001 From: Tomoaki KOBAYASHI <37573952+moatom@users.noreply.github.com> Date: Mon, 10 Nov 2025 02:34:30 +0900 Subject: [PATCH 103/123] Implement aten.stft (#2645) Fixed https://github.com/pytorch/pytorch/issues/147052 ```bash $ python -m pytest tests/function_libs/torch_lib/ops_test.py -k ops_aten_stft ====================================================================================================================================================================================================== test session starts ====================================================================================================================================================================================================== platform linux -- Python 3.13.1, pytest-8.4.1, pluggy-1.6.0 Using --randomly-seed=371864411 rootdir: /home/moatom/github/onnxscript configfile: pyproject.toml plugins: randomly-3.16.0, xdist-3.8.0, subtests-0.14.2, cov-6.2.1, hypothesis-6.138.2 collected 2158 items / 2154 deselected / 4 selected tests/function_libs/torch_lib/ops_test.py s..x [100%] ======================================================================================================================================================================================================= warnings summary ======================================================================================================================================================================================================== onnxscript/converter.py:457: 429 warnings tests/function_libs/torch_lib/ops_test.py: 15 warnings /home/moatom/github/onnxscript/onnxscript/converter.py:457: DeprecationWarning: Expression.__init__ got an unexpected keyword argument 'lineno'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15. expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) onnxscript/converter.py:457: 429 warnings tests/function_libs/torch_lib/ops_test.py: 15 warnings /home/moatom/github/onnxscript/onnxscript/converter.py:457: DeprecationWarning: Expression.__init__ got an unexpected keyword argument 'col_offset'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15. expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32 tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32 tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32 /home/moatom/github/onnxscript/tests/function_libs/torch_lib/ops_test_common.py:329: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword value = np.array(value.cpu()) -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ==================================================================================================================================================================================================== short test summary info ==================================================================================================================================================================================================== SKIPPED [1] tests/function_libs/torch_lib/ops_test.py:101: Traced functions does not have a function proto =================================================================================================================================================================== 2 passed, 1 skipped, 2154 deselected, 1 xfailed, 891 warnings, 7 subtests passed in 4.42s =================================================================================================================================================================== ``` --- .../function_libs/torch_lib/ops/core.py | 97 +++++++++++++++++++ .../function_libs/torch_lib/e2e_ops_tests.py | 69 +++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 8 ++ 3 files changed, 174 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2cbecdcfc2..09704199f9 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8548,6 +8548,103 @@ def aten_std_mean_correction( return op.Sqrt(var), mean +def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: + left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), win_length) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + window_list = op.Expand(op.Constant(value_ints=[1]), win_length) + return op.Concat(left_win, window_list, right_win, axis=0) + + +def _create_window_from_n_fft(n_fft: int) -> TFloat: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) + return window + + +def _normalize_fft_result(signal: TFloat, result: TFloat, n_fft: int) -> TFloat: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) + result = op.Div(result, sqrt_nfft) + return result + + +@torch_op("aten::stft", trace_only=True) +def aten_stft( + self: TFloat, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[TFloat] = None, + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, +) -> TFloat: + """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" + + # NOTE: regardless of the value of return_complex, we always return a real representation. + del return_complex + + # Get STFT sizes + if hop_length is None: + # core dump + # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) + hop_length = n_fft // 4 + frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) + + # Pre-process input if needed + is_signal_rank1 = len(self.shape) == 1 + if is_signal_rank1: + # Add a batch dimension + self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + if window is not None and window.shape[0] is not None: + # first dimension + n_win = op.Shape(window, start=0, end=1) + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), n_win) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + right_win = op.CastLike(right_win, window) + left_win = op.CastLike(left_win, window) + window = op.Concat(left_win, window, right_win, axis=0) + elif window is None: + if win_length is not None: + window = _create_window_from_win_length(win_length, n_fft) + else: + window = _create_window_from_n_fft(n_fft) + + if onesided is None or onesided: + onesided = 1 + else: + onesided = 0 + window = op.CastLike(window, self) + result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided) + result = op.Transpose(result, perm=[0, 2, 1, 3]) + # Remove batch dimension, if needed + if is_signal_rank1: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + + # Normalize, if needed + if normalized: + result = _normalize_fft_result(self, result, n_fft) + + return result + + @torch_op( ( "aten::sub.Tensor", diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index f74dda699d..cb272a98a6 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -406,6 +406,75 @@ def forward(self, x): onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) _testing.assert_onnx_program(onnx_program) + def test_aten_stft_1(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=4, return_complex=True) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=4, return_complex=False) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_3(self): + class Model(torch.nn.Module): + def forward(self, x): + window = torch.ones(16, dtype=torch.float32) + return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False) + + x = torch.randn(100, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_4(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft( + x, + n_fft=4, + hop_length=1, + win_length=4, + center=True, + onesided=True, + return_complex=True, + ) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf31..4ef7550b6e 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1760,6 +1760,14 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), + TorchLibOpInfo( + "ops.aten.stft", # Custom from extra_opinfo + core_ops.aten_stft, + tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + ).xfail( + dtypes=(torch.float16,), + reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", + ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, From 1dd9d049705c508b173ed2523aeede0a0fa95d99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 10 Nov 2025 16:35:17 +0100 Subject: [PATCH 104/123] Add converter for unique_consecutive (#2694) Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> Co-authored-by: G. Ramalingam Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- .../function_libs/torch_lib/ops/core.py | 47 ++++++++++++++++++- .../function_libs/torch_lib/e2e_ops_tests.py | 45 ++++++++++++++++++ 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 09704199f9..a25015b232 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -51,6 +51,7 @@ from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +_INT32_MAX = 2147483647 _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi @@ -9183,15 +9184,57 @@ def aten_unfold_copy(self: TensorType, dimension: int, size: int, step: int) -> raise NotImplementedError() +@torch_op("aten::unique_consecutive", trace_only=True) def aten_unique_consecutive( - self: TensorType, + x: TensorType, return_inverse: bool = False, return_counts: bool = False, dim: Optional[int] = None, ) -> tuple[TensorType, TensorType, TensorType]: """unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)""" + assert x.dtype in {INT64.dtype, INT32.dtype}, ( + "unique_consecutive not implemented for other type than int32, int64" + ) + rank_x = len(x.shape) - raise NotImplementedError() + zero = op.Constant(value=ir.tensor([0], dtype=x.dtype)) + zero64 = op.Constant(value=ir.tensor([0], dtype=INT64.dtype)) + minus_one = op.Constant(value=ir.tensor([-1], dtype=INT64.dtype)) + + if dim is None: + if rank_x != 1: + x = op.Reshape(x, minus_one) + else: + assert rank_x == 1 and dim == 0, ( + f"Not implemented for x={x!r} with rank={rank_x} and dim={dim}." + ) + + lag = op.Concat( + # Hopefully this will never be equal to the first value of the tensor x + # ideally we could do differently but with a higher cost + op.Constant(value=ir.tensor([_INT32_MAX], dtype=x.dtype)), + op.Slice(x, zero64, minus_one, zero64), + axis=0, + ) + eq = op.Equal(x, lag) + diff = op.Not(eq) + res = op.Compress(x, diff, axis=0) + + zero_no_dim = op.Constant(value=ir.tensor(0, dtype=x.dtype)) + one_no_dim = op.Constant(value=ir.tensor(1, dtype=x.dtype)) + one = op.Constant(value=ir.tensor([1], dtype=x.dtype)) + + inverse = op.Sub(op.CumSum(op.Cast(diff, to=x.dtype), zero), one) + shape_x = op.Shape(x) + indices = op.Range(zero_no_dim, op.Squeeze(shape_x), one_no_dim) + points = op.Compress(indices, diff, axis=0) + lagp = op.Concat( + op.Slice(points, one, op.Shape(points), zero), + shape_x, + axis=0, + ) + counts = op.Sub(lagp, points) + return res, inverse, counts @torch_op("aten::_unique", trace_only=True) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index cb272a98a6..1546de59bd 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -406,6 +406,51 @@ def forward(self, x): onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) _testing.assert_onnx_program(onnx_program) + def test_aten_unique_consecutive(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int64) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_unique_consecutive_int32(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int32) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_unique_consecutive_return(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x, return_inverse=True, return_counts=True) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_aten_stft_1(self): class Model(torch.nn.Module): def forward(self, x): From 4042df33ea48edfa5976d260d63a89d231226fdd Mon Sep 17 00:00:00 2001 From: Yuan Yao <99693700+yuanyao-nv@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:28:20 -0800 Subject: [PATCH 105/123] Add missing output_size kwarg to repeat_interleave (#2691) Fixes https://github.com/microsoft/onnxscript/issues/2687 --------- Signed-off-by: Yuan Yao --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a25015b232..326075b2fe 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7539,7 +7539,10 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: @torch_op("aten::repeat_interleave.self_int", trace_only=True) def aten_repeat_interleave_self_int( - self: TensorType, repeats: int, dim: Optional[int] = None + self: TensorType, + repeats: int, + dim: Optional[int] = None, + output_size: Optional[int] = None, ) -> TensorType: """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor From cfb52e23817260dab07f4fd636a19395a0b67da1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:29:26 -0800 Subject: [PATCH 106/123] chore(deps): bump ruff from 0.14.2 to 0.14.3 in /requirements/lintrunner (#2676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [ruff](https://github.com/astral-sh/ruff) from 0.14.2 to 0.14.3.
Release notes

Sourced from ruff's releases.

0.14.3

Release Notes

Released on 2025-10-30.

Preview features

  • Respect --output-format with --watch (#21097)
  • [pydoclint] Fix false positive on explicit exception re-raising (DOC501, DOC502) (#21011)
  • [pyflakes] Revert to stable behavior if imports for module lie in alternate branches for F401 (#20878)
  • [pylint] Implement stop-iteration-return (PLR1708) (#20733)
  • [ruff] Add support for additional eager conversion patterns (RUF065) (#20657)

Bug fixes

  • Fix finding keyword range for clause header after statement ending with semicolon (#21067)
  • Fix syntax error false positive on nested alternative patterns (#21104)
  • [ISC001] Fix panic when string literals are unclosed (#21034)
  • [flake8-django] Apply DJ001 to annotated fields (#20907)
  • [flake8-pyi] Fix PYI034 to not trigger on metaclasses (PYI034) (#20881)
  • [flake8-type-checking] Fix TC003 false positive with future-annotations (#21125)
  • [pyflakes] Fix false positive for __class__ in lambda expressions within class definitions (F821) (#20564)
  • [pyupgrade] Fix false positive for TypeVar with default on Python <3.13 (UP046,UP047) (#21045)

Rule changes

  • Add missing docstring sections to the numpy list (#20931)
  • [airflow] Extend airflow.models..Param check (AIR311) (#21043)
  • [airflow] Warn that airflow....DAG.create_dagrun has been removed (AIR301) (#21093)
  • [refurb] Preserve digit separators in Decimal constructor (FURB157) (#20588)

Server

  • Avoid sending an unnecessary "clear diagnostics" message for clients supporting pull diagnostics (#21105)

Documentation

  • [flake8-bandit] Fix correct example for S308 (#21128)

Other changes

  • Clearer error message when line-length goes beyond threshold (#21072)

Contributors

... (truncated)

Changelog

Sourced from ruff's changelog.

0.14.3

Released on 2025-10-30.

Preview features

  • Respect --output-format with --watch (#21097)
  • [pydoclint] Fix false positive on explicit exception re-raising (DOC501, DOC502) (#21011)
  • [pyflakes] Revert to stable behavior if imports for module lie in alternate branches for F401 (#20878)
  • [pylint] Implement stop-iteration-return (PLR1708) (#20733)
  • [ruff] Add support for additional eager conversion patterns (RUF065) (#20657)

Bug fixes

  • Fix finding keyword range for clause header after statement ending with semicolon (#21067)
  • Fix syntax error false positive on nested alternative patterns (#21104)
  • [ISC001] Fix panic when string literals are unclosed (#21034)
  • [flake8-django] Apply DJ001 to annotated fields (#20907)
  • [flake8-pyi] Fix PYI034 to not trigger on metaclasses (PYI034) (#20881)
  • [flake8-type-checking] Fix TC003 false positive with future-annotations (#21125)
  • [pyflakes] Fix false positive for __class__ in lambda expressions within class definitions (F821) (#20564)
  • [pyupgrade] Fix false positive for TypeVar with default on Python <3.13 (UP046,UP047) (#21045)

Rule changes

  • Add missing docstring sections to the numpy list (#20931)
  • [airflow] Extend airflow.models..Param check (AIR311) (#21043)
  • [airflow] Warn that airflow....DAG.create_dagrun has been removed (AIR301) (#21093)
  • [refurb] Preserve digit separators in Decimal constructor (FURB157) (#20588)

Server

  • Avoid sending an unnecessary "clear diagnostics" message for clients supporting pull diagnostics (#21105)

Documentation

  • [flake8-bandit] Fix correct example for S308 (#21128)

Other changes

  • Clearer error message when line-length goes beyond threshold (#21072)

Contributors

... (truncated)

Commits
  • 8737a2d Bump v0.14.3 (#21152)
  • 3be3a10 [ty] Don't provide completions when in class or function definition (#21146)
  • 13375d0 [ty] Use the top materialization of classes for narrowing in class-patterns f...
  • c0b04d4 [ty] Update "constraint implication" relation to work on constraints between ...
  • 1c7ea69 [flake8-type-checking] Fix TC003 false positive with future-annotations...
  • 9bacd19 [ty] Fix lookup of __new__ on instances (#21147)
  • f0fe6d6 Fix syntax error false positive on nested alternative patterns (#21104)
  • 10bda3d [pyupgrade] Fix false positive for TypeVar with default on Python <3.13 (...
  • e55bc94 [ty] Reachability and narrowing for enum methods (#21130)
  • 1b0ee46 [ty] Use range instead of custom IntIterable (#21138)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.14.2&new-version=0.14.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 2a913e68f4..c2c0c54364 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.14.2 +ruff==0.14.3 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250915 From f1a6ec4ff062ec516662df8a09665971d7b0c23b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:29:34 -0800 Subject: [PATCH 107/123] chore(deps): bump editorconfig-checker from 3.4.0 to 3.4.1 in /requirements/lintrunner (#2677) Bumps [editorconfig-checker](https://github.com/editorconfig-checker/editorconfig-checker.python) from 3.4.0 to 3.4.1.
Commits
  • 62dcf36 Merge pull request #47 from editorconfig-checker/chore/bump-core-3.4.1
  • f906601 chore(release): bump core package to 3.4.1
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=editorconfig-checker&package-manager=pip&previous-version=3.4.0&new-version=3.4.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index c2c0c54364..f7549e39e3 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -8,4 +8,4 @@ types-PyYAML==6.0.12.20250915 # PYLINT pylint==3.3.9 # EDITORCONFIG-CHECKER -editorconfig-checker==3.4.0 +editorconfig-checker==3.4.1 From cba13256fb656920123ad0dd38d5c049c2442c12 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:29:46 -0800 Subject: [PATCH 108/123] chore(deps): bump onnx-weekly from 1.20.0.dev20251027 to 1.21.0.dev20251103 in /requirements/ci (#2678) Bumps [onnx-weekly](https://github.com/onnx/onnx) from 1.20.0.dev20251027 to 1.21.0.dev20251103.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=onnx-weekly&package-manager=pip&previous-version=1.20.0.dev20251027&new-version=1.21.0.dev20251103)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/ci/requirements-onnx-weekly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index 728f319adf..d206c9fcd6 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.20.0.dev20251027 +onnx-weekly==1.21.0.dev20251103 From 97513c747d22adfbb1600fad9232d9280d80dc1f Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 11 Nov 2025 17:44:28 -0800 Subject: [PATCH 109/123] Bump version (#2702) Bump version Signed-off-by: Ganesan Ramalingam --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index b49b25336d..d3532a107e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.6 +0.5.7 From c1bfdfc5baf2dabf09191ad2ab7603ea21cfe614 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 12 Nov 2025 10:33:54 -0800 Subject: [PATCH 110/123] Utility and example for custom op expansion (#2701) A utility and an example showing how onnxscript functions can be used to define function expansions and be used with the inliner to replace calls to the custom function with an expanded subgraph. This is useful to perform certain classes of graph surgery easily. --- .lintrunner.toml | 1 + examples/custom_op_expansion.py | 61 +++++++++++++++++++++++++++++++++ onnxscript/utils/replace.py | 35 +++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 examples/custom_op_expansion.py create mode 100644 onnxscript/utils/replace.py diff --git a/.lintrunner.toml b/.lintrunner.toml index 907f3bfce6..ed937d352c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -39,6 +39,7 @@ include_patterns = [ exclude_patterns = [ 'tests/**', # Skip linting test files for speed # FIXME: Fix typing annotations in these files + 'examples/custom_op_expansion.py', 'onnxscript/converter_test.py', 'onnxscript/converter.py', 'onnxscript/evaluator_test.py', diff --git a/examples/custom_op_expansion.py b/examples/custom_op_expansion.py new file mode 100644 index 0000000000..c261ff18d7 --- /dev/null +++ b/examples/custom_op_expansion.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ruff: noqa + +"""A utility and an example showing how onnxscript functions can be used to define function expansions +and be used with the inliner to replace calls to the custom function with an expanded subgraph. +This is useful to perform certain classes of graph surgery easily. +""" + +import onnx + +import onnxscript +import onnxscript.utils.replace as replace + +script = onnxscript.script +FLOAT = onnxscript.FLOAT +op = onnxscript.values.opset22 +local = onnxscript.values.Opset("local", 1) + + +# Example Model: Actual models can come from ModelBuilder or Exporter or any other source. +# Models can contain calls to custom operations (from a custom domain like 'local' here or +# even "com.microsoft" etc.) +@script() +def model_script(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]: + DoubleX = op.Add(X, X) + YSquare = op.Mul(Y, Y) + # Example call to a custom operation + Temp1 = local.CustomOp1(DoubleX, YSquare) + # Another call to a custom operation with an attribute + Temp2 = local.CustomOp2(Temp1, alp=0.9) + return Temp2 + + +# Define expansions for custom operations as onnxscript functions +@script(opset=local) +def CustomOp1(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]: + Temp1 = op.Sub(X, Y) + return op.Div(Temp1, X) + + +@script(opset=local) +def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]: + Temp2 = op.Elu(X, alpha=alp) + return op.Mul(Temp2, Temp2) + + +# Now, we can replace the custom operations in the model with their expansions: + +functions = [CustomOp1.to_function_proto(), CustomOp2.to_function_proto()] + +model = model_script.to_model_proto() + +print("Original Model with custom operations:") +print(onnx.printer.to_text(model)) + + +updated_model = replace.replace_functions(model, functions) + +print("\nUpdated Model after replacing custom operations with their expansions:") +print(onnx.printer.to_text(updated_model)) diff --git a/onnxscript/utils/replace.py b/onnxscript/utils/replace.py new file mode 100644 index 0000000000..d3af1a37a0 --- /dev/null +++ b/onnxscript/utils/replace.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""A utility function to replace custom operations in a model with their expansions""" + +from typing import Sequence + +import onnx +import onnx_ir as ir +import onnx_ir.passes.common as common_passes + + +def replace_functions( + model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto] +) -> onnx.ModelProto: + """A utility function to replace custom operations in a model with their expansions: + Args: + model: An ONNX ModelProto possibly containing calls to custom operations. + functions: A sequence of FunctionProto defining the expansions for the custom operations. + + Returns: + An updated ModelProto with custom operations replaced by their expansions. + """ + irmodel = ir.from_proto(model) + irfunctions = [ir.from_proto(func) for func in functions] + model_functions = irmodel.functions + if len(model_functions) != 0: + # Since we use inlining, check that there are no model-local functions. + raise ValueError("Input model cannot have model-local functions.") + for func in irfunctions: + model_functions[func.identifier()] = func + + # TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values. + common_passes.InlinePass()(irmodel) + common_passes.RemoveUnusedOpsetsPass()(irmodel) + return ir.to_proto(irmodel) From 53af80035d69a4cb777593540dd72f69f99f9ec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 18 Nov 2025 17:23:12 +0100 Subject: [PATCH 111/123] add converter for aten::sym_storage_offset (#2697) --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 326075b2fe..9532a3b564 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8753,6 +8753,14 @@ def aten_sym_size(self: TensorType, dim: int = 0) -> INT64: return op.Squeeze(op.Shape(self, end=dim + 1, start=dim)) +@torch_op("aten::sym_storage_offset", trace_only=True) +def aten_sym_storage_offset(self: TensorType, dim: int = 0) -> INT64: + """sym_storage_offset(Tensor self, int dim) -> SymInt""" + # storage offset is not used in onnx world. + # the output of this function is not used. + return op.Constant(value_int=0) + + def aten_symeig( self: TensorType, eigenvectors: bool = False, upper: bool = True ) -> tuple[TensorType, TensorType]: From 6247ac18d1822dc054e9ca86a8d21da622d28a4f Mon Sep 17 00:00:00 2001 From: ruro Date: Wed, 19 Nov 2025 00:25:26 +0300 Subject: [PATCH 112/123] Implement ONNX export for `fake_quantize_per_*_affine` (#2696) See pytorch/pytorch#167063. ~The tests in this PR rely on pytorch/pytorch#167465.~ --- .../function_libs/torch_lib/ops/core.py | 123 +++++++++++++++++- tests/function_libs/torch_lib/extra_opinfo.py | 119 +++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 11 ++ 3 files changed, 246 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9532a3b564..4d83675589 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -23,6 +23,7 @@ COMPLEX128, DOUBLE, FLOAT, + FLOAT16, INT8, INT16, INT32, @@ -3317,17 +3318,58 @@ def aten_eye(n: int) -> TensorType: raise NotImplementedError() +@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True) def aten_fake_quantize_per_channel_affine( - self: TensorType, - scale: TensorType, - zero_point: TensorType, + self: TFloat, + scale: FLOAT, # float32 specifically! + zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only! axis: int, quant_min: int, quant_max: int, ) -> TensorType: """fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + if zero_point.type.dtype == ir.DataType.INT32: + zero_point = op.Cast(zero_point, to=int_dtype) + else: + raise NotImplementedError( + "ONNX only supports integer values for the zero_point parameter. " + f"Got {zero_point.type.dtype}", + ) + + quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_channel_affine_cachemask( @@ -3351,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward( raise NotImplementedError() +@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True) def aten_fake_quantize_per_tensor_affine( - self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int -) -> TensorType: + self: TFloat, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, +) -> TFloat: """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True) +def aten_fake_quantize_per_tensor_affine_tensor_qparams( + self: TFloat, + scale: TReal, + zero_point: TReal, + quant_min: int, + quant_max: int, +) -> TFloat: + """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor""" + + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +def _aten_fake_quantize_per_tensor_affine( + self: TFloat, + scale: Union[float, TReal], + zero_point: Union[int, TReal], + quant_min: int, + quant_max: int, +) -> TFloat: + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + # TODO: When opset >= 19, relex the condition for this cast + if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT: + scale = op.Cast(scale, to=ir.DataType.FLOAT) + + if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype: + zero_point = op.Cast(zero_point, to=int_dtype) + + quantized = op.QuantizeLinear(self, scale, zero_point) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_tensor_affine_cachemask( diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 5d7deb1695..2ce015b363 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -779,6 +779,109 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): ) +def sample_inputs_fake_quantize_per_tensor_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, empty and scalar tensors (like sample_inputs_elementwise_unary) + shapes = [ + (S,), + (1, 0, 3), + (), + ] + + scale_zero_point_dtypes = [ + # default (float, int) + (None, None) + ] + [ + # tensor_qparams (tensor, tensor) + (t1, t2) + for t1 in common_dtype.all_types_and() + for t2 in common_dtype.all_types_and() + ] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(shapes, scale_zero_point_dtypes, quant_vals) + for shape, (scale_dtype, zero_point_dtype), (quant_min, quant_max) in cases: + scale = make_arg( + (), + dtype=scale_dtype or torch.float64, + ) + if scale_dtype is None: + scale = scale.item() + + zero_point = make_arg( + (), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + if zero_point_dtype is None: + zero_point = zero_point.item() + + args = (scale, zero_point, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + +def sample_inputs_fake_quantize_per_channel_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, 2D, 4D and empty tensors (scalar tensors not supported) + axes_and_shapes = [ + # 1D, 2D, 4D + (axis, (S,) * dims) + for dims in (1, 2, 4) + for axis in range(dims) + ] + [ + # empty + (0, (1, 0, 3)), + (2, (1, 0, 3)), + # empty channel axis causes an error due to + # an internal zero_point.min() calculation + # (1, (1, 0, 3)), + ] + + # tensor_qparams + scale_dtype = torch.float + zero_point_dtypes = [torch.int32, torch.float, torch.half] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(axes_and_shapes, zero_point_dtypes, quant_vals) + for (axis, shape), zero_point_dtype, (quant_min, quant_max) in cases: + scale = make_arg((shape[axis],), dtype=scale_dtype) + + zero_point = make_arg( + (shape[axis],), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + + args = (scale, zero_point, axis, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + def _index_variable_bool(shape, max_indices, device): if not isinstance(shape, tuple): shape = (shape,) @@ -2408,6 +2511,22 @@ def __init__(self): sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + aten_name="fake_quantize_per_tensor_affine", + op=torch.fake_quantize_per_tensor_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_tensor_affine, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_channel_affine", + aten_name="fake_quantize_per_channel_affine", + op=torch.fake_quantize_per_channel_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_channel_affine, + supports_out=False, + ), opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4ef7550b6e..e87a0cc232 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -698,6 +698,17 @@ def _where_input_wrangler( TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail( reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223" ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_channel_affine", + core_ops.aten_fake_quantize_per_channel_affine, + ).xfail( + reason="fixme: ONNX (De)QuantizeLinear only supports integer zero_point values", + matcher=lambda sample: sample.args[1].dtype != torch.int32, + ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + core_ops.aten_fake_quantize_per_tensor_affine, + ), TorchLibOpInfo("fill", core_ops.aten_fill), TorchLibOpInfo("flip", core_ops.aten_flip).skip( reason="fixme: size 0 inputs are not handled yet", From 9dbf6858dbe01a10abda78312a4c2cfcd7e54a93 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 18 Nov 2025 13:26:48 -0800 Subject: [PATCH 113/123] Provide inplace replacement util (#2708) Redo https://github.com/microsoft/onnxscript/pull/2703, thanks to Titai. Signed-off-by: Ganesan Ramalingam --- onnxscript/utils/replace.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/onnxscript/utils/replace.py b/onnxscript/utils/replace.py index d3af1a37a0..d46493155d 100644 --- a/onnxscript/utils/replace.py +++ b/onnxscript/utils/replace.py @@ -9,19 +9,17 @@ import onnx_ir.passes.common as common_passes -def replace_functions( - model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto] -) -> onnx.ModelProto: +def replace_functions_inplace(irmodel: ir.Model, irfunctions: Sequence[ir.Function]) -> None: """A utility function to replace custom operations in a model with their expansions: + + The model is updated in-place. + Args: - model: An ONNX ModelProto possibly containing calls to custom operations. - functions: A sequence of FunctionProto defining the expansions for the custom operations. + irmodel: An ONNX model possibly containing calls to custom operations. + irfunctions: A sequence of functions defining the expansions for the custom operations. + - Returns: - An updated ModelProto with custom operations replaced by their expansions. """ - irmodel = ir.from_proto(model) - irfunctions = [ir.from_proto(func) for func in functions] model_functions = irmodel.functions if len(model_functions) != 0: # Since we use inlining, check that there are no model-local functions. @@ -32,4 +30,20 @@ def replace_functions( # TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values. common_passes.InlinePass()(irmodel) common_passes.RemoveUnusedOpsetsPass()(irmodel) + + +def replace_functions( + model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto] +) -> onnx.ModelProto: + """A utility function to replace custom operations in a model with their expansions: + Args: + model: An ONNX ModelProto possibly containing calls to custom operations. + functions: A sequence of FunctionProto defining the expansions for the custom operations. + + Returns: + An updated ModelProto with custom operations replaced by their expansions. + """ + irmodel = ir.from_proto(model) + irfunctions = [ir.from_proto(func) for func in functions] + replace_functions_inplace(irmodel, irfunctions) return ir.to_proto(irmodel) From 597d5f7273b5f6a1a758dcc4d2b1cdf0c1c84d75 Mon Sep 17 00:00:00 2001 From: Afshin Paydar <60913143+afshin-paydar@users.noreply.github.com> Date: Tue, 2 Dec 2025 04:29:51 +0800 Subject: [PATCH 114/123] Fix aten_unbind for torch >= 2.7 dynamo export (#2719) Fix aten_unbind for torch >= 2.7 dynamo export ## Problem When exporting PyTorch models to ONNX with `dynamo=True` on torch >= 2.7, the `aten_unbind` operation fails with: ``` TypeError: 'SymbolicTensor' object is not iterable ``` This occurs because `op.Split(self, axis=dim, num_outputs=num_outputs)` returns a single `SymbolicTensor` object rather than an iterable sequence during dynamo export. The subsequent list comprehension `[op.Squeeze(out, [dim]) for out in outputs]` attempts to iterate over this non-iterable object, causing that error. This affects models using LSTM and other operations that internally call `unbind`, preventing successful ONNX export. ## Solution Replace the `Split` op approach with explicit `Slice` operations: - For each output along the unbind dimension, create an individual `Slice` operation to extract one element - Apply `Squeeze` to remove the size-1 dimension - Collect all results in a list ## Test coverage: - num_outputs = 1: test_unbind_size_one - num_outputs = 2: test_unbind_with_lstm, test_unbind_dim1 (size 2 along some dims) - num_outputs = 3: test_unbind_dim0, test_unbind_dim1 - num_outputs = 4: test_unbind_negative_dim - Negative dimensions: test_unbind_negative_dim ## Related Issues Fixes pytorch/pytorch#168969 --- .../function_libs/torch_lib/ops/core.py | 23 ++-- .../function_libs/torch_lib/e2e_ops_tests.py | 106 ++++++++++++++++++ 2 files changed, 120 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4d83675589..b287cec057 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -9200,16 +9200,21 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"): - # We can create a definitive split op if the input shape is static - # Only torch>=2.7 supports correctly generating the correct number of outputs for Split + if isinstance(self.shape[dim], int): num_outputs = self.shape[dim] - if num_outputs != 1: - outputs = op.Split(self, axis=dim, num_outputs=num_outputs) - else: - outputs = [self] - - return [op.Squeeze(out, [dim]) for out in outputs] + results = [] + for i in range(num_outputs): + # Slice to get a single element at position i along dim + sliced = op.Slice( + self, + starts=op.Constant(value_ints=[i]), + ends=op.Constant(value_ints=[i + 1]), + axes=op.Constant(value_ints=[dim]), + ) + # Squeeze to remove the dimension of size 1 + squeezed = op.Squeeze(sliced, axes=[dim]) + results.append(squeezed) + return results return op.SplitToSequence(self, axis=dim, keepdims=False) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 1546de59bd..a2ced58c44 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -520,6 +520,112 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_unbind_dim0(self): + """Test unbind along dimension 0""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(3, 4, 5) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dim1(self): + """Test unbind along dimension 1""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_negative_dim(self): + """Test unbind with negative dimension""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=-1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_size_one(self): + """Test unbind with dimension of size 1""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return tensors[0] + + model = UnbindModel() + x = torch.randn(1, 4, 5) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_with_lstm(self): + """Test unbind in LSTM context""" + + class LSTMDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(100, 64) + self.lstm = torch.nn.LSTM(64, 64, 2, batch_first=True) # 2 layers + self.fc = torch.nn.Linear(64, 100) + + def forward(self, tokens, h, c): + embedded = self.embedding(tokens).unsqueeze(0) + output, (h_out, c_out) = self.lstm(embedded, (h, c)) + logits = self.fc(output.squeeze(0).squeeze(0)) + return logits, h_out, c_out + + model = LSTMDecoder() + model.eval() + tokens = torch.tensor([1]) + h = torch.randn(2, 1, 64) # 2 layers + c = torch.randn(2, 1, 64) # 2 layers + onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dynamic_dim0(self): + """Test unbind with dynamic dimension 0 - triggers SplitToSequence""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(3, 4, 5) + onnx_program = torch.onnx.export( + model, (x,), dynamo=True, verbose=False, dynamic_shapes=({0: "batch_size"},) + ) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dynamic_dim1(self): + """Test unbind with dynamic dimension 1 - triggers SplitToSequence""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export( + model, (x,), dynamo=True, verbose=False, dynamic_shapes=({1: "seq_len"},) + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From 7dab831510855491466b0aaed0c5d8d0a2847170 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 13:20:52 -0800 Subject: [PATCH 115/123] chore(deps): bump ruff from 0.14.3 to 0.14.6 in /requirements/lintrunner (#2716) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [ruff](https://github.com/astral-sh/ruff) from 0.14.3 to 0.14.6.
Release notes

Sourced from ruff's releases.

0.14.6

Release Notes

Released on 2025-11-21.

Preview features

  • [flake8-bandit] Support new PySNMP API paths (S508, S509) (#21374)

Bug fixes

  • Adjust own-line comment placement between branches (#21185)
  • Avoid syntax error when formatting attribute expressions with outer parentheses, parenthesized value, and trailing comment on value (#20418)
  • Fix panic when formatting comments in unary expressions (#21501)
  • Respect fmt: skip for compound statements on a single line (#20633)
  • [refurb] Fix FURB103 autofix (#21454)
  • [ruff] Fix false positive for complex conversion specifiers in logging-eager-conversion (RUF065) (#21464)

Rule changes

  • [ruff] Avoid false positive on ClassVar reassignment (RUF012) (#21478)

CLI

  • Render hyperlinks for lint errors (#21514)
  • Add a ruff analyze option to skip over imports in TYPE_CHECKING blocks (#21472)

Documentation

  • Limit eglot-format hook to eglot-managed Python buffers (#21459)
  • Mention force-exclude in "Configuration > Python file discovery" (#21500)

Contributors

Install ruff 0.14.6

Install prebuilt binaries via shell script

curl --proto '=https' --tlsv1.2 -LsSf
https://github.com/astral-sh/ruff/releases/download/0.14.6/ruff-installer.sh
| sh
</tr></table>

... (truncated)

Changelog

Sourced from ruff's changelog.

0.14.6

Released on 2025-11-21.

Preview features

  • [flake8-bandit] Support new PySNMP API paths (S508, S509) (#21374)

Bug fixes

  • Adjust own-line comment placement between branches (#21185)
  • Avoid syntax error when formatting attribute expressions with outer parentheses, parenthesized value, and trailing comment on value (#20418)
  • Fix panic when formatting comments in unary expressions (#21501)
  • Respect fmt: skip for compound statements on a single line (#20633)
  • [refurb] Fix FURB103 autofix (#21454)
  • [ruff] Fix false positive for complex conversion specifiers in logging-eager-conversion (RUF065) (#21464)

Rule changes

  • [ruff] Avoid false positive on ClassVar reassignment (RUF012) (#21478)

CLI

  • Render hyperlinks for lint errors (#21514)
  • Add a ruff analyze option to skip over imports in TYPE_CHECKING blocks (#21472)

Documentation

  • Limit eglot-format hook to eglot-managed Python buffers (#21459)
  • Mention force-exclude in "Configuration > Python file discovery" (#21500)

Contributors

0.14.5

Released on 2025-11-13.

Preview features

  • [flake8-simplify] Apply SIM113 when index variable is of type int (#21395)

... (truncated)

Commits
  • 59c6cb5 Bump 0.14.6 (#21558)
  • 54dba15 [ty] Improve debug messages when imports fail (#21555)
  • 1af3185 [ty] Add support for relative import completions
  • 553e568 [ty] Refactor detection of import statements for completions
  • cdef3f5 [ty] Use dedicated collector for completions
  • 6178822 [ty] Attach subdiagnostics to unresolved-import errors for relative imports...
  • 6b7adb0 [ty] support PEP 613 type aliases (#21394)
  • 06941c1 [ty] More low-hanging fruit for inlay hint goto-definition (#21548)
  • eb7c098 [ty] implement TypedDict structural assignment (#21467)
  • 1b28fc1 [ty] Add more random TypeDetails and tests (#21546)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.14.3&new-version=0.14.6)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index f7549e39e3..5b00aa3c0f 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.14.3 +ruff==0.14.6 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250915 From c8bfe719161c54daf28d1a00ccaffc93c65d8488 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 13:21:16 -0800 Subject: [PATCH 116/123] chore(deps): bump actions/checkout from 5 to 6 (#2715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [actions/checkout](https://github.com/actions/checkout) from 5 to 6.
Release notes

Sourced from actions/checkout's releases.

v6.0.0

What's Changed

Full Changelog: https://github.com/actions/checkout/compare/v5.0.0...v6.0.0

v6-beta

What's Changed

Updated persist-credentials to store the credentials under $RUNNER_TEMP instead of directly in the local git config.

This requires a minimum Actions Runner version of v2.329.0 to access the persisted credentials for Docker container action scenarios.

v5.0.1

What's Changed

Full Changelog: https://github.com/actions/checkout/compare/v5...v5.0.1

Changelog

Sourced from actions/checkout's changelog.

Changelog

V6.0.0

V5.0.1

V5.0.0

V4.3.1

V4.3.0

v4.2.2

v4.2.1

v4.2.0

v4.1.7

v4.1.6

v4.1.5

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=actions/checkout&package-manager=github_actions&previous-version=5&new-version=6)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql-analysis.yml | 2 +- .github/workflows/lint.yaml | 4 ++-- .github/workflows/main.yaml | 6 +++--- .github/workflows/pages.yaml | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 6953a76929..c169029a58 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -41,7 +41,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 3fe51a3a5a..d1f165afa5 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -20,7 +20,7 @@ jobs: pull-requests: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: misspell # Check spelling uses: reviewdog/action-misspell@v1 with: @@ -43,7 +43,7 @@ jobs: permissions: security-events: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 85d2a0b331..fcff6d2dd4 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -57,7 +57,7 @@ jobs: nox-tag: test-onnx-ir-git runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v6 with: @@ -95,7 +95,7 @@ jobs: os: [ubuntu-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 with: @@ -119,7 +119,7 @@ jobs: update_readme: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python uses: actions/setup-python@v6 - name: Update readme diff --git a/.github/workflows/pages.yaml b/.github/workflows/pages.yaml index ce638dc60d..51ae68abcc 100644 --- a/.github/workflows/pages.yaml +++ b/.github/workflows/pages.yaml @@ -25,14 +25,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Setup Pages uses: actions/configure-pages@v4 - name: Setup Python uses: actions/setup-python@v6 with: python-version: "3.10" - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel From 45ba02d399430a2b4109e9f23c7f4555e1e28f8e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 13:33:38 -0800 Subject: [PATCH 117/123] chore(deps): bump onnxruntime from 1.23.1 to 1.23.2 in /requirements/ci (#2652) Bumps [onnxruntime](https://github.com/microsoft/onnxruntime) from 1.23.1 to 1.23.2.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=onnxruntime&package-manager=pip&previous-version=1.23.1&new-version=1.23.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/ci/requirements-ort-nightly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index cb16597719..f2e801846a 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ # https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -onnxruntime==1.23.1 +onnxruntime==1.23.2 From 3364ada90a087f4118426c22f22371a74193d851 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:53:40 +0000 Subject: [PATCH 118/123] chore(deps): bump github/codeql-action from 3 to 4 (#2626) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 3 to 4.
Release notes

Sourced from github/codeql-action's releases.

v3.30.8

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.30.8 - 10 Oct 2025

No user facing changes.

See the full CHANGELOG.md for more information.

v3.30.7

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.30.7 - 06 Oct 2025

No user facing changes.

See the full CHANGELOG.md for more information.

v3.30.6

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.30.6 - 02 Oct 2025

  • Update default CodeQL bundle version to 2.23.2. #3168

See the full CHANGELOG.md for more information.

v3.30.5

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.30.5 - 26 Sep 2025

  • We fixed a bug that was introduced in 3.30.4 with upload-sarif which resulted in files without a .sarif extension not getting uploaded. #3160

See the full CHANGELOG.md for more information.

v3.30.4

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

3.30.4 - 25 Sep 2025

... (truncated)

Changelog

Sourced from github/codeql-action's changelog.

3.29.4 - 23 Jul 2025

No user facing changes.

3.29.3 - 21 Jul 2025

No user facing changes.

3.29.2 - 30 Jun 2025

  • Experimental: When the quality-queries input for the init action is provided with an argument, separate .quality.sarif files are produced and uploaded for each language with the results of the specified queries. Do not use this in production as it is part of an internal experiment and subject to change at any time. #2935

3.29.1 - 27 Jun 2025

  • Fix bug in PR analysis where user-provided include query filter fails to exclude non-included queries. #2938
  • Update default CodeQL bundle version to 2.22.1. #2950

3.29.0 - 11 Jun 2025

  • Update default CodeQL bundle version to 2.22.0. #2925
  • Bump minimum CodeQL bundle version to 2.16.6. #2912

3.28.21 - 28 July 2025

No user facing changes.

3.28.20 - 21 July 2025

3.28.19 - 03 Jun 2025

  • The CodeQL Action no longer includes its own copy of the extractor for the actions language, which is currently in public preview. The actions extractor has been included in the CodeQL CLI since v2.20.6. If your workflow has enabled the actions language and you have pinned your tools: property to a specific version of the CodeQL CLI earlier than v2.20.6, you will need to update to at least CodeQL v2.20.6 or disable actions analysis.
  • Update default CodeQL bundle version to 2.21.4. #2910

3.28.18 - 16 May 2025

  • Update default CodeQL bundle version to 2.21.3. #2893
  • Skip validating SARIF produced by CodeQL for improved performance. #2894
  • The number of threads and amount of RAM used by CodeQL can now be set via the CODEQL_THREADS and CODEQL_RAM runner environment variables. If set, these environment variables override the threads and ram inputs respectively. #2891

3.28.17 - 02 May 2025

  • Update default CodeQL bundle version to 2.21.2. #2872

3.28.16 - 23 Apr 2025

... (truncated)

Commits
  • a841c54 Scratch uploadSpecifiedFiles tests, make uploadPayload tests instead
  • aeb12f6 Merge branch 'main' into redsun82/skip-sarif-upload-tests
  • 6fd4ceb Merge pull request #3189 from github/henrymercer/download-codeql-rate-limit
  • 196a3e5 Merge pull request #3188 from github/mbg/telemetry/partial-config
  • 98abb87 Add configuration error for rate limited CodeQL download
  • bdd2cdf Also include language in error status report for start-proxy, if available
  • fb14878 Include languages in start-proxy telemetry
  • 2ff418f Parse language before calling getCredentials
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github/codeql-action&package-manager=github_actions&previous-version=3&new-version=4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) You can trigger a rebase of this PR by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
> **Note** > Automatic rebases have been disabled on this pull request as it has been open for over 30 days. Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ti-Tai Wang --- .github/workflows/codeql-analysis.yml | 6 +++--- .github/workflows/lint.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index c169029a58..30a97315d0 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -45,7 +45,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -59,7 +59,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v3 + uses: github/codeql-action/autobuild@v4 # ℹ️ Command-line programs to run using the OS shell. # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun @@ -72,4 +72,4 @@ jobs: # ./location_of_script_within_repo/buildscript.sh - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index d1f165afa5..a87792fd2f 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -78,7 +78,7 @@ jobs: # To toggle linter comments in the files page, press `i` on the keyboard if: always() continue-on-error: true - uses: github/codeql-action/upload-sarif@v3 + uses: github/codeql-action/upload-sarif@v4 with: # Path to SARIF file relative to the root of the repository sarif_file: lintrunner.sarif From 3e7d9fb55922d3d32a52c0721a2c758971b75daa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 3 Dec 2025 13:10:12 -0800 Subject: [PATCH 119/123] chore(deps): bump ruff from 0.14.6 to 0.14.7 in /requirements/lintrunner (#2721) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [ruff](https://github.com/astral-sh/ruff) from 0.14.6 to 0.14.7.
Release notes

Sourced from ruff's releases.

0.14.7

Release Notes

Released on 2025-11-28.

Preview features

  • [flake8-bandit] Handle string literal bindings in suspicious-url-open-usage (S310) (#21469)
  • [pylint] Fix PLR1708 false positives on nested functions (#21177)
  • [pylint] Fix suppression for empty dict without tuple key annotation (PLE1141) (#21290)
  • [ruff] Add rule RUF066 to detect unnecessary class properties (#21535)
  • [ruff] Catch more dummy variable uses (RUF052) (#19799)

Bug fixes

  • [server] Set severity for non-rule diagnostics (#21559)
  • [flake8-implicit-str-concat] Avoid invalid fix in (ISC003) (#21517)
  • [parser] Fix panic when parsing IPython escape command expressions (#21480)

CLI

  • Show partial fixability indicator in statistics output (#21513)

Contributors

Install ruff 0.14.7

Install prebuilt binaries via shell script

curl --proto '=https' --tlsv1.2 -LsSf
https://github.com/astral-sh/ruff/releases/download/0.14.7/ruff-installer.sh
| sh

Install prebuilt binaries via powershell script

powershell -ExecutionPolicy Bypass -c "irm
https://github.com/astral-sh/ruff/releases/download/0.14.7/ruff-installer.ps1
| iex"

... (truncated)

Changelog

Sourced from ruff's changelog.

0.14.7

Released on 2025-11-28.

Preview features

  • [flake8-bandit] Handle string literal bindings in suspicious-url-open-usage (S310) (#21469)
  • [pylint] Fix PLR1708 false positives on nested functions (#21177)
  • [pylint] Fix suppression for empty dict without tuple key annotation (PLE1141) (#21290)
  • [ruff] Add rule RUF066 to detect unnecessary class properties (#21535)
  • [ruff] Catch more dummy variable uses (RUF052) (#19799)

Bug fixes

  • [server] Set severity for non-rule diagnostics (#21559)
  • [flake8-implicit-str-concat] Avoid invalid fix in (ISC003) (#21517)
  • [parser] Fix panic when parsing IPython escape command expressions (#21480)

CLI

  • Show partial fixability indicator in statistics output (#21513)

Contributors

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ruff&package-manager=pip&previous-version=0.14.6&new-version=0.14.7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/lintrunner/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 5b00aa3c0f..f140d45917 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.14.6 +ruff==0.14.7 # MYPY mypy==1.10.1 types-PyYAML==6.0.12.20250915 From bbe9c2bcfe7d0b917446007bb9030c849a2c1673 Mon Sep 17 00:00:00 2001 From: ruro Date: Thu, 4 Dec 2025 00:36:48 +0300 Subject: [PATCH 120/123] Don't constant fold Quantize/DequantizeLinear nodes by default (#2713) I added support for exporting `QuantizeLinear`/`DequantizeLinear` nodes (from `fake_quantize_per_*_affine` torch operators) in a previous PR. Unfortunately, the current default onnxscript optimizer settings tend to automatically remove any weight quantization. This is because the `Weight -> QDQ -> ...` pattern looks like it can be just constant folded to `QDQ(Weight) -> ...`. I believe that this behavior is not desirable, since the presence of `QDQ` nodes in the graph is what allows inference engines to run the supported computations using quantized data types. So the purpose of `QDQ` nodes is to hold the relevant quantization "metadata". As such, they normally shouldn't be constant folded. I have extended the existing logic in `FoldConstantsPass` that was used to exclude `ConstantOfShape` from constant folding. I haven't found any tests verifying this behavior for `ConstantOfShape` and I'm not sure, how to set up such a unit test, so I have left this code untested for now. If adding tests is mandatory, please give me a hint on where should I add such a test and what would be the best way to check/assert that the optimized graph matches the expectations (hopefully without reinventing the wheel or manually introspecting the `ir.Model` object). --- onnxscript/optimizer/_constant_folding.py | 25 ++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 03536cc9ce..27b09557e7 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -26,6 +26,14 @@ import onnxscript.utils.utils as utils from onnxscript.ir import _tape +DEFAULT_CONSTANT_FOLD_BLACKLIST = [ + # ConstantOfShape is preserved to avoid increasing model size unnecessarily + "ConstantOfShape", + # Quantize/DequantizeLinear are preserved to keep the quantization info + "QuantizeLinear", + "DequantizeLinear", +] + DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 8192 DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512 @@ -1226,14 +1234,17 @@ def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None: elif should_fold is None: # Use default rules to decide whether to fold the node: - # - ConstantOfShape is preserved to avoid increasing model size unnecessarily + # - Nodes in the DEFAULT_CONSTANT_FOLD_BLACKLIST list are not folded # - If the any tensor input size exceeds the input_size_limit, skip folding the node - if _is_onnx_op(node, "ConstantOfShape"): - logger.info( - "Skipping constant folding for node %r because ConstantOfShape is preserved by default", - node.name, - ) - return None + for op_type in DEFAULT_CONSTANT_FOLD_BLACKLIST: + if _is_onnx_op(node, op_type): + logger.info( + "Skipping constant folding for node %r because " + "%s is preserved by default", + node.name, + op_type, + ) + return None input_tensors = [x.const_value if x is not None else None for x in node.inputs] large_inputs = [ From 5583f96f5593373fa5c66e0d475bd1ac91b6bb2e Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 8 Dec 2025 20:15:16 -0800 Subject: [PATCH 121/123] support opset23 (#2725) --- onnxscript/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py index b839093d2b..bccfd84cd4 100644 --- a/onnxscript/__init__.py +++ b/onnxscript/__init__.py @@ -55,6 +55,7 @@ "opset20", "opset21", "opset22", + "opset23", "opset_ai_onnx_ml1", "opset_ai_onnx_ml2", "opset_ai_onnx_ml3", @@ -92,6 +93,7 @@ opset20, opset21, opset22, + opset23, opset_ai_onnx_ml1, opset_ai_onnx_ml2, opset_ai_onnx_ml3, From a3883a6cc4d4e5c4222d9638d665252cb6386601 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 11 Dec 2025 15:54:30 -0500 Subject: [PATCH 122/123] Update aten_index_put implementation (#2712) --- .../function_libs/torch_lib/ops/core.py | 187 ++++++++----- .../function_libs/torch_lib/e2e_ops_tests.py | 249 +++++++++++++++++- 2 files changed, 368 insertions(+), 68 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b287cec057..099b786d74 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4541,80 +4541,135 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - - def _make_reshape_list_broadcastable(reshape_list, values_shape): - # Remove ones until the rank of reshape_list matches values_shape. - while len(reshape_list) > len(values_shape) and 1 in reshape_list: - reshape_list.remove(1) - - # Now ensure each dimension is broadcastable: - # This is mandatory when mixing basic and advanced indexing - # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) - # the reshape list should be : [[2, 1], [1, 3], [2, 1]] - for i, r in enumerate(reshape_list): - if r not in (1, values_shape[i]): - value_index = values_shape.index(r) - # Swap elements - # For the example above the current reshape list is [1, 2] for last dim, - # to make it broadcastable, we swap the elements - reshape_list[value_index], reshape_list[i] = r, 1 - - return reshape_list - - # Ensure the number of indices matches the tensor rank. + # Ensure the number of indices matches the tensor rank by appending trailing Nones. self_rank = len(self.shape) if len(indices) < self_rank: indices = list(indices) + [None] * (self_rank - len(indices)) - # Get values shape - values_shape = tuple(values.shape) + # The behavior of the op is dependent on whether there are advanced indices (i.e., non-scalar tensors) + # and whether these advanced indices are contiguous. + + # Identify advanced indices. + def is_advanced_index(index): + # Note: In this function, the index is assumed to be either None or an int64 Tensor. + return index is not None + + advanced_indices: list[int] = [] + none_indices: list[int] = [] + num_advanced_indices = 0 + num_none_indices = 0 + + for i, index in enumerate(indices): + if is_advanced_index(index): + advanced_indices.append(i) + num_advanced_indices += 1 + elif index is None: + none_indices.append(i) + num_none_indices += 1 + else: + raise ValueError(f"Unhandled index at position {i}: {index}") - index_vectors = [] - for i in range(self_rank): - if indices[i] is None: - # For a full slice along dim i, create a range index [0, self.shape[i]). - idx = op.Range(0, self.shape[i], 1) - reshape_update = self.shape[i] + self_shape = op.Shape(self) + if num_advanced_indices == 0: + return op.Expand(values, self_shape) + + # More than one advanced index may require broadcasting of index values + if num_advanced_indices > 1: + # Check for special case where all advanced indices have same shape. + # But need to ensure none of the shapes have None as a dimension, which + # will invalidate equality-based check. + first_shape = indices[advanced_indices[0]].shape + + def same_shape(other_shape: ir.Shape) -> bool: + return (not any(d is None for d in other_shape)) and other_shape == first_shape + + all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices) + if not all_same_shape: + # Broadcast advanced indices to a common shape. + advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices) + shapes = [] + for i in advanced_indices: + index = indices[i] + index_rank = len(index.shape) + index_shape = op.Shape(index) + if index_rank < advanced_index_rank: + padding = op.Constant( + value_ints=[1 for _ in range(advanced_index_rank - index_rank)] + ) + index_shape = op.Concat(padding, index_shape, axis=0) + shapes.append(index_shape) + advanced_indices_shape = op.Max(*shapes) + indices = [ + op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index + for index in indices + ] else: - idx = indices[i] - reshape_update = math.prod(idx.shape) - # when Index is more than 1D, flatten it and also the values shape - # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # Indices -> (2*4,) and values shape (2*4, 32) - if len(idx.shape) > 1: - values_shape = (reshape_update, *values_shape[len(idx.shape) :]) - - # Flatten index (always working with 1D index in each dim) - idx = op.Reshape(idx, [-1]) - - # Create a reshape pattern: one value per index dimension, - # with the current dimension set to the update size. - reshape_list = [1] * len(indices) - reshape_list[i] = reshape_update - - # Adjust the reshape list to match the values shape. - reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) - - # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list, allowzero=True) - idx = op.Expand(idx, values_shape) - - # Flatten the index to 1D and unsqueeze to form a column vector. - idx = op.Reshape(idx, [-1]) - idx = op.Unsqueeze(idx, axes=[1]) - index_vectors.append(idx) - - # Concatenate the index vectors along axis=1 to form the final indices. - new_index = op.Concat(*index_vectors, axis=1) - - # Flatten values to match the indices - flat_values = op.Reshape(values, [-1]) - - if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) else: - result = op.ScatterND(self, new_index, flat_values) + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) + + # ONNX ScatterND supports only the case where all advanced indices appear first, + # followed by None indices. So, we need to transpose self and values so that the + # advanced indices appear first, and then transpose the result back to original + # order at the end. + + none_indices_constant = op.Constant(value_ints=none_indices) + none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0) + target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0) + target_rank = advanced_index_rank + num_none_indices + + # Generate indices tensor required by ONNX ScatterND by unsqueezing an extra dimension and + # concatenating all advanced indices along this new dimension. + minus_one = op.Constant(value_ints=[-1]) + advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices] + onnx_index = op.Concat(*advanced_index_values, axis=-1) + + # Check if advanced indices are contiguous: + contiguous = True + if advanced_indices: + if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices): + contiguous = False + + # Bring advanced indices to front: + perm = advanced_indices + none_indices + transposed = op.Transpose(self, perm=perm) + + # Expand values to match target shape: + # First, transpose values if necessary to match advanced indices order! + if contiguous: + # values may need to be transposed before expanding to target shape + num_padded_dims = target_rank - len(values.shape) + if num_padded_dims > 0: + unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims))) + values = op.Unsqueeze(values, unsqueezed_dims) + initial_none_index_positions = list(range(advanced_indices[0])) + advanced_index_replacement_positions = list( + range(advanced_indices[0], advanced_indices[0] + advanced_index_rank) + ) + final_none_index_positions = list( + range(advanced_indices[0] + advanced_index_rank, target_rank) + ) + values_perm = ( + advanced_index_replacement_positions + + initial_none_index_positions + + final_none_index_positions + ) + values = op.Transpose(values, perm=values_perm) + + expanded_values = op.Expand(values, target_shape) + + updated = op.ScatterND( + transposed, onnx_index, expanded_values, reduction="add" if accumulate else None + ) + + # Inverse transpose to restore original dimension order: + inverse_perm = [0] * self_rank + for i, p in enumerate(perm): + inverse_perm[p] = i + result = op.Transpose(updated, perm=inverse_perm) return result diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index a2ced58c44..d344723408 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations -# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo - +import math import unittest +import parameterized + +# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch from torch.onnx._internal.exporter import _testing @@ -626,6 +629,248 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + @parameterized.parameterized.expand( + [ + # Multiple advanced indices, all 1D tensors. + # Non-contiguous advanced indices: updates must be broadcastable to (2, 6) + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 6), + "non_contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 1), + "non_contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (1, 6), + "non_contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (6,), + "non_contiguous_non_broadcast_indices_new_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (), + "non_contiguous_non_broadcast_indices_scalar", + ), + # Contiguous advanced indices versions of above tests: updates must be broadcastable to (6, 2) + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 2), + "contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 1), + "contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (1, 2), + "contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (2,), + "contiguous_non_broadcast_indices_new_dim1", + ), + ((6, 6, 6), [None, [0, 1], [2, 3]], (), "contiguous_non_broadcast_indices_scalar"), + # Multiple advanced indices, with broadcasting among indices. + # Contiguous advanced indices: + # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) + # The update values must be broadcastable to (6,2,2) + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 2, 2), + "contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 1, 1), + "contiguous_broadcast_indices_expand_dim2_dim3", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (2,), + "contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Non-contiguous advanced indices versions of above tests: + # Here, update values must be broadcastable to (2,2,6) + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (2, 2, 6), + "non_contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (1, 1, 6), + "non_contiguous_broadcast_indices_expand_dim1_dim2", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (6,), + "non_contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Other test cases + ( + (4, 4, 4, 4), + [None, [0, 1], None, [2, 3]], + (2, 4, 4), + "non_contiguous_non_first", + ), + ((6, 6, 6), [0, None, None], (6, 6), "single_scalar_index"), + ((6, 6, 6), [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index"), + ((6, 6, 6), [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index"), + # (TODO): Exporter doesn't yet support all None indices + # ((6, 6, 6), [None, None, None], (6, 6, 6), "all_none_indices"), + ] + ) + def test_index_put(self, x_shape, index_list, update_shape, _: str): + indices = [ + (torch.tensor(index, dtype=torch.int64) if index is not None else None) + for index in index_list + ] + + class Model(torch.nn.Module): + def forward(self, x, update): + return torch.ops.aten.index_put(x, indices, update, accumulate=True) + + x = torch.zeros(x_shape, dtype=torch.float32) + update = torch.randn(update_shape, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x, update), + input_names=["x", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_dynamic(self): + for dimension in [3, 4, 2]: + with self.subTest(dimension=dimension): + + class Model(torch.nn.Module): + def __init__(self, dimension): + super().__init__() + self.params = torch.zeros( + (4, 5) + if dimension == 2 + else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5)) + ) + self.dimension = dimension + + def forward(self, update, index1, index2): + copy = self.params.clone() + if self.dimension == 2: + copy[index1, index2] = update + elif self.dimension == 3: + copy[:, index1, index2] = update + else: + copy[:, :, index1, index2] = update + return copy + + update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) + index1 = torch.tensor([1, 2], dtype=torch.int64) + index2 = torch.tensor([3, 4], dtype=torch.int64) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) + onnx_program = torch.onnx.export( + Model(dimension), + tuple(feeds.values()), + input_names=["update", "index1", "index2"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes={ + "update": {0: "dn"}, + "index1": {0: "dn"}, + "index2": {0: "dn"}, + }, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_55_12_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update) + + x = torch.zeros((6, 5), dtype=torch.float32) + index = torch.tensor([[2, 1]], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_55_2_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update, accumulate=True) + + x = torch.ones((6, 5), dtype=torch.float32) + index = torch.tensor([4, 3], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_scatter_nd(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + x = x.clone() + return torch.ops.aten.index_put(x, [None, index, None], update) + + shape = (2, 3, 2) + N = math.prod(shape) + x = torch.arange(N, dtype=torch.float32).reshape(shape) + update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100 + index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2] + + feeds = dict(zip(["x", "index", "update"], (x, index, update))) + onnx_program = torch.onnx.export( + Model(), + tuple(feeds.values()), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}), + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From da967e3c314023a533786c6a50bb41895ca3f07d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 15:20:05 -0800 Subject: [PATCH 123/123] [torchlib] Fix and implement overloads for aten::remainder (#2727) Previously the Scalar_Tensor overload will fail because the first arg will be a scalar which does not have dtype. Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 099b786d74..254378bf09 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7657,11 +7657,8 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) -def aten_remainder(self: TTensor, other: TTensor) -> TTensor: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - - if self.dtype.is_integer(): +def _aten_remainder(self: TTensor, other: TTensor, integer: bool) -> TTensor: + if integer: return op.Mod(self, other) # TODO(justinchuby): Improve fp16 precision by following the logic in @@ -7673,6 +7670,29 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: return op.Sub(self, op.Mul(rounded_quotient, other)) +@torch_op("aten::remainder.Tensor", trace_only=True) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: + """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + + return _aten_remainder(self, other, integer=self.dtype.is_integer()) + + +@torch_op("aten::remainder.Scalar", trace_only=True) +def aten_remainder_scalar(self: TTensor, other: float) -> TTensor: + """remainder.Scalar(Tensor self, Scalar other) -> Tensor""" + + other_tensor = ir.tensor(other, dtype=self.dtype) + return _aten_remainder(self, other_tensor, integer=self.dtype.is_integer()) + + +@torch_op("aten::remainder.Scalar_Tensor", trace_only=True) +def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor: + """remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + + self_tensor = ir.tensor(self, dtype=other.dtype) + return _aten_remainder(self_tensor, other, integer=other.dtype.is_integer()) + + @torch_op("_operator::mod", trace_only=True) def operator_mod(self: TTensor, other: TTensor) -> TTensor: # Modulus operator % on SymInt