Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
c6505cf
Move version_utils to _internal | chore
justinchuby Apr 12, 2023
268ad8f
Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)
justinchuby Apr 12, 2023
1f900c1
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 12, 2023
01f9378
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 12, 2023
1b7391e
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 12, 2023
304fa0d
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 12, 2023
ac60dce
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 12, 2023
3eb45c6
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 12, 2023
f7b9b79
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 13, 2023
59523f5
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 13, 2023
758ce59
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 13, 2023
fdd2f22
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 13, 2023
d4b0db6
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 13, 2023
cc977b1
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 13, 2023
8ffd7e7
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 13, 2023
831a416
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 13, 2023
1fd4d2e
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 13, 2023
62105f3
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 13, 2023
9c6bfae
Update base for Update on "Merge `attrs` and `attr_protos` in `IRFunc…
justinchuby Apr 13, 2023
c63ccd5
Update on "Merge `attrs` and `attr_protos` in `IRFunction` | chore(ir…
justinchuby Apr 13, 2023
4e38a29
Analyze type annotation for input types | feat(op_schema)
justinchuby Apr 14, 2023
b34551c
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 14, 2023
f3911bf
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 14, 2023
b48d6ef
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 14, 2023
9018509
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 17, 2023
f6a934d
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 17, 2023
7014658
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 18, 2023
40c29c6
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 18, 2023
a80742f
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 18, 2023
509aea0
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 18, 2023
f2046cc
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 18, 2023
b8c3450
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 18, 2023
d6ec312
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 20, 2023
b50b72c
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 20, 2023
84f7014
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 20, 2023
e9f383b
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 20, 2023
03f64d4
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 20, 2023
7804c7f
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 20, 2023
6538959
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 21, 2023
f31f30c
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 21, 2023
8648dd3
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 22, 2023
2b90e31
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 22, 2023
5f04381
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 22, 2023
0f605cd
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 22, 2023
d272e20
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 22, 2023
cbcf5b1
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 22, 2023
ca76465
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 25, 2023
724c0ef
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 25, 2023
5d8f60e
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 27, 2023
9d7d12f
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 27, 2023
e61bffe
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 27, 2023
52c198e
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 27, 2023
a37fb92
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 27, 2023
d1a24ef
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 27, 2023
dcd74de
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 28, 2023
2375946
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 28, 2023
6160dbc
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 28, 2023
83cc202
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 28, 2023
d69a130
Update base for Update on "Analyze type annotation for input types | …
justinchuby Apr 28, 2023
9dc0937
Update on "Analyze type annotation for input types | feat(op_schema)"
justinchuby Apr 28, 2023
22ac02a
Merge branch 'main' into gh/justinchuby/17/head
justinchuby Apr 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,5 +526,7 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut
proto = onnx.AttributeProto()
proto.name = attrname
proto.ref_attr_name = refname
proto.type = ta.pytype_to_attrtype(pytype)
attr_type = ta.pytype_to_attrtype(pytype)
assert attr_type is not None
proto.type = attr_type
return IRAttributeValue(proto)
4 changes: 4 additions & 0 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def to_type_proto(cls) -> onnx.TypeProto:
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)

@classmethod
def to_string(cls) -> str:
return f"tensor({cls.__name__.lower()})"


class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
pass
Expand Down
137 changes: 113 additions & 24 deletions onnxscript/type_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
import collections
import inspect
import typing
from typing import Optional, Sequence, Union

import onnx
from typing_extensions import get_args, get_origin

from onnxscript.onnx_types import TensorType
from onnxscript import onnx_types

# TypeAnnotationValue represents the (value of) valid type-annotations recognized
# by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports
# * float, int, str (primitive attribute types)
# * Sequence[float], Sequence[int], Sequence[str] (attribute types)
# * Tensor types
# * Sequence[Tensor] types
# * Union of above 2
# * TypeVars with above bounds
# * Above types with annotation attached
# - float, int, str (primitive attribute types)
# - Sequence[float], Sequence[int], Sequence[str] (attribute types)
# - Tensor types
# - Sequence[Tensor] types
# - Union of above 2
# - TypeVars with above bounds
# - Above types with annotation attached
TypeAnnotationValue = typing.Any

# Map from python type to corresponding ONNX AttributeProto type
Expand All @@ -43,13 +43,18 @@

_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence])

# A sorted list of all type strings used in an OpSchema
ALL_TENSOR_TYPE_STRINGS = tuple(
sorted(tensor_type.to_string() for tensor_type in onnx_types.tensor_type_registry.values())
)


def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue:
"""Remove Annotated wrapper if present, otherwise return typeinfo as is."""
if hasattr(typing, "Annotated"):
# Present in Python 3.9+
if get_origin(typeinfo) is typing.Annotated:
return get_args(typeinfo)[0]
if typing.get_origin(typeinfo) is typing.Annotated:
return typing.get_args(typeinfo)[0]
return typeinfo


Expand All @@ -59,28 +64,28 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:

