From efe2a58de3429079ceac07c442d7aaebf5d5cf3e Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Sun, 5 Nov 2023 18:27:21 +0000 Subject: [PATCH 1/2] add builtin ops --- .../function_libs/torch_lib/ops/core.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index add9dd3980..cd74fa45fe 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -124,7 +124,7 @@ def aten__softmax( return aten_softmax_no_dtype(self, dim) -@torch_op("aten::abs") +@torch_op(("aten::abs", "_operator::abs")) def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8: """abs(Tensor self) -> Tensor""" @@ -158,7 +158,7 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op(("aten::add", "aten::add.Tensor")) +@torch_op(("aten::add", "aten::add.Tensor", "_operator::add")) def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" # TODO(microsoft/onnxruntime#15977): Improve fp16 precision @@ -1163,6 +1163,7 @@ def aten_binomial( "aten::bitwise_and.Tensor", "aten::bitwise_and.Scalar", "aten::bitwise_and.Scalar_Tensor", + "_operator::and_", ) ) def aten_bitwise_and(self: TInt, other: TInt) -> TInt: @@ -1234,6 +1235,7 @@ def aten_bitwise_not(self: TInt) -> TInt: "aten::bitwise_or.Tensor", "aten::bitwise_or.Scalar", "aten::bitwise_or.Scalar_Tensor", + "_operator::or_", ) ) def aten_bitwise_or(self: TInt, other: TInt) -> TInt: @@ -1436,7 +1438,7 @@ def aten_cdist( raise NotImplementedError() -@torch_op("aten::ceil") +@torch_op(("aten::ceil", "math::ceil")) def aten_ceil(self: TFloat) -> TFloat: """ceil(Tensor self) -> Tensor""" @@ -2617,6 +2619,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType "aten::div.Scalar_mode", "aten::divide", "aten::true_divide", + "_operator::truediv", ) ) def aten_div(self: TFloat, other: TFloat) -> TFloat: @@ -3106,7 +3109,7 @@ def aten_empty_strided( return op.Expand(zero, size) -@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar")) +@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq")) def aten_eq(self: TTensor, other: TTensor) -> BOOL: """eq.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3361,14 +3364,14 @@ def aten_flipud(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::floor") +@torch_op(("aten::floor", "math::floor")) def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """floor(Tensor self) -> Tensor""" return op.Floor(self) -@torch_op("aten::floor_divide") +@torch_op(("aten::floor_divide", "_operator::floordiv")) def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: """floor_divide(Tensor self, Tensor other) -> Tensor""" @@ -3510,7 +3513,9 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal")) +@torch_op( + ("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal", "_operator::ge") +) def aten_ge(self: TReal, other: TReal) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3666,7 +3671,7 @@ def aten_gru_cell( raise NotImplementedError() -@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater")) +@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater", "_operator::gt")) def aten_gt(self: TReal, other: TReal) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4378,7 +4383,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::le", "aten::le.Tensor")) +@torch_op(("aten::le", "aten::le.Tensor", "_operator::le")) def aten_le(self: TReal, other: TReal) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -4682,7 +4687,7 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less")) +@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less", "_operator::lt")) def aten_lt(self: TReal, other: TReal) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -5284,7 +5289,7 @@ def aten_msort(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::mul", "aten::mul.Tensor")) +@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul")) def aten_mul(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" # FIXME(titaiwang): get rid of this when we have type_promotion @@ -5735,14 +5740,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType: raise NotImplementedError() -@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor")) +@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne")) def aten_ne(self: TReal, other: TReal) -> BOOL: """ne.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Not(op.Equal(self, other)) -@torch_op("aten::neg") +@torch_op(("aten::neg", "_operator::neg")) def aten_neg(self: TReal) -> TReal: """neg(Tensor self) -> Tensor""" @@ -6122,7 +6127,9 @@ def aten_positive(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar")) +@torch_op( + ("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow") +) def aten_pow(self: TReal, exponent: TTensor) -> TReal: """pow(Tensor self, Tensor exponent) -> Tensor""" @@ -7402,7 +7409,7 @@ def aten_stft( return result -@torch_op(("aten::sub", "aten::sub.Tensor")) +@torch_op(("aten::sub", "aten::sub.Tensor", "_operator::sub")) def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" alpha = op.CastLike(alpha, other) From 975979d72a218840e3d22ed8740a46eb5056a05c Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 7 Nov 2023 01:15:38 +0000 Subject: [PATCH 2/2] split python math functions --- .../function_libs/torch_lib/ops/core.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cd74fa45fe..7e108b99c8 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1438,13 +1438,20 @@ def aten_cdist( raise NotImplementedError() -@torch_op(("aten::ceil", "math::ceil")) +@torch_op("aten::ceil") def aten_ceil(self: TFloat) -> TFloat: """ceil(Tensor self) -> Tensor""" return op.Ceil(self) +@torch_op("math::ceil") +def python_math_ceil(self: TFloat) -> TInt: + """ceil(Tensor self) -> Tensor""" + ceil = op.Ceil(self) + return op.Cast(ceil, to=INT64.dtype) + + def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType: """chain_matmul(Tensor[] matrices) -> Tensor""" @@ -3109,7 +3116,7 @@ def aten_empty_strided( return op.Expand(zero, size) -@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq")) +@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar")) def aten_eq(self: TTensor, other: TTensor) -> BOOL: """eq.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3364,13 +3371,20 @@ def aten_flipud(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op(("aten::floor", "math::floor")) +@torch_op("aten::floor") def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """floor(Tensor self) -> Tensor""" return op.Floor(self) +@torch_op("math::floor") +def python_math_floor(self: TFloatOrBFloat16) -> TInt: + """floor(Tensor self) -> Tensor""" + floor = op.Floor(self) + return op.Cast(floor, to=INT64.dtype) + + @torch_op(("aten::floor_divide", "_operator::floordiv")) def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: """floor_divide(Tensor self, Tensor other) -> Tensor"""