diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py index 512357f1ac..c405ed4224 100644 --- a/onnxscript/__init__.py +++ b/onnxscript/__init__.py @@ -54,7 +54,7 @@ # isort: on from .utils import external_tensor, proto2text -from .values import OnnxFunction +from .values import OnnxFunction, TracedOnnxFunction try: __version__ = importlib.metadata.version("onnxscript") @@ -66,6 +66,7 @@ "script", "export_onnx_lib", "OnnxFunction", + "TracedOnnxFunction", "proto2python", "proto2text", "external_tensor", diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index ebe7b43c4a..aea9125c8f 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -127,12 +127,15 @@ def shape(self, shape: Tuple[int | None, ...]): self._shape = shape self._torch_value.setType(self._torch_value.type().with_sizes(list(shape))) - @property - def dtype(self): + @property # type: ignore[override] + def dtype(self) -> torch.dtype | None: # TODO: Return numpy dtype - return _type_utils.JitScalarType.from_value( # type: ignore[attr-defined] + torch_dtype = _type_utils.JitScalarType.from_value( # type: ignore[attr-defined] self._torch_value, default=_type_utils.JitScalarType.UNDEFINED - ).dtype() + ) + if torch_dtype == _type_utils.JitScalarType.UNDEFINED: + return None + return torch_dtype.dtype() @dtype.setter def dtype(self, dtype: torch.dtype): diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b4bea9c2f3..1dc96985be 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -38,14 +38,14 @@ @torch_op("aten::abs") -def aten_abs(self: TReal) -> TReal: +def aten_abs(self: TrealOrUInt8) -> TrealOrUInt8: """abs(Tensor self) -> Tensor""" return op.Abs(self) -@torch_op("aten::abs") -def aten_abs_complex(self: TReal) -> TReal: +@torch_op("aten::abs", complex=True) +def aten_abs_complex(self: TrealOrUInt8) -> TrealOrUInt8: """abs(Tensor self) -> Tensor""" # self_real = self[..., 0] self_real = op.Gather(self, 0, axis=-1) @@ -250,7 +250,7 @@ def aten_alpha_dropout(input: TensorType, p: float, train: bool) -> TensorType: @torch_op("aten::amax") -def aten_amax(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: +def aten_amax(self: TrealOrUInt8, dim: INT64, keepdim: bool = False) -> TrealOrUInt8: """amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor""" # ReduceMax reduces all dimensions when dim is empty @@ -258,7 +258,7 @@ def aten_amax(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: @torch_op("aten::amin") -def aten_amin(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: +def aten_amin(self: TrealOrUInt8, dim: INT64, keepdim: bool = False) -> TrealOrUInt8: """amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor""" # ReduceMin reduces all dimensions when dim is empty @@ -469,17 +469,19 @@ def aten_arctanh(self: TensorType) -> TensorType: @torch_op("aten::argmax", trace_only=True) -def aten_argmax(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> TReal: +def aten_argmax( + self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False +) -> TrealOrUInt8: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" if dim is None: # TODO: use OptionalHasElement(dim) self = op.Reshape(self, op.Constant(value_ints=[-1])) - return aten_argmax_dim(self, dim=dim, keepdim=keepdim) + return _aten_argmax_dim(self, dim=dim, keepdim=keepdim) -@torch_op("aten::argmax") -def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TReal: +@torch_op("aten::argmax", private=True) +def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" self_is_scaler = op.Size(op.Shape(self)) == 0 @@ -494,17 +496,19 @@ def aten_argmax_dim(self: TReal, dim: int, keepdim: bool = False) -> TReal: @torch_op("aten::argmin", trace_only=True) -def aten_argmin(self: TReal, dim: Optional[int] = None, keepdim: bool = False) -> TReal: +def aten_argmin( + self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False +) -> TrealOrUInt8: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" if dim is None: # TODO: use OptionalHasElement(dim) self = op.Reshape(self, op.Constant(value_ints=[-1])) - return aten_argmin_dim(self, dim=dim, keepdim=keepdim) + return _aten_argmin_dim(self, dim=dim, keepdim=keepdim) -@torch_op("aten::argmin") -def aten_argmin_dim(self: TReal, dim: int, keepdim: bool = False) -> TReal: +@torch_op("aten::argmin", private=True) +def _aten_argmin_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" self_is_scaler = op.Size(op.Shape(self)) == 0 @@ -2458,7 +2462,7 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::fmod") -def aten_fmod(self: TReal, other: TReal) -> TReal: +def aten_fmod(self: TrealOrUInt8, other: TrealOrUInt8) -> TrealOrUInt8: """fmod.Tensor(Tensor self, Tensor other) -> Tensor""" return op.Mod(self, other, fmod=1) @@ -2586,7 +2590,7 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType: # NOTE: The name is made up for `getitem` to be included in the registry @torch_op("aten::getitem") -def aten_getitem(self: Sequence[TReal], i: INT64) -> TReal: +def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: return op.SequenceAt(self, i) @@ -2876,7 +2880,10 @@ def aten_index_put( if op.Cast(accumulate, to=BOOL.dtype): # put values into zeros array first, then add to input zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) + zeros = op.CastLike(zeros, values) result = op.ScatterElements(zeros, new_ind_t, values) + # FIXME: type promotion + result = op.CastLike(result, self) result = op.Add(result, self) else: result = op.ScatterElements(self, new_ind_t, values) @@ -2917,7 +2924,10 @@ def aten_index_put_bool( if op.Cast(accumulate, to=BOOL.dtype): zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) + zeros = op.CastLike(zeros, values) result = op.ScatterElements(zeros, new_ind_t, values) + # FIXME: type promotion + result = op.CastLike(result, self) result = op.Add(result, self) else: result = op.ScatterElements(self, new_ind_t, values) @@ -3348,7 +3358,7 @@ def aten_linspace(start: float, end: float, steps: int) -> TensorType: raise NotImplementedError() -@torch_op("log") +@torch_op("aten::log") def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16: """log(Tensor self) -> Tensor""" diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index d7f32df042..541ca00cec 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -3,7 +3,7 @@ from __future__ import annotations from types import FunctionType -from typing import Any, Callable, Optional +from typing import Any, Callable, Generator, Optional import onnxscript @@ -15,12 +15,14 @@ class OverloadedFunction: name: Name of the op. E.g. "aten::add". overloads: Overloads function. privates: Private functions not exposed to users. + complex: Support complex functions. """ def __init__(self, name: str): self.name = name self.overloads: list[Any] = [] self.privates: list[Any] = [] + self.complex: list[Any] = [] class Registry: @@ -29,11 +31,15 @@ class Registry: def __init__(self): self._registry: dict[str, OverloadedFunction] = {} - def register(self, func: Any, name: str, *, private: bool = False) -> None: + def register( + self, func: Any, name: str, *, private: bool = False, complex: bool = False + ) -> None: """Register a function.""" if private: self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func) + elif complex: + self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func) else: self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func) @@ -49,6 +55,9 @@ def __iter__(self): def __repr__(self): return repr(self._registry) + def items(self) -> Generator[tuple[str, OverloadedFunction], None, None]: + yield from self._registry.items() + # Default registry default_registry = Registry() @@ -60,6 +69,7 @@ def torch_op( registry: Optional[Registry] = None, trace_only: bool = False, private: bool = False, + complex: bool = False, ) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. @@ -69,6 +79,7 @@ def torch_op( trace_only: Whether the function should only be traced and not compiled. private: Whether the function is private (not directly exposed). It should be true for all functions with names starting with "_". + complex: Whether the function supports complex. """ if registry is None: registry = default_registry @@ -87,7 +98,7 @@ def wrapper( processed_func = onnxscript.script(opset=custom_opset)(func) assert registry is not None - registry.register(processed_func, name, private=private) + registry.register(processed_func, name, private=private, complex=complex) return processed_func return wrapper diff --git a/onnxscript/values.py b/onnxscript/values.py index fa256703b3..71016f1466 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -396,7 +396,6 @@ def op_schema_from_function_ir( ) for arg in function_ir.outputs ] - return onnx.defs.OpSchema( function_ir.name, opset.domain,