def pytype_to_attrtype(
pytype: TypeAnnotationValue,
) -> onnx.AttributeProto.AttributeType:
) -> Optional[onnx.AttributeProto.AttributeType]:
pytype = _remove_annotation(pytype)
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
type_constructor = get_origin(pytype)
type_constructor = typing.get_origin(pytype)
# Remove Optional wrapper if present, which is represented as an Union[..., type(None)]
if type_constructor is typing.Union:
# Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)]
args = [x for x in get_args(pytype) if x is not type(None)]
args = [x for x in typing.get_args(pytype) if x is not type(None)]
if len(args) == 1:
return pytype_to_attrtype(args[0])
if type_constructor in _LIST_CONSTRUCTORS:
elt_type = get_args(pytype)[0]
elt_type = typing.get_args(pytype)[0]
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
return onnx.AttributeProto.UNDEFINED
return None


def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool:
if isinstance(typeinfo, TensorType):
if isinstance(typeinfo, onnx_types.TensorType):
return True
if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType):
if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType):
return True
return False

Expand All @@ -94,16 +99,16 @@ def is_value_type(typeinfo: TypeAnnotationValue) -> bool:
return True
if _is_primitive_attr_type(typeinfo):
return False
type_constructor = get_origin(typeinfo)
type_constructor = typing.get_origin(typeinfo)
# Handle List-like type-constructor
# Eg. List[INT32] is a value type, while List[int] is an attribute type
if type_constructor in _LIST_CONSTRUCTORS:
elt_type = get_args(typeinfo)[0]
elt_type = typing.get_args(typeinfo)[0]
return is_value_type(elt_type)
# Handle Union and Optional type-constructors
if type_constructor is typing.Union:
# Filter out None, since typing.Optional[X] evaluates to Union[X, None]
args = [x for x in get_args(typeinfo) if x is not type(None)]
args = [x for x in typing.get_args(typeinfo) if x is not type(None)]
args_value_check = [is_value_type(x) for x in args]
if all(args_value_check):
# Handles cases like Optional[INT32] as well as Union[FLOAT16, FLOAT, DOUBLE]
Expand Down Expand Up @@ -133,7 +138,12 @@ def is_valid_type(typeinfo: TypeAnnotationValue):
return False


def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[type]:
def is_optional(pytype) -> bool:
"""Returns whether a pytype is an Optional."""
return typing.get_origin(pytype) is typing.Union and type(None) in typing.get_args(pytype)


def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]:
"""Converts return-type annotation into a sequence of types.

The return type annotation can be either a single type (for a single output)
Expand All @@ -143,6 +153,85 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[
"""
if isinstance(typeinfo, typing.Sequence):
return typeinfo
if get_origin(typeinfo) is tuple:
return get_args(typeinfo)
if typing.get_origin(typeinfo) is tuple:
return typing.get_args(typeinfo)
return (typeinfo,)


def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]:
"""Returns a list of type-strings corresponding to a given type annotation.

Args:
pytype: A type annotation.

Returns:
A list of all supported input types for the given type annotation.
Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS.
"""
if pytype is None:
return list(ALL_TENSOR_TYPE_STRINGS)
if pytype is onnx_types.TensorType:
return list(ALL_TENSOR_TYPE_STRINGS)
if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType):
return [pytype.to_string()]
if isinstance(pytype, onnx_types.TensorType):
return [pytype.to_string()]
if isinstance(pytype, typing.TypeVar):
constraints = pytype.__constraints__
if constraints:
return pytype_to_type_strings(Union.__getitem__(constraints))
bound = pytype.__bound__
if bound is None:
return list(ALL_TENSOR_TYPE_STRINGS)
return pytype_to_type_strings(bound)
if typing.get_origin(pytype) is Union:
options = []
subtypes = typing.get_args(pytype)
# A None type in a Union is equivalent to an optional type
optional = is_optional(pytype)
for subtype in subtypes:
if subtype is type(None):
# Skip None type because we are handling it with is_optional
continue
if optional:
options += [
*pytype_to_type_strings(subtype),
*[f"optional({s})" for s in pytype_to_type_strings(subtype)],
]
else:
options += pytype_to_type_strings(subtype)
# Remove duplicates
return sorted(set(options))
if typing.get_origin(pytype) in _LIST_CONSTRUCTORS:
subtypes = typing.get_args(pytype)
return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])]

raise ValueError(f"Unsupported type: {pytype}")


def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this function used/useful?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s used for converting Python function signatures into onnx type constraints. The tests show usage. You can find actual usage in #626 and #674

"""Returns the name of the type constraint for a given type annotation.

Args:
pytype: A type annotation.

Returns:
The name of the type constraint if it is a TypeVar.
- Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar].
- Prefixes the name with "Sequence_" if the type annotation is a Sequence[].
- Returns None if the type annotation does not have a type constraint.
"""
if isinstance(pytype, typing.TypeVar):
return pytype.__name__
if is_optional(pytype):
subtypes = typing.get_args(pytype)
for subtype in subtypes:
if subtype is type(None):
continue
type_param_name = get_type_constraint_name(subtype)
return f"Optional_{type_param_name}" if type_param_name else None
if typing.get_origin(pytype) in _LIST_CONSTRUCTORS:
subtypes = typing.get_args(pytype)
type_param_name = get_type_constraint_name(subtypes[0])
return f"Sequence_{type_param_name}" if type_param_name else None
return None
Loading