Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 63 additions & 54 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,15 @@ def aten_acosh(self: TFloat) -> TFloat:


@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision

if self.dtype == ir.DataType.BOOL:
# alpha can also be bool
if alpha == 0:
return op.Identity(self)
return op.Or(self, other)

if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
Expand Down Expand Up @@ -1233,15 +1239,19 @@ def aten_binomial(
"aten::bitwise_and.Tensor",
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
"_operator::and_",
),
trace_only=True,
)
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
# logical_and implements the BOOL variant

return op.BitwiseAnd(self, other)
assert self.dtype == other.dtype

if self.dtype.is_integer():
return op.BitwiseAnd(self, other)
if self.dtype == ir.DataType.BOOL:
return op.And(self, other)
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")


@torch_op(
Expand Down Expand Up @@ -1329,27 +1339,34 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:


@torch_op("aten::bitwise_not", trace_only=True)
def aten_bitwise_not(self: TInt) -> TInt:
def aten_bitwise_not(self: TTensor) -> TTensor:
"""bitwise_not(Tensor self) -> Tensor"""
# logical_not implements the BOOL variant

return op.BitwiseNot(self)
if self.dtype == ir.DataType.BOOL:
return op.Not(self)
if self.dtype.is_integer():
return op.BitwiseNot(self)
raise NotImplementedError(f"Not implemented for type {self.dtype}")


@torch_op(
(
"aten::bitwise_or.Tensor",
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"_operator::or_",
),
trace_only=True,
)
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
# logical_or implements the BOOL variant

return op.BitwiseOr(self, other)
assert self.dtype == other.dtype

if self.dtype.is_integer():
return op.BitwiseOr(self, other)
if self.dtype == ir.DataType.BOOL:
return op.Or(self, other)
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")


@torch_op(
Expand Down Expand Up @@ -1487,11 +1504,15 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
),
trace_only=True,
)
def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor:
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
# logical_xor implements the BOOL variant
assert self.dtype == other.dtype

return op.BitwiseXor(self, other)
if self.dtype.is_integer():
return op.BitwiseXor(self, other)
if self.dtype == ir.DataType.BOOL:
return op.Xor(self, other)
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")


@torch_op("aten::blackman_window", trace_only=True)
Expand Down Expand Up @@ -5010,58 +5031,46 @@ def aten_logdet(self: TFloat) -> TFloat:
return op.Log(op.Det(self))


@torch_op(
(
"aten::logical_and",
"aten::bitwise_and.Tensor",
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
),
trace_only=True,
)
def aten_logical_and(self: BOOL, other: BOOL) -> BOOL:
@torch_op("aten::logical_and", trace_only=True)
def aten_logical_and(self: TTensor, other: TTensor) -> BOOL:
"""logical_and(Tensor self, Tensor other) -> Tensor"""

return op.And(self, other)
assert self.dtype == other.dtype

if self.dtype == ir.DataType.BOOL:
return op.And(self, other)
return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))


@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True)
def aten_logical_not(self: BOOL) -> BOOL:
@torch_op("aten::logical_not", trace_only=True)
def aten_logical_not(self: TTensor) -> BOOL:
"""logical_not(Tensor self) -> Tensor"""

return op.Not(self)
if self.dtype == ir.DataType.BOOL:
return op.Not(self)
return op.Not(op.Cast(self, to=BOOL.dtype))


@torch_op(
(
"aten::logical_or",
"aten::bitwise_or.Tensor",
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"aten::add.Tensor",
"aten::add.Scalar",
),
trace_only=True,
)
def aten_logical_or(self: BOOL, other: BOOL) -> BOOL:
@torch_op(("aten::logical_or"), trace_only=True)
def aten_logical_or(self: TTensor, other: TTensor) -> BOOL:
"""logical_or(Tensor self, Tensor other) -> Tensor"""

return op.Or(self, other)
assert self.dtype == other.dtype

if self.dtype == ir.DataType.BOOL:
return op.Or(self, other)
return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))

@torch_op(
(
"aten::logical_xor",
"aten::bitwise_xor.Tensor",
"aten::bitwise_xor.Scalar",
"aten::bitwise_xor.Scalar_Tensor",
),
trace_only=True,
)
def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:

@torch_op("aten::logical_xor", trace_only=True)
def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL:
"""logical_xor(Tensor self, Tensor other) -> Tensor"""

return op.Xor(self, other)
assert self.dtype == other.dtype

if self.dtype == ir.DataType.BOOL:
return op.Xor(self, other)
return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))


@torch_op("aten::logit", private=True)
Expand Down
4 changes: 4 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,10 @@ def _where_input_wrangler(
dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,),
reason="fixme: test is unstable on macosx, windows",
),
TorchLibOpInfo("logical_and", core_ops.aten_logical_and),
TorchLibOpInfo("logical_not", core_ops.aten_logical_not),
TorchLibOpInfo("logical_or", core_ops.aten_logical_or),
TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor),
TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}),
TorchLibOpInfo("max_dim", core_ops.aten_max_dim)
.xfail(
Expand Down
Loading