Skip to content
3 changes: 2 additions & 1 deletion onnxscript/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# isort: on

from .utils import external_tensor, proto2text
from .values import OnnxFunction
from .values import OnnxFunction, TracedOnnxFunction

try:
__version__ = importlib.metadata.version("onnxscript")
Expand All @@ -66,6 +66,7 @@
"script",
"export_onnx_lib",
"OnnxFunction",
"TracedOnnxFunction",
"proto2python",
"proto2text",
"external_tensor",
Expand Down
11 changes: 7 additions & 4 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@ def shape(self, shape: Tuple[int | None, ...]):
self._shape = shape
self._torch_value.setType(self._torch_value.type().with_sizes(list(shape)))

@property
def dtype(self):
@property # type: ignore[override]
def dtype(self) -> torch.dtype | None:
# TODO: Return numpy dtype
return _type_utils.JitScalarType.from_value( # type: ignore[attr-defined]
torch_dtype = _type_utils.JitScalarType.from_value( # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 Is it possible to extend _type_utils.JitScalarType.from_value to handle sequence type properly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, you don't have a torch.dtype for sequence to be passed in from dtype.setter to begin with.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly feels it's a bit overloaded and doesn't seem to match the design for this class, to be also covering sequence type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add another attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of alternative way that I can think of is that we give up this dtype, and try to relax match_schema in a way that it finds the best match instead of "the 100% match", as I found some inputs are just None anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

self._torch_value, default=_type_utils.JitScalarType.UNDEFINED
).dtype()
)
if torch_dtype == _type_utils.JitScalarType.UNDEFINED:
return None
return torch_dtype.dtype()

@dtype.setter
def dtype(self, dtype: torch.dtype):
Expand Down
42 changes: 26 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@


@torch_op("aten::abs")
def aten_abs(self: TReal) -> TReal:
def aten_abs(self: TrealOrUInt8) -> TrealOrUInt8:
"""abs(Tensor self) -> Tensor"""

return op.Abs(self)


@torch_op("aten::abs")
def aten_abs_complex(self: TReal) -> TReal:
@torch_op("aten::abs", complex=True)
def aten_abs_complex(self: TrealOrUInt8) -> TrealOrUInt8:
"""abs(Tensor self) -> Tensor"""
# self_real = self[..., 0]
self_real = op.Gather(self, 0, axis=-1)
Expand Down Expand Up @@ -250,15 +250,15 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType:


