diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 9c4e5f5d9e..443f18412e 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -526,5 +526,7 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut proto = onnx.AttributeProto() proto.name = attrname proto.ref_attr_name = refname - proto.type = ta.pytype_to_attrtype(pytype) + attr_type = ta.pytype_to_attrtype(pytype) + assert attr_type is not None + proto.type = attr_type return IRAttributeValue(proto) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 32e48bf744..4951a3ff18 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -105,6 +105,10 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = [cls.shape] # example: "FLOAT[10]" return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + @classmethod + def to_string(cls) -> str: + return f"tensor({cls.__name__.lower()})" + class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): pass diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index c5e9f0a704..1f87ee8878 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -7,21 +7,21 @@ import collections import inspect import typing +from typing import Optional, Sequence, Union import onnx -from typing_extensions import get_args, get_origin -from onnxscript.onnx_types import TensorType +from onnxscript import onnx_types # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports -# * float, int, str (primitive attribute types) -# * Sequence[float], Sequence[int], Sequence[str] (attribute types) -# * Tensor types -# * Sequence[Tensor] types -# * Union of above 2 -# * TypeVars with above bounds -# * Above types with annotation attached +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached TypeAnnotationValue = typing.Any # Map from python type to corresponding ONNX AttributeProto type @@ -43,13 +43,18 @@ _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) +# A sorted list of all type strings used in an OpSchema +ALL_TENSOR_TYPE_STRINGS = tuple( + sorted(tensor_type.to_string() for tensor_type in onnx_types.tensor_type_registry.values()) +) + def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: """Remove Annotated wrapper if present, otherwise return typeinfo as is.""" if hasattr(typing, "Annotated"): # Present in Python 3.9+ - if get_origin(typeinfo) is typing.Annotated: - return get_args(typeinfo)[0] + if typing.get_origin(typeinfo) is typing.Annotated: + return typing.get_args(typeinfo)[0] return typeinfo @@ -59,28 +64,28 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> onnx.AttributeProto.AttributeType: +) -> Optional[onnx.AttributeProto.AttributeType]: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] - type_constructor = get_origin(pytype) + type_constructor = typing.get_origin(pytype) # Remove Optional wrapper if present, which is represented as an Union[..., type(None)] if type_constructor is typing.Union: # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)] - args = [x for x in get_args(pytype) if x is not type(None)] + args = [x for x in typing.get_args(pytype) if x is not type(None)] if len(args) == 1: return pytype_to_attrtype(args[0]) if type_constructor in _LIST_CONSTRUCTORS: - elt_type = get_args(pytype)[0] + elt_type = typing.get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return onnx.AttributeProto.UNDEFINED + return None def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: - if isinstance(typeinfo, TensorType): + if isinstance(typeinfo, onnx_types.TensorType): return True - if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType): + if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False @@ -94,16 +99,16 @@ def is_value_type(typeinfo: TypeAnnotationValue) -> bool: return True if _is_primitive_attr_type(typeinfo): return False - type_constructor = get_origin(typeinfo) + type_constructor = typing.get_origin(typeinfo) # Handle List-like type-constructor # Eg. List[INT32] is a value type, while List[int] is an attribute type if type_constructor in _LIST_CONSTRUCTORS: - elt_type = get_args(typeinfo)[0] + elt_type = typing.get_args(typeinfo)[0] return is_value_type(elt_type) # Handle Union and Optional type-constructors if type_constructor is typing.Union: # Filter out None, since typing.Optional[X] evaluates to Union[X, None] - args = [x for x in get_args(typeinfo) if x is not type(None)] + args = [x for x in typing.get_args(typeinfo) if x is not type(None)] args_value_check = [is_value_type(x) for x in args] if all(args_value_check): # Handles cases like Optional[INT32] as well as Union[FLOAT16, FLOAT, DOUBLE] @@ -133,7 +138,12 @@ def is_valid_type(typeinfo: TypeAnnotationValue): return False -def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[type]: +def is_optional(pytype) -> bool: + """Returns whether a pytype is an Optional.""" + return typing.get_origin(pytype) is typing.Union and type(None) in typing.get_args(pytype) + + +def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]: """Converts return-type annotation into a sequence of types. The return type annotation can be either a single type (for a single output) @@ -143,6 +153,85 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[ """ if isinstance(typeinfo, typing.Sequence): return typeinfo - if get_origin(typeinfo) is tuple: - return get_args(typeinfo) + if typing.get_origin(typeinfo) is tuple: + return typing.get_args(typeinfo) return (typeinfo,) + + +def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]: + """Returns a list of type-strings corresponding to a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + A list of all supported input types for the given type annotation. + Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS. + """ + if pytype is None: + return list(ALL_TENSOR_TYPE_STRINGS) + if pytype is onnx_types.TensorType: + return list(ALL_TENSOR_TYPE_STRINGS) + if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType): + return [pytype.to_string()] + if isinstance(pytype, onnx_types.TensorType): + return [pytype.to_string()] + if isinstance(pytype, typing.TypeVar): + constraints = pytype.__constraints__ + if constraints: + return pytype_to_type_strings(Union.__getitem__(constraints)) + bound = pytype.__bound__ + if bound is None: + return list(ALL_TENSOR_TYPE_STRINGS) + return pytype_to_type_strings(bound) + if typing.get_origin(pytype) is Union: + options = [] + subtypes = typing.get_args(pytype) + # A None type in a Union is equivalent to an optional type + optional = is_optional(pytype) + for subtype in subtypes: + if subtype is type(None): + # Skip None type because we are handling it with is_optional + continue + if optional: + options += [ + *pytype_to_type_strings(subtype), + *[f"optional({s})" for s in pytype_to_type_strings(subtype)], + ] + else: + options += pytype_to_type_strings(subtype) + # Remove duplicates + return sorted(set(options)) + if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: + subtypes = typing.get_args(pytype) + return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])] + + raise ValueError(f"Unsupported type: {pytype}") + + +def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: + """Returns the name of the type constraint for a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar]. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + - Returns None if the type annotation does not have a type constraint. + """ + if isinstance(pytype, typing.TypeVar): + return pytype.__name__ + if is_optional(pytype): + subtypes = typing.get_args(pytype) + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = get_type_constraint_name(subtype) + return f"Optional_{type_param_name}" if type_param_name else None + if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: + subtypes = typing.get_args(pytype) + type_param_name = get_type_constraint_name(subtypes[0]) + return f"Sequence_{type_param_name}" if type_param_name else None + return None diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 17f04deff8..a4b64485e1 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -4,11 +4,14 @@ # -------------------------------------------------------------------------- import unittest +from typing import Any, List, Optional, Sequence, TypeVar, Union +import parameterized + +import onnxscript import onnxscript.testing -from onnxscript import script +from onnxscript import FLOAT, INT64, script, type_annotation from onnxscript.onnx_opset import opset15 as op -from onnxscript.onnx_types import FLOAT from onnxscript.tests.common import testutils @@ -88,5 +91,157 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: ) +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) + + +class TypeConversionFunctionsTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ( + "tensor_type_all", + onnxscript.onnx_types.TensorType, + list(type_annotation.ALL_TENSOR_TYPE_STRINGS), + ), + ("none", None, list(type_annotation.ALL_TENSOR_TYPE_STRINGS)), + ("tensor_type", INT64, ["tensor(int64)"]), + ("tensor_type_union", Union[INT64, FLOAT], ["tensor(float)", "tensor(int64)"]), + ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), + ("tensor_type_shape", INT64[10], ["tensor(int64)"]), + ( + "type_var_constraints", + _TestTypeVarConstraints, + ["tensor(float)", "tensor(int64)"], + ), + ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), + ("type_bound_two", _TestTypeVarTwoBound, ["tensor(float)", "tensor(int64)"]), + ( + "optional_tensor_type_all", + Optional[onnxscript.onnx_types.TensorType], + [ + *[ + f"optional({tensor_type})" + for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS + ], + *type_annotation.ALL_TENSOR_TYPE_STRINGS, + ], + ), + ( + "optional_tensor_type", + Optional[INT64], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_tensor_type_union", + Optional[Union[INT64, FLOAT]], + [ + "optional(tensor(float))", + "optional(tensor(int64))", + "tensor(float)", + "tensor(int64)", + ], + ), + ( + "optional_tensor_type_variadic_shape", + Optional[INT64[...]], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_tensor_type_shape", + Optional[INT64[10]], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_type_var_constraints", + Optional[_TestTypeVarConstraints], + [ + "optional(tensor(float))", + "optional(tensor(int64))", + "tensor(float)", + "tensor(int64)", + ], + ), + ( + "optional_type_bound_one", + Optional[_TestTypeVarOneBound], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_type_bound_two", + Optional[_TestTypeVarTwoBound], + [ + "optional(tensor(float))", + "optional(tensor(int64))", + "tensor(float)", + "tensor(int64)", + ], + ), + ( + "sequence_type_all", + Sequence[onnxscript.onnx_types.TensorType], + [ + f"seq({tensor_type})" + for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS + ], + ), + ("sequence_type", Sequence[INT64], ["seq(tensor(int64))"]), + ( + "union_sequence_type", + Union[Sequence[INT64], Sequence[FLOAT]], + ["seq(tensor(float))", "seq(tensor(int64))"], + ), + ( + "sequence_type_variadic_shape", + Sequence[INT64[...]], + ["seq(tensor(int64))"], + ), + ("sequence_type_shape", Sequence[INT64[10]], ["seq(tensor(int64))"]), + ( + "sequence_type_var_constraints", + Sequence[_TestTypeVarConstraints], + ["seq(tensor(float))", "seq(tensor(int64))"], + ), + ( + "sequence_type_bound_one", + Sequence[_TestTypeVarOneBound], + ["seq(tensor(int64))"], + ), + ( + "sequence_type_bound_two", + Sequence[_TestTypeVarTwoBound], + ["seq(tensor(float))", "seq(tensor(int64))"], + ), + ] + ) + def test_pytype_to_type_strings(self, _, pytype: Any, expected: List[str]): + self.assertEqual(type_annotation.pytype_to_type_strings(pytype), expected) + + @parameterized.parameterized.expand( + [ + ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), + ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), + ( + "optional_type_var", + Optional[_TestTypeVarOneBound], + "Optional__TestTypeVarOneBound", + ), + ( + "sequence_type_var", + Sequence[_TestTypeVarOneBound], + "Sequence__TestTypeVarOneBound", + ), + ("normal_type", INT64, None), + ("union_type", Union[INT64, FLOAT], None), + ("optional_type", Optional[INT64], None), + ("sequence_type", Sequence[INT64], None), + ("optional_sequence_type", Optional[Sequence[INT64]], None), + ("optional_union_type", Optional[Union[INT64, FLOAT]], None), + ] + ) + def test_get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): + self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) + + if __name__ == "__main__": unittest.main()