diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 80240880e0..816047d9a4 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -77,14 +77,14 @@ jobs: pip list | grep torch - name: pytest - run: pytest -v onnxscript --cov=onnxscript --cov-report=xml -n auto + run: pytest -v onnxscript --cov=onnxscript --cov-report=xml -n=auto - name: Install package run: pip install . - name: Test examples if: ${{ matrix.test_examples }} - run: pytest -v docs/test + run: pytest -v docs/test -n=auto - name: Build package run: python -m build diff --git a/onnxscript/backend/onnx_backend.py b/onnxscript/backend/onnx_backend.py index 1d917ffdc0..7ccc5e9a63 100644 --- a/onnxscript/backend/onnx_backend.py +++ b/onnxscript/backend/onnx_backend.py @@ -77,8 +77,8 @@ def _read_proto_from_file(full): loaded = to_list(seq) # type: ignore[assignment] except Exception: # pylint: disable=W0703 try: - loaded = onnx.load_model_from_string(serialized) - except Exception: # pragma: no cover + loaded = onnx.load_model_from_string(serialized) # type: ignore[assignment] + except Exception: raise RuntimeError( f"Unable to read {full!r}, error is {e}, " f"content is {serialized[:100]!r}." diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 8622318eaa..9db8b7ab2f 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -16,6 +16,7 @@ from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64 from onnxscript.function_libs.torch_aten.registration import torch_op from onnxscript.function_libs.torch_aten.typing import ( + IntType, TFloat, TFloatOrBFloat16, TInt, @@ -1642,10 +1643,10 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand") -def aten_expand(self: TTensor, size: INT64) -> TTensor: +def aten_expand(self: TTensor, size: TInt) -> TTensor: # expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - size = op.Cast(size, to=INT64.dtype) # to INT64 + size = op.Cast(size, to=INT64.dtype) return op.Expand(self, size) @@ -3518,10 +3519,11 @@ def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> Tens @torch_op("aten::new_full") def aten_new_full( - self, size: INT64, fill_value, dtype: int = FLOAT.dtype + self, size: IntType, fill_value, dtype: int = FLOAT.dtype ): # pylint: disable=unused-argument # new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + size = op.Cast(size, to=INT64.dtype) fill_value = op.Cast(fill_value, to=dtype) return op.Expand(fill_value, size) @@ -3585,12 +3587,12 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType: @torch_op("aten::ones") -def aten_ones(size: INT64, dtype: int = -1): +def aten_ones(size: IntType, dtype: int = FLOAT.dtype): # ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + size = op.Cast(size, to=INT64.dtype) one = op.Constant(value_float=1) - if dtype != -1: - one = op.Cast(one, to=dtype) + one = op.Cast(one, to=dtype) return op.Expand(one, size) @@ -4088,13 +4090,14 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT @torch_op("aten::repeat") -def aten_repeat(self: TTensor, repeats: INT64) -> TTensor: +def aten_repeat(self: TTensor, repeats: TInt) -> TTensor: # repeat(Tensor self, SymInt[] repeats) -> Tensor if op.Size(repeats) == 0: result = self else: # TODO(justinchuby): Make ones_like a function when onnxscript supports it + repeats = op.Cast(repeats, to=INT64.dtype) # shape = ones_like(repeats) := { one = op.Constant(value_int=1) repeats_shape = op.Shape(repeats) @@ -4114,10 +4117,11 @@ def aten_repeat_interleave( @torch_op("aten::reshape") -def aten_reshape(self: TTensor, shape: INT64) -> TTensor: +def aten_reshape(self: TTensor, shape: IntType) -> TTensor: # reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) - shape = op.Cast(shape, to=INT64.dtype) # Reshape only support INT64 as 'shape' + # Reshape only support INT64 as 'shape' + shape = op.Cast(shape, to=INT64.dtype) return op.Reshape(self, shape) @@ -4975,7 +4979,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::view") -def aten_view(self: TTensor, size: INT64) -> TTensor: +def aten_view(self: TTensor, size: IntType) -> TTensor: # view(Tensor(a) self, SymInt[] size) -> Tensor(a) size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input @@ -5044,12 +5048,12 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::zeros") -def aten_zeros(size: INT64, dtype: int = -1): +def aten_zeros(size: IntType, dtype: int = FLOAT.dtype): # zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + size = op.Cast(size, to=INT64.dtype) zero = op.Constant(value_float=0) - if dtype != -1: - zero = op.Cast(zero, to=dtype) + zero = op.Cast(zero, to=dtype) return op.Expand(zero, size) diff --git a/onnxscript/function_libs/torch_aten/typing.py b/onnxscript/function_libs/torch_aten/typing.py index ff4d8d1a32..7b4ce7a9b4 100644 --- a/onnxscript/function_libs/torch_aten/typing.py +++ b/onnxscript/function_libs/torch_aten/typing.py @@ -41,7 +41,7 @@ UINT8, ] _FloatType = Union[FLOAT16, FLOAT, DOUBLE] -_IntType = Union[INT8, INT16, INT32, INT64] +IntType = Union[INT8, INT16, INT32, INT64] RealType = Union[ BFLOAT16, FLOAT16, @@ -56,7 +56,7 @@ TTensor = TypeVar("TTensor", bound=_TensorType) TFloat = TypeVar("TFloat", bound=_FloatType) TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) -TInt = TypeVar("TInt", bound=_IntType) +TInt = TypeVar("TInt", bound=IntType) TReal = TypeVar("TReal", bound=RealType) TRealUnlessInt16OrInt8 = TypeVar( "TRealUnlessInt16OrInt8", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64] diff --git a/onnxscript/utils.py b/onnxscript/utils.py index 988016bc79..e354c204da 100644 --- a/onnxscript/utils.py +++ b/onnxscript/utils.py @@ -5,12 +5,13 @@ from __future__ import annotations import numbers -from typing import Any, Optional, Sequence +from typing import Any, Iterable, Optional, Sequence import numpy as np import onnx +import onnx.helper +import onnx.mapping from onnx import FunctionProto, ModelProto, TensorProto, ValueInfoProto -from onnx.helper import make_sequence_type_proto, make_tensor_type_proto from onnxscript import tensor @@ -82,22 +83,24 @@ def value_to_type_proto(val): if isinstance(val, (np.ndarray, tensor.Tensor)): elem_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val.dtype] shape = val.shape - return make_tensor_type_proto(elem_type, shape) + return onnx.helper.make_tensor_type_proto(elem_type, shape) if isinstance(val, int): - return make_tensor_type_proto(TensorProto.INT32, []) + return onnx.helper.make_tensor_type_proto(TensorProto.INT32, []) if isinstance(val, (float, np.float32)): - return make_tensor_type_proto(TensorProto.FLOAT, []) + return onnx.helper.make_tensor_type_proto(TensorProto.FLOAT, []) if isinstance(val, list): if len(val) > 0: - return make_sequence_type_proto(value_to_type_proto(val[0])) + return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # Edge-case. Cannot determine a suitable ONNX type for an empty list. # Should be using a typed-value instead. # Treated as a sequence of tensors of float-type. - return make_sequence_type_proto(make_tensor_type_proto(TensorProto.FLOAT, None)) + return onnx.helper.make_sequence_type_proto( + onnx.helper.make_tensor_type_proto(TensorProto.FLOAT, None) + ) if isinstance(val, numbers.Number): nparray = np.array(val) elem_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[nparray.dtype] - return make_tensor_type_proto(elem_type, []) + return onnx.helper.make_tensor_type_proto(elem_type, []) raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") @@ -144,7 +147,7 @@ def make_model_from_function_proto( **(attrs or {}), ) graph = onnx.helper.make_graph([node], "node_graph", input_value_infos, output_value_infos) - model_proto_opset = function_proto.opset_import + model_proto_opset: Iterable[onnx.OperatorSetIdProto] = function_proto.opset_import if all(o.domain != function_proto.domain for o in model_proto_opset): model_proto_opset = [ *model_proto_opset, diff --git a/pyproject.toml b/pyproject.toml index 0e529617d3..83526d7359 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ onnx = ["py.typed"] [tool.pytest.ini_options] filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"] +addopts = "-ra --tb=short --color=yes" [tool.mypy] follow_imports = "silent" # TODO: Remove when we fix all the mypy errors