diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 99fc6fb44..90c560d9f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5025,6 +5025,13 @@ def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: return op.And(self, other) +@torch_op("aten::logical_and", trace_only=True) +def aten_logical_and_non_bool_type(self: TTensor, other: TTensor2) -> BOOL: + """logical_and(Tensor self, Tensor other) -> Tensor""" + + return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) + + @torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) def aten_logical_not(self: BOOL) -> BOOL: """logical_not(Tensor self) -> Tensor""" @@ -5032,6 +5039,13 @@ def aten_logical_not(self: BOOL) -> BOOL: return op.Not(self) +@torch_op("aten::logical_not", trace_only=True) +def aten_logical_not_non_bool_type(self: TTensor) -> BOOL: + """logical_not(Tensor self) -> Tensor""" + + return op.Equal(self, 0) + + @torch_op( ( "aten::logical_or", @@ -5049,6 +5063,13 @@ def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: return op.Or(self, other) +@torch_op("aten::logical_or", trace_only=True) +def aten_logical_or_non_bool_type(self: TTensor, other: TTensor2) -> BOOL: + """logical_or(Tensor self, Tensor other) -> Tensor""" + + return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) + + @torch_op( ( "aten::logical_xor", @@ -5064,6 +5085,13 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: return op.Xor(self, other) +@torch_op("aten::logical_xor", trace_only=True) +def aten_logical_xor_non_bool_type(self: TTensor, other: TTensor2) -> BOOL: + """logical_xor(Tensor self, Tensor other) -> Tensor""" + + return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) + + @torch_op("aten::logit", private=True) def _aten_logit_onnx(self: TFloat) -> TFloat: return op.Log(op.Div(self, op.Sub(1.0, self)))