diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 99fc6fb44f..96b92c2e8e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -162,9 +162,15 @@ def aten_acosh(self: TFloat) -> TFloat: @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) -def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(microsoft/onnxruntime#15977): Improve fp16 precision + + if self.dtype == ir.DataType.BOOL: + # alpha can also be bool + if alpha == 0: + return op.Identity(self) + return op.Or(self, other) + if alpha != 1.0: alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) @@ -1233,15 +1239,19 @@ def aten_binomial( "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", - "_operator::and_", ), trace_only=True, ) -def aten_bitwise_and(self: TInt, other: TInt) -> TInt: +def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_and implements the BOOL variant - return op.BitwiseAnd(self, other) + assert self.dtype == other.dtype + + if self.dtype.is_integer(): + return op.BitwiseAnd(self, other) + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op( @@ -1329,11 +1339,14 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: @torch_op("aten::bitwise_not", trace_only=True) -def aten_bitwise_not(self: TInt) -> TInt: +def aten_bitwise_not(self: TTensor) -> TTensor: """bitwise_not(Tensor self) -> Tensor""" - # logical_not implements the BOOL variant - return op.BitwiseNot(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + if self.dtype.is_integer(): + return op.BitwiseNot(self) + raise NotImplementedError(f"Not implemented for type {self.dtype}") @torch_op( @@ -1341,15 +1354,19 @@ def aten_bitwise_not(self: TInt) -> TInt: "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", - "_operator::or_", ), trace_only=True, ) -def aten_bitwise_or(self: TInt, other: TInt) -> TInt: +def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_or implements the BOOL variant - return op.BitwiseOr(self, other) + assert self.dtype == other.dtype + + if self.dtype.is_integer(): + return op.BitwiseOr(self, other) + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op( @@ -1487,11 +1504,15 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: ), trace_only=True, ) -def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: +def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor: """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_xor implements the BOOL variant + assert self.dtype == other.dtype - return op.BitwiseXor(self, other) + if self.dtype.is_integer(): + return op.BitwiseXor(self, other) + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") @torch_op("aten::blackman_window", trace_only=True) @@ -5010,58 +5031,46 @@ def aten_logdet(self: TFloat) -> TFloat: return op.Log(op.Det(self)) -@torch_op( - ( - "aten::logical_and", - "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", - ), - trace_only=True, -) -def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: +@torch_op("aten::logical_and", trace_only=True) +def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" - return op.And(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) -def aten_logical_not(self: BOOL) -> BOOL: +@torch_op("aten::logical_not", trace_only=True) +def aten_logical_not(self: TTensor) -> BOOL: """logical_not(Tensor self) -> Tensor""" - return op.Not(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + return op.Not(op.Cast(self, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_or", - "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", - "aten::add.Tensor", - "aten::add.Scalar", - ), - trace_only=True, -) -def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: +@torch_op(("aten::logical_or"), trace_only=True) +def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" - return op.Or(self, other) + assert self.dtype == other.dtype + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_xor", - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ), - trace_only=True, -) -def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: + +@torch_op("aten::logical_xor", trace_only=True) +def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: """logical_xor(Tensor self, Tensor other) -> Tensor""" - return op.Xor(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) @torch_op("aten::logit", private=True) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b1e0c529ec..98d10d9e5b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1631,6 +1631,10 @@ def _where_input_wrangler( dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,), reason="fixme: test is unstable on macosx, windows", ), + TorchLibOpInfo("logical_and", core_ops.aten_logical_and), + TorchLibOpInfo("logical_not", core_ops.aten_logical_not), + TorchLibOpInfo("logical_or", core_ops.aten_logical_or), + TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor), TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}), TorchLibOpInfo("max_dim", core_ops.aten_max_dim) .xfail(