Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3732,7 +3732,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:


@torch_op(
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor"),
trace_only=True,
)
def aten_ge(self: TTensor, other: TTensor) -> BOOL:
Expand All @@ -3749,6 +3749,12 @@ def aten_ge(self: TTensor, other: TTensor) -> BOOL:
return op.GreaterOrEqual(self, other)


@torch_op("_operator::ge", trace_only=True)
def operator_ge(self: TTensor, other: TTensor) -> BOOL:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be TInt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more lenient, which I think is ok

# operator.ge for SymInt
return op.GreaterOrEqual(self, other)


def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]:
"""geqrf(Tensor self) -> (Tensor a, Tensor tau)"""

Expand Down Expand Up @@ -4058,7 +4064,7 @@ def aten_gru_cell(


@torch_op(
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"),
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor"),
trace_only=True,
)
def aten_gt(self: TTensor, other: TTensor) -> BOOL:
Expand All @@ -4076,6 +4082,12 @@ def aten_gt(self: TTensor, other: TTensor) -> BOOL:
return op.Greater(self, other)


@torch_op("_operator::gt", trace_only=True)
def operator_gt(self: TTensor, other: TTensor) -> BOOL:
# operator.gt for SymInt
return op.Greater(self, other)


@torch_op("aten::hamming_window", trace_only=True)
def aten_hamming_window(
window_length: int,
Expand Down Expand Up @@ -4891,7 +4903,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:


@torch_op(
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"),
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor"),
trace_only=True,
)
def aten_le(self: TTensor, other: TTensor) -> BOOL:
Expand All @@ -4909,6 +4921,12 @@ def aten_le(self: TTensor, other: TTensor) -> BOOL:
return op.LessOrEqual(self, other)


@torch_op("_operator::le", trace_only=True)
def operator_le(self: TTensor, other: TTensor) -> BOOL:
# operator.le for SymInt
return op.LessOrEqual(self, other)


@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar"))
def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor:
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""
Expand Down Expand Up @@ -5384,7 +5402,7 @@ def aten_lstm(


@torch_op(
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor"),
trace_only=True,
)
def aten_lt(self: TTensor, other: TTensor) -> BOOL:
Expand All @@ -5401,6 +5419,12 @@ def aten_lt(self: TTensor, other: TTensor) -> BOOL:
return op.Less(self, other)


@torch_op("_operator::lt", trace_only=True)
def operator_lt(self: TTensor, other: TTensor) -> BOOL:
# operator.lt for SymInt
return op.Less(self, other)


def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
"""lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor"""

Expand Down Expand Up @@ -7468,9 +7492,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()


@torch_op(
("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True
)
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True)
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -7486,6 +7508,12 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
return op.Sub(self, op.Mul(rounded_quotient, other))


@torch_op("_operator::mod", trace_only=True)
def operator_mod(self: TTensor, other: TTensor) -> TTensor:
# Modulus operator % on SymInt
return op.Mod(self, other)


def aten_rename(self: TensorType, names: Optional[str]) -> TensorType:
"""rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)"""

Expand Down
Loading