@torch_op("aten::amax")
def aten_amax(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
def aten_amax(self: TrealOrUInt8, dim: INT64, keepdim: bool = False) -> TrealOrUInt8:
"""amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"""

# ReduceMax reduces all dimensions when dim is empty
return op.ReduceMax(self, dim, keepdims=keepdim)


@torch_op("aten::amin")
def aten_amin(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
def aten_amin(self: TrealOrUInt8, dim: INT64, keepdim: bool = False) -> TrealOrUInt8:
"""amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"""

# ReduceMin reduces all dimensions when dim is empty
Expand Down Expand Up @@ -469,17 +469,19 @@ def aten_arctanh(self: TensorType) -> TensorType:


@torch_op("aten::argmax", trace_only=True)
def aten_argmax(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> TReal:
def aten_argmax(
self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

if dim is None: # TODO: use OptionalHasElement(dim)
self = op.Reshape(self, op.Constant(value_ints=[-1]))

return aten_argmax_dim(self, dim=dim, keepdim=keepdim)
return _aten_argmax_dim(self, dim=dim, keepdim=keepdim)


@torch_op("aten::argmax")
def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TReal:
@torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
"""argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

self_is_scaler = op.Size(op.Shape(self)) == 0
Expand All @@ -494,17 +496,19 @@ def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TReal:


@torch_op("aten::argmin", trace_only=True)
def aten_argmin(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> TReal:
def aten_argmin(
self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

if dim is None: # TODO: use OptionalHasElement(dim)
self = op.Reshape(self, op.Constant(value_ints=[-1]))

return aten_argmin_dim(self, dim=dim, keepdim=keepdim)
return _aten_argmin_dim(self, dim=dim, keepdim=keepdim)


@torch_op("aten::argmin")
def aten_argmin_dim(self: TReal, dim: int, keepdim: bool = False) -> TReal:
@torch_op("aten::argmin", private=True)
def _aten_argmin_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
"""argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"""

self_is_scaler = op.Size(op.Shape(self)) == 0
Expand Down Expand Up @@ -2458,7 +2462,7 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType:


@torch_op("aten::fmod")
def aten_fmod(self: TReal, other: TReal) -> TReal:
def aten_fmod(self: TrealOrUInt8, other: TrealOrUInt8) -> TrealOrUInt8:
"""fmod.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Mod(self, other, fmod=1)
Expand Down Expand Up @@ -2586,7 +2590,7 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType:

# NOTE: The name is made up for `getitem` to be included in the registry
@torch_op("aten::getitem")
def aten_getitem(self: Sequence[TReal], i: INT64) -> TReal:
def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor:
return op.SequenceAt(self, i)


Expand Down Expand Up @@ -2876,7 +2880,10 @@ def aten_index_put(
if op.Cast(accumulate, to=BOOL.dtype):
# put values into zeros array first, then add to input
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
zeros = op.CastLike(zeros, values)
result = op.ScatterElements(zeros, new_ind_t, values)
# FIXME: type promotion
result = op.CastLike(result, self)
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)
Expand Down Expand Up @@ -2917,7 +2924,10 @@ def aten_index_put_bool(

if op.Cast(accumulate, to=BOOL.dtype):
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
zeros = op.CastLike(zeros, values)
result = op.ScatterElements(zeros, new_ind_t, values)
# FIXME: type promotion
result = op.CastLike(result, self)
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)
Expand Down Expand Up @@ -3348,7 +3358,7 @@ def aten_linspace(start: float, end: float, steps: int) -> TensorType:
raise NotImplementedError()


@torch_op("log")
@torch_op("aten::log")
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
"""log(Tensor self) -> Tensor"""

Expand Down
17 changes: 14 additions & 3 deletions onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from types import FunctionType
from typing import Any, Callable, Optional
from typing import Any, Callable, Generator, Optional

import onnxscript

Expand All @@ -15,12 +15,14 @@ class OverloadedFunction:
name: Name of the op. E.g. "aten::add".
overloads: Overloads function.
privates: Private functions not exposed to users.
complex: Support complex functions.
"""

def __init__(self, name: str):
self.name = name
self.overloads: list[Any] = []
self.privates: list[Any] = []
self.complex: list[Any] = []


class Registry:
Expand All @@ -29,11 +31,15 @@ class Registry:
def __init__(self):
self._registry: dict[str, OverloadedFunction] = {}

def register(self, func: Any, name: str, *, private: bool = False) -> None:
def register(
self, func: Any, name: str, *, private: bool = False, complex: bool = False
) -> None:
"""Register a function."""

if private:
self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func)
elif complex:
self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func)
else:
self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func)

Expand All @@ -49,6 +55,9 @@ def __iter__(self):
def __repr__(self):
return repr(self._registry)

def items(self) -> Generator[tuple[str, OverloadedFunction], None, None]:
yield from self._registry.items()


# Default registry
default_registry = Registry()
Expand All @@ -60,6 +69,7 @@ def torch_op(
registry: Optional[Registry] = None,
trace_only: bool = False,
private: bool = False,
complex: bool = False,
) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]:
"""Register a torch op.

Expand All @@ -69,6 +79,7 @@ def torch_op(
trace_only: Whether the function should only be traced and not compiled.
private: Whether the function is private (not directly exposed). It should
be true for all functions with names starting with "_".
complex: Whether the function supports complex.
"""
if registry is None:
registry = default_registry
Expand All @@ -87,7 +98,7 @@ def wrapper(
processed_func = onnxscript.script(opset=custom_opset)(func)

assert registry is not None
registry.register(processed_func, name, private=private)
registry.register(processed_func, name, private=private, complex=complex)
return processed_func

return wrapper
1 change: 0 additions & 1 deletion onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ def op_schema_from_function_ir(
)
for arg in function_ir.outputs
]

return onnx.defs.OpSchema(
function_ir.name,
opset.domain,
Expand Down