Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ jobs:
pip list | grep torch

- name: pytest
run: pytest -v onnxscript --cov=onnxscript --cov-report=xml -n auto
run: pytest -v onnxscript --cov=onnxscript --cov-report=xml -n=auto

- name: Install package
run: pip install .

- name: Test examples
if: ${{ matrix.test_examples }}
run: pytest -v docs/test
run: pytest -v docs/test -n=auto

- name: Build package
run: python -m build
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/backend/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _read_proto_from_file(full):
loaded = to_list(seq) # type: ignore[assignment]
except Exception: # pylint: disable=W0703
try:
loaded = onnx.load_model_from_string(serialized)
except Exception: # pragma: no cover
loaded = onnx.load_model_from_string(serialized) # type: ignore[assignment]
except Exception:
raise RuntimeError(
f"Unable to read {full!r}, error is {e}, "
f"content is {serialized[:100]!r}."
Expand Down
30 changes: 17 additions & 13 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64
from onnxscript.function_libs.torch_aten.registration import torch_op
from onnxscript.function_libs.torch_aten.typing import (
IntType,
TFloat,
TFloatOrBFloat16,
TInt,
Expand Down Expand Up @@ -1642,10 +1643,10 @@ def aten_exp2(self: TFloat) -> TFloat:


@torch_op("aten::expand")
def aten_expand(self: TTensor, size: INT64) -> TTensor:
def aten_expand(self: TTensor, size: TInt) -> TTensor:
# expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)

size = op.Cast(size, to=INT64.dtype) # to INT64
size = op.Cast(size, to=INT64.dtype)
return op.Expand(self, size)


Expand Down Expand Up @@ -3518,10 +3519,11 @@ def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> Tens

@torch_op("aten::new_full")
def aten_new_full(
self, size: INT64, fill_value, dtype: int = FLOAT.dtype
self, size: IntType, fill_value, dtype: int = FLOAT.dtype
): # pylint: disable=unused-argument
# new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

size = op.Cast(size, to=INT64.dtype)
fill_value = op.Cast(fill_value, to=dtype)

return op.Expand(fill_value, size)
Expand Down Expand Up @@ -3585,12 +3587,12 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:


@torch_op("aten::ones")
def aten_ones(size: INT64, dtype: int = -1):
def aten_ones(size: IntType, dtype: int = FLOAT.dtype):
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

size = op.Cast(size, to=INT64.dtype)
one = op.Constant(value_float=1)
if dtype != -1:
one = op.Cast(one, to=dtype)
one = op.Cast(one, to=dtype)
return op.Expand(one, size)


Expand Down Expand Up @@ -4088,13 +4090,14 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT


@torch_op("aten::repeat")
def aten_repeat(self: TTensor, repeats: INT64) -> TTensor:
def aten_repeat(self: TTensor, repeats: TInt) -> TTensor:
# repeat(Tensor self, SymInt[] repeats) -> Tensor

if op.Size(repeats) == 0:
result = self
else:
# TODO(justinchuby): Make ones_like a function when onnxscript supports it
repeats = op.Cast(repeats, to=INT64.dtype)
# shape = ones_like(repeats) := {
one = op.Constant(value_int=1)
repeats_shape = op.Shape(repeats)
Expand All @@ -4114,10 +4117,11 @@ def aten_repeat_interleave(


@torch_op("aten::reshape")
def aten_reshape(self: TTensor, shape: INT64) -> TTensor:
def aten_reshape(self: TTensor, shape: IntType) -> TTensor:
# reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)

shape = op.Cast(shape, to=INT64.dtype) # Reshape only support INT64 as 'shape'
# Reshape only support INT64 as 'shape'
shape = op.Cast(shape, to=INT64.dtype)
return op.Reshape(self, shape)


Expand Down Expand Up @@ -4975,7 +4979,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:


@torch_op("aten::view")
def aten_view(self: TTensor, size: INT64) -> TTensor:
def aten_view(self: TTensor, size: IntType) -> TTensor:
# view(Tensor(a) self, SymInt[] size) -> Tensor(a)

size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
Expand Down Expand Up @@ -5044,12 +5048,12 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:


@torch_op("aten::zeros")
def aten_zeros(size: INT64, dtype: int = -1):
def aten_zeros(size: IntType, dtype: int = FLOAT.dtype):
# zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

size = op.Cast(size, to=INT64.dtype)
zero = op.Constant(value_float=0)
if dtype != -1:
zero = op.Cast(zero, to=dtype)
zero = op.Cast(zero, to=dtype)

return op.Expand(zero, size)

Expand Down
4 changes: 2 additions & 2 deletions onnxscript/function_libs/torch_aten/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
UINT8,
]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
_IntType = Union[INT8, INT16, INT32, INT64]
IntType = Union[INT8, INT16, INT32, INT64]
RealType = Union[
BFLOAT16,
FLOAT16,
Expand All @@ -56,7 +56,7 @@
TTensor = TypeVar("TTensor", bound=_TensorType)
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
TInt = TypeVar("TInt", bound=_IntType)
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
TRealUnlessInt16OrInt8 = TypeVar(
"TRealUnlessInt16OrInt8", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64]
Expand Down
21 changes: 12 additions & 9 deletions onnxscript/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from __future__ import annotations

import numbers
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional, Sequence

import numpy as np
import onnx
import onnx.helper
import onnx.mapping
from onnx import FunctionProto, ModelProto, TensorProto, ValueInfoProto
from onnx.helper import make_sequence_type_proto, make_tensor_type_proto

from onnxscript import tensor

Expand Down Expand Up @@ -82,22 +83,24 @@ def value_to_type_proto(val):
if isinstance(val, (np.ndarray, tensor.Tensor)):
elem_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val.dtype]
shape = val.shape
return make_tensor_type_proto(elem_type, shape)
return onnx.helper.make_tensor_type_proto(elem_type, shape)
if isinstance(val, int):
return make_tensor_type_proto(TensorProto.INT32, [])
return onnx.helper.make_tensor_type_proto(TensorProto.INT32, [])
if isinstance(val, (float, np.float32)):
return make_tensor_type_proto(TensorProto.FLOAT, [])
return onnx.helper.make_tensor_type_proto(TensorProto.FLOAT, [])
if isinstance(val, list):
if len(val) > 0:
return make_sequence_type_proto(value_to_type_proto(val[0]))
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
# Should be using a typed-value instead.
# Treated as a sequence of tensors of float-type.
return make_sequence_type_proto(make_tensor_type_proto(TensorProto.FLOAT, None))
return onnx.helper.make_sequence_type_proto(
onnx.helper.make_tensor_type_proto(TensorProto.FLOAT, None)
)
if isinstance(val, numbers.Number):
nparray = np.array(val)
elem_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[nparray.dtype]
return make_tensor_type_proto(elem_type, [])
return onnx.helper.make_tensor_type_proto(elem_type, [])
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")


Expand Down Expand Up @@ -144,7 +147,7 @@ def make_model_from_function_proto(
**(attrs or {}),
)
graph = onnx.helper.make_graph([node], "node_graph", input_value_infos, output_value_infos)
model_proto_opset = function_proto.opset_import
model_proto_opset: Iterable[onnx.OperatorSetIdProto] = function_proto.opset_import
if all(o.domain != function_proto.domain for o in model_proto_opset):
model_proto_opset = [
*model_proto_opset,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ onnx = ["py.typed"]

[tool.pytest.ini_options]
filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"]
addopts = "-ra --tb=short --color=yes"

[tool.mypy]
follow_imports = "silent" # TODO: Remove when we fix all the mypy errors
Expand Down