From 503cb04c314ed3eea53492ab7ed030b919bc84dc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 11:16:46 -0700 Subject: [PATCH 1/9] [torchlib] Support integers in logical_and/or ops Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 81 ++++++++++++------- 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 99fc6fb44..acf6e3a6c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1233,15 +1233,19 @@ def aten_binomial( "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", - "_operator::and_", ), trace_only=True, ) -def aten_bitwise_and(self: TInt, other: TInt) -> TInt: +def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_and implements the BOOL variant - return op.BitwiseAnd(self, other) + assert self.dtype == other.dtype + + if self.dtype.is_integer(): + return op.BitwiseAnd(self, other) + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op( @@ -1341,15 +1345,19 @@ def aten_bitwise_not(self: TInt) -> TInt: "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", - "_operator::or_", ), trace_only=True, ) def aten_bitwise_or(self: TInt, other: TInt) -> TInt: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_or implements the BOOL variant - return op.BitwiseOr(self, other) + assert self.dtype == other.dtype + + if self.dtype.is_integer(): + return op.BitwiseOr(self, other) + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op( @@ -1489,9 +1497,13 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: ) def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_xor implements the BOOL variant + assert self.dtype == other.dtype - return op.BitwiseXor(self, other) + if self.dtype.is_integer(): + return op.BitwiseXor(self, other) + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op("aten::blackman_window", trace_only=True) @@ -5011,57 +5023,66 @@ def aten_logdet(self: TFloat) -> TFloat: @torch_op( - ( - "aten::logical_and", - "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", - ), + ("aten::logical_and") trace_only=True, ) -def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: +def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" - return op.And(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + if self.dtype.is_integer(): + return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) + raise NotImplementedError(f"Not implemented for dtype {self.dtype} and {other.dtype}") @torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) def aten_logical_not(self: BOOL) -> BOOL: """logical_not(Tensor self) -> Tensor""" - return op.Not(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + if self.dtype.is_integer(): + return op.Not(op.Cast(self, to=BOOL.dtype)) + raise NotImplementedError(f"Not implemented for dtype {self.dtype}") @torch_op( ( "aten::logical_or", - "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", "aten::add.Tensor", - "aten::add.Scalar", + "aten::add.Scalar" ), trace_only=True, ) def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" - return op.Or(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + if self.dtype.is_integer(): + return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) + raise NotImplementedError(f"Not implemented for dtype {self.dtype} and {other.dtype}") @torch_op( - ( - "aten::logical_xor", - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ), + ("aten::logical_xor") trace_only=True, ) def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" - return op.Xor(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + if self.dtype.is_integer(): + return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) + raise NotImplementedError(f"Not implemented for dtype {self.dtype} and {other.dtype}") @torch_op("aten::logit", private=True) From b212b2713c02709529c2989d4798dbb4cfa81fe6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 11:19:32 -0700 Subject: [PATCH 2/9] Simplify Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index acf6e3a6c..bcbca5180 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5033,9 +5033,7 @@ def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: if self.dtype == ir.DataType.BOOL: return op.And(self, other) - if self.dtype.is_integer(): - return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) - raise NotImplementedError(f"Not implemented for dtype {self.dtype} and {other.dtype}") + return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) @torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) @@ -5044,9 +5042,7 @@ def aten_logical_not(self: BOOL) -> BOOL: if self.dtype == ir.DataType.BOOL: return op.Not(self) - if self.dtype.is_integer(): - return op.Not(op.Cast(self, to=BOOL.dtype)) - raise NotImplementedError(f"Not implemented for dtype {self.dtype}") + return op.Not(op.Cast(self, to=BOOL.dtype)) @torch_op( @@ -5064,9 +5060,7 @@ def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: if self.dtype == ir.DataType.BOOL: return op.Or(self, other) - if self.dtype.is_integer(): - return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) - raise NotImplementedError(f"Not implemented for dtype {self.dtype} and {other.dtype}") + return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) @torch_op( @@ -5080,9 +5074,7 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: if self.dtype == ir.DataType.BOOL: return op.Xor(self, other) - if self.dtype.is_integer(): - return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) - raise NotImplementedError(f"Not implemented for dtype {self.dtype} and {other.dtype}") + return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) @torch_op("aten::logit", private=True) From af18ec76ea5ade599438d157b837f427b540070f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 11:20:58 -0700 Subject: [PATCH 3/9] Format Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index bcbca5180..a6a25147a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5022,10 +5022,7 @@ def aten_logdet(self: TFloat) -> TFloat: return op.Log(op.Det(self)) -@torch_op( - ("aten::logical_and") - trace_only=True, -) +@torch_op(("aten::logical_and"), trace_only=True) def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" @@ -5045,14 +5042,7 @@ def aten_logical_not(self: BOOL) -> BOOL: return op.Not(op.Cast(self, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_or", - "aten::add.Tensor", - "aten::add.Scalar" - ), - trace_only=True, -) +@torch_op(("aten::logical_or", "aten::add.Tensor", "aten::add.Scalar"), trace_only=True) def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" @@ -5063,10 +5053,7 @@ def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op( - ("aten::logical_xor") - trace_only=True, -) +@torch_op(("aten::logical_xor"), trace_only=True) def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" From de794223e140870b6025fbf040dd76dbddf68f9a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 11:24:19 -0700 Subject: [PATCH 4/9] Tests Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- tests/function_libs/torch_lib/ops_test_data.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a6a25147a..781ba9b1b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5034,7 +5034,7 @@ def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: @torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) -def aten_logical_not(self: BOOL) -> BOOL: +def aten_logical_not(self: TTensor) -> BOOL: """logical_not(Tensor self) -> Tensor""" if self.dtype == ir.DataType.BOOL: @@ -5043,7 +5043,7 @@ def aten_logical_not(self: BOOL) -> BOOL: @torch_op(("aten::logical_or", "aten::add.Tensor", "aten::add.Scalar"), trace_only=True) -def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: +def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" assert self.dtype == other.dtype @@ -5054,7 +5054,7 @@ def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: @torch_op(("aten::logical_xor"), trace_only=True) -def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: +def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" assert self.dtype == other.dtype diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b1e0c529e..98d10d9e5 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1631,6 +1631,10 @@ def _where_input_wrangler( dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,), reason="fixme: test is unstable on macosx, windows", ), + TorchLibOpInfo("logical_and", core_ops.aten_logical_and), + TorchLibOpInfo("logical_not", core_ops.aten_logical_not), + TorchLibOpInfo("logical_or", core_ops.aten_logical_or), + TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor), TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}), TorchLibOpInfo("max_dim", core_ops.aten_max_dim) .xfail( From da7a40bd75e0d1db7eef62925697350cf642d405 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 11:28:53 -0700 Subject: [PATCH 5/9] aten_bitwise_not Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 781ba9b1b..2d667630e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1333,11 +1333,14 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: @torch_op("aten::bitwise_not", trace_only=True) -def aten_bitwise_not(self: TInt) -> TInt: +def aten_bitwise_not(self: TTensor) -> TTensor: """bitwise_not(Tensor self) -> Tensor""" - # logical_not implements the BOOL variant - return op.BitwiseNot(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + if self.dtype.is_integer(): + return op.BitwiseNot(self) + raise NotImplementedError(f"Not implemented for type {self.dtype}") @torch_op( @@ -1348,7 +1351,7 @@ def aten_bitwise_not(self: TInt) -> TInt: ), trace_only=True, ) -def aten_bitwise_or(self: TInt, other: TInt) -> TInt: +def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" assert self.dtype == other.dtype @@ -1495,7 +1498,7 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: ), trace_only=True, ) -def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: +def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor: """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" assert self.dtype == other.dtype @@ -5033,7 +5036,7 @@ def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) +@torch_op("aten::logical_not", trace_only=True) def aten_logical_not(self: TTensor) -> BOOL: """logical_not(Tensor self) -> Tensor""" From 9351d517a9e80b987e3f8cdbdf074b3373460b26 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 11:59:53 -0700 Subject: [PATCH 6/9] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2d667630e..01e5e7592 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5025,7 +5025,7 @@ def aten_logdet(self: TFloat) -> TFloat: return op.Log(op.Det(self)) -@torch_op(("aten::logical_and"), trace_only=True) +@torch_op("aten::logical_and", trace_only=True) def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" From b5ca7296e24d1585a0b565aaef23f172f71c0cc6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 12:00:00 -0700 Subject: [PATCH 7/9] Update onnxscript/function_libs/torch_lib/ops/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 01e5e7592..23e467580 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5056,7 +5056,7 @@ def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op(("aten::logical_xor"), trace_only=True) +@torch_op("aten::logical_xor", trace_only=True) def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" From b779ab84855328389bf1ce06dd583a31c20d96fb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 15:23:18 -0700 Subject: [PATCH 8/9] Move bool overload to aten_add to avoid dispatcher conflicts Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 23e467580..e67f27e92 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -162,9 +162,12 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) -def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(microsoft/onnxruntime#15977): Improve fp16 precision + + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + if alpha != 1.0: alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) @@ -5045,7 +5048,7 @@ def aten_logical_not(self: TTensor) -> BOOL: return op.Not(op.Cast(self, to=BOOL.dtype)) -@torch_op(("aten::logical_or", "aten::add.Tensor", "aten::add.Scalar"), trace_only=True) +@torch_op(("aten::logical_or"), trace_only=True) def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" From ec8bda322e65653ad0c59df4fd7f8ff4c8b27a39 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 29 Sep 2025 15:26:21 -0700 Subject: [PATCH 9/9] Fix alpha Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e67f27e92..96b92c2e8 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -166,6 +166,9 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" if self.dtype == ir.DataType.BOOL: + # alpha can also be bool + if alpha == 0: + return op.Identity(self) return op.Or(self, other) if alpha != 1.0: