Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def aten__softmax(
return aten_softmax_no_dtype(self, dim)


@torch_op("aten::abs")
@torch_op(("aten::abs", "_operator::abs"))
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -158,7 +158,7 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op(("aten::add", "aten::add.Tensor"))
@torch_op(("aten::add", "aten::add.Tensor", "_operator::add"))
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
Expand Down Expand Up @@ -1163,6 +1163,7 @@ def aten_binomial(
"aten::bitwise_and.Tensor",
"aten::bitwise_and.Scalar",
"aten::bitwise_and.Scalar_Tensor",
"_operator::and_",
)
)
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
Expand Down Expand Up @@ -1234,6 +1235,7 @@ def aten_bitwise_not(self: TInt) -> TInt:
"aten::bitwise_or.Tensor",
"aten::bitwise_or.Scalar",
"aten::bitwise_or.Scalar_Tensor",
"_operator::or_",
)
)
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
Expand Down Expand Up @@ -1443,6 +1445,13 @@ def aten_ceil(self: TFloat) -> TFloat:
return op.Ceil(self)


@torch_op("math::ceil")
def python_math_ceil(self: TFloat) -> TInt:
"""ceil(Tensor self) -> Tensor"""
ceil = op.Ceil(self)
return op.Cast(ceil, to=INT64.dtype)


def aten_chain_matmul(matrices: Sequence[TensorType]) -> TensorType:
"""chain_matmul(Tensor[] matrices) -> Tensor"""

Expand Down Expand Up @@ -2617,6 +2626,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
"aten::div.Scalar_mode",
"aten::divide",
"aten::true_divide",
"_operator::truediv",
)
)
def aten_div(self: TFloat, other: TFloat) -> TFloat:
Expand Down Expand Up @@ -3372,7 +3382,14 @@ def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Floor(self)


@torch_op("aten::floor_divide")
@torch_op("math::floor")
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)


@torch_op(("aten::floor_divide", "_operator::floordiv"))
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3514,7 +3531,9 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal"))
@torch_op(
("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal", "_operator::ge")
)
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3670,7 +3689,7 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater"))
@torch_op(("aten::gt", "aten::gt.Scalar", "aten::greater", "_operator::gt"))
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -4382,7 +4401,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le", "aten::le.Tensor"))
@torch_op(("aten::le", "aten::le.Tensor", "_operator::le"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -4686,7 +4705,7 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less"))
@torch_op(("aten::lt", "aten::lt.Scalar", "aten::less", "_operator::lt"))
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5288,7 +5307,7 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::mul", "aten::mul.Tensor"))
@torch_op(("aten::mul", "aten::mul.Tensor", "_operator::mul"))
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
Expand Down Expand Up @@ -5739,14 +5758,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor"))
@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"))
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Not(op.Equal(self, other))


@torch_op("aten::neg")
@torch_op(("aten::neg", "_operator::neg"))
def aten_neg(self: TReal) -> TReal:
"""neg(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -6126,7 +6145,9 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar"))
@torch_op(
("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar", "_operator::pow")
)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

Expand Down Expand Up @@ -7406,7 +7427,7 @@ def aten_stft(
return result


@torch_op(("aten::sub", "aten::sub.Tensor"))
@torch_op(("aten::sub", "aten::sub.Tensor", "_operator::sub"))
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand Down