From c6505cf19829d264c37873a123404cc22bac401b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 17:30:24 +0000 Subject: [PATCH 01/28] Move version_utils to _internal | chore Move version_utils to `_internal` so that it can be used my onnxscript [ghstack-poisoned] --- onnxscript/{tests/common => _internal}/version_utils.py | 0 onnxscript/function_libs/torch_aten/graph_building_test.py | 2 +- .../tests/function_libs/torch_aten/ops_correctness_test.py | 2 +- onnxscript/tests/functions/onnxfns1A_test.py | 3 ++- onnxscript/tests/functions/onnxfns2_test.py | 3 ++- onnxscript/tests/functions/onnxfns_test.py | 3 ++- 6 files changed, 8 insertions(+), 5 deletions(-) rename onnxscript/{tests/common => _internal}/version_utils.py (100%) diff --git a/onnxscript/tests/common/version_utils.py b/onnxscript/_internal/version_utils.py similarity index 100% rename from onnxscript/tests/common/version_utils.py rename to onnxscript/_internal/version_utils.py diff --git a/onnxscript/function_libs/torch_aten/graph_building_test.py b/onnxscript/function_libs/torch_aten/graph_building_test.py index 0257fa3315..60646b8768 100644 --- a/onnxscript/function_libs/torch_aten/graph_building_test.py +++ b/onnxscript/function_libs/torch_aten/graph_building_test.py @@ -10,8 +10,8 @@ import onnxscript.testing from onnxscript import FLOAT, evaluator from onnxscript import opset17 as op +from onnxscript._internal import version_utils from onnxscript.function_libs.torch_aten import graph_building, ops -from onnxscript.tests.common import version_utils @unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported") diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 011ba96a43..897a8cdcdb 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -58,11 +58,11 @@ import onnxscript import onnxscript.evaluator +from onnxscript._internal import version_utils from onnxscript.function_libs.torch_aten import graph_building from onnxscript.function_libs.torch_aten.ops import core as core_ops from onnxscript.function_libs.torch_aten.ops import nn as nn_ops from onnxscript.function_libs.torch_aten.ops import special as special_ops -from onnxscript.tests.common import version_utils from onnxscript.tests.function_libs.torch_aten import extra_opinfo T = TypeVar("T") diff --git a/onnxscript/tests/functions/onnxfns1A_test.py b/onnxscript/tests/functions/onnxfns1A_test.py index 00148e5632..a9c22aba43 100644 --- a/onnxscript/tests/functions/onnxfns1A_test.py +++ b/onnxscript/tests/functions/onnxfns1A_test.py @@ -3,7 +3,8 @@ import onnx import pytest -from onnxscript.tests.common import onnx_script_test_case, version_utils +from onnxscript._internal import version_utils +from onnxscript.tests.common import onnx_script_test_case from onnxscript.tests.models import onnxfns1A diff --git a/onnxscript/tests/functions/onnxfns2_test.py b/onnxscript/tests/functions/onnxfns2_test.py index db4aa40ea7..55ccaeea17 100644 --- a/onnxscript/tests/functions/onnxfns2_test.py +++ b/onnxscript/tests/functions/onnxfns2_test.py @@ -3,7 +3,8 @@ import onnxruntime import pytest -from onnxscript.tests.common import onnx_script_test_case, version_utils +from onnxscript._internal import version_utils +from onnxscript.tests.common import onnx_script_test_case from onnxscript.tests.models import onnxfns2 diff --git a/onnxscript/tests/functions/onnxfns_test.py b/onnxscript/tests/functions/onnxfns_test.py index fd0c5ff3a8..e0b37fb0f4 100644 --- a/onnxscript/tests/functions/onnxfns_test.py +++ b/onnxscript/tests/functions/onnxfns_test.py @@ -8,7 +8,8 @@ import onnx import pytest -from onnxscript.tests.common import onnx_script_test_case, version_utils +from onnxscript._internal import version_utils +from onnxscript.tests.common import onnx_script_test_case from onnxscript.tests.models import onnxfns1 From 268ad8f142194994c5d920ab10a0e6daee9b3914 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 17:53:18 +0000 Subject: [PATCH 02/28] Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder) [ghstack-poisoned] --- onnxscript/converter.py | 1 + onnxscript/irbuilder.py | 66 +++++++++++++++++++++-------------- onnxscript/type_annotation.py | 4 +-- onnxscript/values.py | 41 ++++++++++++---------- 4 files changed, 64 insertions(+), 48 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 278989a6fc..75f3a8dafb 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1326,6 +1326,7 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: self.ir_builder.add_attr_parameter( self.current_fn, x.arg, + ta.pytype_to_attrtype(typeinfo), default_value, ) self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x))) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 2481b42ae6..bcb2f9d3f6 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -108,10 +108,16 @@ def _opt_var_to_str(x): class IRAttributeValue: - """An attribute value (representing an actual parameter).""" + """An attribute value (representing an actual parameter). - def __init__(self, attrproto) -> None: + Attributes: + attr_proto: The attribute proto + has_default: Whether the attribute has a default value. + """ + + def __init__(self, attrproto, has_default: bool) -> None: self.attr_proto = attrproto + self.has_default = has_default def __str__(self): if self.attr_proto.HasField("ref_attr_name"): @@ -191,9 +197,7 @@ def __init__(self, name: str, domain: str = "") -> None: self.outputs: list[IRVar] = [] self.stmts: list[IRStmt] = [] # attribute parameters - self.attrs: list[str] = [] - # attribute parameters with default value - self.attr_protos: list[IRAttributeValue] = [] + self.attrs: list[IRAttributeValue] = [] self.called_functions: dict[str, onnx.FunctionProto] = {} self.docstring: str = "" # a dictionary of nested function-definitions @@ -207,11 +211,10 @@ def assigned_names(self) -> Sequence[str]: def __str__(self): attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else "" - attr_protos = _format(self.attr_protos, "<", ", ", ">") if self.attr_protos else "" inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")") outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")") stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n") - return f"{self.name} {attrs}{attr_protos}{inputs} => {outputs}{stmts}" + return f"{self.name} {attrs}{inputs} => {outputs}{stmts}" def append_docstring(self, docstring): self.docstring += docstring @@ -225,11 +228,8 @@ def append_input(self, name: IRVar) -> None: def append_output(self, name: IRVar) -> None: self.outputs.append(name) - def add_attr_parameter(self, attr: str | IRAttributeValue) -> None: - if isinstance(attr, IRAttributeValue): - self.attr_protos.append(attr) - else: - self.attrs.append(attr) + def add_attr_parameter(self, attr: IRAttributeValue) -> None: + self.attrs.append(attr) def debug_print(self): if logger.isEnabledFor(logging.DEBUG): @@ -398,19 +398,19 @@ def to_function_proto(self) -> onnx.FunctionProto: onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() ] - # attribute_proto is introduced in version onnx==1.13.0. + # attribute_proto is introduced in version onnx==1.14.0. # If this attribute is available, onnxscript uses it to # default values for attributes. The function has then two # lists, one list for attributes without default values, # another one for attributes with default values. # If this *attribute_proto* is not available, - # all attributes with a default value are moved to the first + # all attributes are moved to the first # list, default values are removed. # TODO: remove this when onnx with attribute_proto is released. if hasattr(onnx.FunctionProto, "attribute_proto"): - atts = self.attrs + attribute_names = [attr.name for attr in self.attrs if not attr.has_default] else: - atts = self.attrs + [a.attr_proto.name for a in self.attr_protos] + attribute_names = [attr.name for attr in self.attrs] f = helper.make_function( self.domain, @@ -419,11 +419,13 @@ def to_function_proto(self) -> onnx.FunctionProto: outputs=[y.name for y in self.outputs], nodes=nodes, opset_imports=opset_imports, # TODO - attributes=atts, + attributes=attribute_names, doc_string=self.docstring, ) if hasattr(onnx.FunctionProto, "attribute_proto"): - f.attribute_proto.extend([a.attr_proto for a in self.attr_protos]) + f.attribute_proto.extend( + [attr.attr_proto for attr in self.attrs if attr.has_default] + ) return f @@ -463,25 +465,35 @@ def add_input( v = IRVar(varname, type, info) fn.append_input(v) - def add_attr_parameter(self, fn: IRFunction, varname: str, default_value) -> None: + def add_attr_parameter( + self, + fn: IRFunction, + varname: str, + attribute_type: onnx.AttributeProto.AttributeType, + default_value, + ) -> None: if default_value is not None: - a = IRAttributeValue(helper.make_attribute(varname, default_value)) - fn.add_attr_parameter(a) + fn.add_attr_parameter( + IRAttributeValue( + helper.make_attribute(varname, default_value), has_default=True + ) + ) else: - fn.add_attr_parameter(varname) + proto = onnx.AttributeProto() + proto.name = varname + proto.type = attribute_type + fn.add_attr_parameter(IRAttributeValue(proto, has_default=False)) def add_output(self, fn: IRFunction, varname: str, type, info) -> None: v = IRVar(varname, type, info) fn.append_output(v) def make_attr(self, attrname: str, attrval: Any) -> IRAttributeValue: - return IRAttributeValue(helper.make_attribute(attrname, attrval)) + return IRAttributeValue(helper.make_attribute(attrname, attrval), has_default=True) def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue: a = onnx.AttributeProto() a.name = attrname a.ref_attr_name = refname - type_ = ta.pytype_to_attrtype(pytype) - assert type_ is not None - a.type = type_ - return IRAttributeValue(a) + a.type = ta.pytype_to_attrtype(pytype) + return IRAttributeValue(a, has_default=False) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index ee7a2cc554..c5e9f0a704 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -59,7 +59,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> typing.Optional[onnx.AttributeProto.AttributeType]: +) -> onnx.AttributeProto.AttributeType: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] @@ -74,7 +74,7 @@ def pytype_to_attrtype( elt_type = get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return None + return onnx.AttributeProto.UNDEFINED def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: diff --git a/onnxscript/values.py b/onnxscript/values.py index 04379685bd..121cac2d8a 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -311,11 +311,10 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # The first len(func_ir.inputs) arguments are onnx inputs inputs = function_ir.inputs # The rest is onnx attributes - attributes = function_ir.attrs # Construct a dictionary of attributes with their names specified in the function # definition attr_name_to_protos = collections.OrderedDict( - (attr.name, attr) for attr in function_ir.attr_protos + (attr.name, attr) for attr in function_ir.attrs ) # args with default value are attributes @@ -325,26 +324,30 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: required = False else: required = True - param_schema = ParamSchema( - name=arg.name, type=arg.typeinfo, is_input=True, required=required + schemas.append( + ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required) ) - schemas.append(param_schema) - - for attr_name in attributes: - # Attributes without default values - # FIXME(justinchuby): Where can we find the type? - param_schema = ParamSchema(name=attr_name, type=None, is_input=False) - schemas.append(param_schema) for name, attr_value in attr_name_to_protos.items(): - param_schema = ParamSchema( - name=name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], - default=_get_attribute_value(attr_value.attr_proto), - is_input=False, - # All function attributes are required - ) - schemas.append(param_schema) + if not attr_value.has_default: + schemas.append( + ParamSchema( + name=name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], + is_input=False, + required=True, + ) + ) + else: + schemas.append( + ParamSchema( + name=name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], + default=_get_attribute_value(attr_value.attr_proto), + is_input=False, + required=True, + ) + ) self._param_schemas = tuple(schemas) return self._param_schemas # type: ignore[return-value] From 327a7d6287f426180c51119dcc4c0804ec65d689 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 18:04:08 +0000 Subject: [PATCH 03/28] Auto generate OpSchema for functions | feat This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` [ghstack-poisoned] --- onnxscript/autocast.py | 2 +- onnxscript/onnx_types.py | 87 ++++++++++++++++++----- onnxscript/type_annotation.py | 60 ++++++++++++++-- onnxscript/values.py | 126 ++++++++++++++++++++++++++++++++-- 4 files changed, 248 insertions(+), 27 deletions(-) diff --git a/onnxscript/autocast.py b/onnxscript/autocast.py index 40854eb75d..03aaf967f9 100644 --- a/onnxscript/autocast.py +++ b/onnxscript/autocast.py @@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor: return cast_inputs(get_type_info, cast, op_schema, *args) -def static_cast_inputs(converter, op_schema: OpSchema, *args): +def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args): """Used for autocast during script-translation.""" if op_schema is None: return args diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 32e48bf744..0e8f8b8b8c 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -105,69 +105,105 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = [cls.shape] # example: "FLOAT[10]" return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + @classmethod + def to_string(cls) -> str: + raise NotImplementedError() + class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): - pass + @classmethod + def to_string(cls): + return "tensor(float)" class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): - pass + @classmethod + def to_string(cls): + return "tensor(uint8)" class INT8(TensorType, dtype=onnx.TensorProto.INT8): - pass + @classmethod + def to_string(cls): + return "tensor(int8)" class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): - pass + @classmethod + def to_string(cls): + return "tensor(uint16)" class INT16(TensorType, dtype=onnx.TensorProto.INT16): - pass + @classmethod + def to_string(cls): + return "tensor(int16)" class INT32(TensorType, dtype=onnx.TensorProto.INT32): - pass + @classmethod + def to_string(cls): + return "tensor(int32)" class INT64(TensorType, dtype=onnx.TensorProto.INT64): - pass + @classmethod + def to_string(cls): + return "tensor(int64)" class STRING(TensorType, dtype=onnx.TensorProto.STRING): - pass + @classmethod + def to_string(cls): + return "tensor(string)" class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): - pass + @classmethod + def to_string(cls): + return "tensor(bool)" class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): - pass + @classmethod + def to_string(cls): + return "tensor(float16)" class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): - pass + @classmethod + def to_string(cls): + return "tensor(double)" class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): - pass + @classmethod + def to_string(cls): + return "tensor(uint32)" class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): - pass + @classmethod + def to_string(cls): + return "tensor(uint64)" class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): - pass + @classmethod + def to_string(cls): + return "tensor(complex64)" class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): - pass + @classmethod + def to_string(cls): + return "tensor(complex128)" class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): - pass + @classmethod + def to_string(cls): + return "tensor(bfloat16)" def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: @@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: # Currently, only tensor types are supported. Need to expand support for other ONNX types. ONNXType = TensorType + +ALL_TENSOR_TYPES = ( + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, +) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index c5e9f0a704..498a7e2971 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -7,11 +7,12 @@ import collections import inspect import typing +from typing import Any, TypeVar, Union import onnx from typing_extensions import get_args, get_origin -from onnxscript.onnx_types import TensorType +from onnxscript import onnx_types # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports @@ -59,7 +60,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> onnx.AttributeProto.AttributeType: +) -> typing.Optional[onnx.AttributeProto.AttributeType]: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] @@ -74,13 +75,13 @@ def pytype_to_attrtype( elt_type = get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return onnx.AttributeProto.UNDEFINED + return None def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: - if isinstance(typeinfo, TensorType): + if isinstance(typeinfo, onnx_types.TensorType): return True - if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType): + if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False @@ -146,3 +147,52 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[ if get_origin(typeinfo) is tuple: return get_args(typeinfo) return (typeinfo,) + + +def _reduce_type_var_to_union(hint: typing.TypeVar): + """Reduce a TypeVar to a Union type on which we can use issubclass to check membership.""" + assert isinstance(hint, TypeVar) + + # If the TypeVar has a bound, use that. + if hint.__bound__ is not None: + return hint.__bound__ + + # If the TypeVar has no bound, use the first constraint. + if hint.__constraints__: + return Union.__getitem__(hint.__constraints__) + + return Any + + +def get_supported_input_types(pytype) -> list[str]: + """Returns a list of all supported input types for a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + A list of all supported input types for the given type annotation. + """ + supported_types: list[str] = [] + if typing.get_origin(pytype) is Union and isinstance(typing.get_args(pytype)[0], TypeVar): + # Recursively unpack TypeVars inside an Optional + for arg in typing.get_args(pytype): + supported_types.extend(get_supported_input_types(arg)) + return supported_types + + if isinstance(pytype, TypeVar): + pytype = _reduce_type_var_to_union(pytype) + + for tensor_type in onnx_types.ALL_TENSOR_TYPES: + if pytype is None: + # The same as Any + supported_types.append(tensor_type.to_string()) + elif pytype == onnx_types.TensorType: + supported_types.append(tensor_type.to_string()) + elif isinstance(pytype, tensor_type): + supported_types.append(tensor_type.to_string()) + elif issubclass(tensor_type, pytype): + supported_types.append(tensor_type.to_string()) + # TODO(justinchuby): Handle sequence types + + return supported_types diff --git a/onnxscript/values.py b/onnxscript/values.py index 121cac2d8a..5b35b8c2e6 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -8,13 +8,16 @@ import dataclasses import logging import types +import typing from enum import IntFlag -from typing import Any, Optional, Sequence, _GenericAlias # type: ignore[attr-defined] +from typing import _GenericAlias # type: ignore[attr-defined] +from typing import Any, Optional, Sequence, TypeVar, Union import onnx import onnx.defs -from onnxscript import irbuilder, sourceinfo +from onnxscript import irbuilder, sourceinfo, type_annotation +from onnxscript._internal import version_utils _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, @@ -35,6 +38,7 @@ # A special value to indicate that the default value is not specified _EmptyDefault = object() +_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14") class Opset: @@ -174,21 +178,27 @@ def __init__( ) -> None: self.opset = opset self.opname = opname - self.opschema = opschema + self._opschema = opschema self._param_schemas: Optional[tuple[ParamSchema, ...]] = None def __call__(self, *args, **kwargs): # FIXME(after #225): Move import to the top of the file. from onnxscript import evaluator # pylint: disable=import-outside-toplevel - return evaluator.default().eval(self.get_schema(), args, kwargs) + schema = self.get_schema() + assert schema is not None + return evaluator.default().eval(schema, args, kwargs) def is_single_op(self) -> bool: return isinstance(self.opname, str) + @property + def opschema(self) -> Optional[onnx.defs.OpSchema]: + return self._opschema + def get_schema(self) -> Optional[onnx.defs.OpSchema]: """Returns the ONNX OpSchema for this op.""" - if self.opschema: + if self.opschema is not None: return self.opschema return self.opset[self.opname] @@ -245,6 +255,24 @@ class OnnxClosure: function: Any +@dataclasses.dataclass +class TypeConstraint: + """Represents a type constraint for an ONNX op. + + Attributes: + name: The name of the type constraint. + allowed_types: The allowed types for the type constraint. + """ + + name: str + allowed_types: list[str] + description: str = "" + + def as_tuple(self) -> tuple[str, list[str], str]: + """Returns the type constraint as a tuple.""" + return (self.name, self.allowed_types, self.description) + + class OnnxFunction(Op): """Represents an ONNX op for which a function-body has been defined in onnxscript. @@ -272,12 +300,97 @@ def __init__( self.source = source self.kwargs = kwargs self._param_schemas: Optional[tuple[ParamSchema, ...]] = None + self._opschema: Optional[onnx.defs.OpSchema] = None @property def name(self): """Returns the function name.""" return self.opname + @property + def opschema(self) -> Optional[onnx.defs.OpSchema]: + """Construct an OpSchema from function_ir.""" + if self._opschema is not None: + return self._opschema + + if not _ONNX_OP_SCHEMA_WRITABLE: + return None + + function_ir = self.function_ir + # Find all distinct types in the inputs and outputs + distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union( + {arg.typeinfo for arg in function_ir.outputs} + ) + # Create a mapping from type to a unique name + type_to_constraint = {} + for i, type_ in enumerate(distinct_types): + if isinstance(type_, TypeVar): + name = type_.__name__ + # FIXME(justinchuby): Handle Optional[TypeVar] + else: + name = f"T{i}" + type_to_constraint[type_] = TypeConstraint( + name=name, + allowed_types=type_annotation.get_supported_input_types(type_), + ) + + formal_inputs = [ + onnx.defs.OpSchema.FormalParameter( + arg.name, + type_to_constraint[arg.typeinfo].name, + param_option=onnx.defs.OpSchema.FormalParameterOption.Optional + if typing.get_origin(arg.typeinfo) is Union and typing.get_args(arg.typeinfo) + else onnx.defs.OpSchema.FormalParameterOption.Single, + # TODO(justinchu): Check this is_homogeneous thing + is_homogeneous=True, + ) + for arg in function_ir.inputs + ] + formal_outputs = [ + onnx.defs.OpSchema.FormalParameter( + arg.name, + type_to_constraint[arg.typeinfo].name, + param_option=onnx.defs.OpSchema.FormalParameterOption.Optional + if typing.get_origin(arg.typeinfo) is Union and typing.get_args(arg.typeinfo) + else onnx.defs.OpSchema.FormalParameterOption.Single, + # TODO(justinchu): Check this is_homogeneous thing + is_homogeneous=True, + ) + for arg in function_ir.outputs + ] + + self._opschema = onnx.defs.OpSchema( + self.name, + self.opset.domain, + since_version=self.opset.version, + doc=function_ir.docstring, + inputs=formal_inputs, + outputs=formal_outputs, + type_constraints=[ + constraint.as_tuple() for constraint in type_to_constraint.values() + ], + attributes=[ + *[ + onnx.defs.OpSchema.Attribute( + attr.name, + type=onnx.defs.OpSchema.AttrType(attr.type), + ) + for attr in function_ir.attrs + if not attr.has_default + ], + *[ + onnx.defs.OpSchema.Attribute( + attr.name, + default_value=attr.attr_proto, + ) + for attr in function_ir.attrs + if attr.has_default + ], + ], + ) + + return self._opschema + def __getitem__(self, instance): """Returns a lambda to evaluate function using given evaluator instance. @@ -307,6 +420,9 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: if self._param_schemas is not None: return self._param_schemas + # NOTE: We generate the parameter schemas from the function_ir instead + # of relying on the auto generated OpSchema because we need to preserve the keyword + # argument order from the Python function definition, which is lost in OpSchema. function_ir = self.function_ir # The first len(func_ir.inputs) arguments are onnx inputs inputs = function_ir.inputs From 962c13b0b3942a63f3c23d0b1b85385f70440f47 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 22:08:18 +0000 Subject: [PATCH 04/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/converter_test.py | 6 ++---- onnxscript/values.py | 8 ++++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index f01392a3cf..8a4f10d153 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -269,7 +269,7 @@ def test_renaming(self): self.validate_save(renaming, shape_inference=False) - @unittest.skipIf(True, reason="TypeError: val must be numeric not ") + @unittest.skip(reason="TypeError: val must be numeric not ") def test_opt_output(self): from onnxscript.tests.models import opt_output @@ -280,9 +280,7 @@ def test_opt_input(self): self.validate_save(opt_input, shape_inference=False) - @unittest.skipIf( - True, reason="ValueError: A function with attributes " "cannot be exported as a model." - ) + @unittest.skip("ValueError: A function with attributes " "cannot be exported as a model.") def test_onnxfns2(self): from onnxscript.tests.models import onnxfns2 diff --git a/onnxscript/values.py b/onnxscript/values.py index 121cac2d8a..681150c63a 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -358,8 +358,12 @@ def to_function_proto(self): def to_model_proto(self, **kwargs): """Converts the function into :class:`onnx.ModelProto`.""" - if self.function_ir.attrs: - raise ValueError("A function with attributes cannot be exported as a model.") + if self.function_ir.attrs and any( + not attr.has_default for attr in self.function_ir.attrs + ): + raise ValueError( + "A function with required attributes cannot be exported as a model." + ) # Note: The function must also have monomorphic type annotation for inputs/outputs # to be converted into a valid model. Otherwise, we can still produce an ONNX # model, but it will not pass the ONNX model checker. We do not report an error From 821821c4a93ed396ece3149624e7f62d737b0030 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 12 Apr 2023 22:10:08 +0000 Subject: [PATCH 05/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/converter_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 8a4f10d153..16f4168718 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -280,7 +280,7 @@ def test_opt_input(self): self.validate_save(opt_input, shape_inference=False) - @unittest.skip("ValueError: A function with attributes " "cannot be exported as a model.") + @unittest.skip("A function with attributes cannot be exported as a model.") def test_onnxfns2(self): from onnxscript.tests.models import onnxfns2 From 6b9106bbfd8b760f558d05f040480e953c93b517 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 04:56:58 +0000 Subject: [PATCH 06/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/irbuilder.py | 71 +++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index bcb2f9d3f6..6c1d7011d0 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -111,13 +111,13 @@ class IRAttributeValue: """An attribute value (representing an actual parameter). Attributes: - attr_proto: The attribute proto - has_default: Whether the attribute has a default value. + attr_proto: The attribute proto. + name: The name of the attribute. + type: The type of the attribute. """ - def __init__(self, attrproto, has_default: bool) -> None: + def __init__(self, attrproto: onnx.AttributeProto) -> None: self.attr_proto = attrproto - self.has_default = has_default def __str__(self): if self.attr_proto.HasField("ref_attr_name"): @@ -126,14 +126,31 @@ def __str__(self): return helper.printable_attribute(self.attr_proto) @property - def name(self): + def name(self) -> str: return self.attr_proto.name @property - def type(self): + def type(self) -> onnx.AttributeProto.AttributeType: return self.attr_proto.type +class IRAttributeParameter(IRAttributeValue): + """An attribute parameter (representing a formal parameter). + + It may carry a formal parameter. + + Attributes: + attr_proto: The attribute proto + has_default: Whether the attribute has a default value. + name: The name of the attribute. + type: The type of the attribute. + """ + + def __init__(self, attrproto: onnx.AttributeProto, has_default: bool) -> None: + super().__init__(attrproto) + self.has_default = has_default + + class IRStmt: def __init__( self, @@ -197,7 +214,7 @@ def __init__(self, name: str, domain: str = "") -> None: self.outputs: list[IRVar] = [] self.stmts: list[IRStmt] = [] # attribute parameters - self.attrs: list[IRAttributeValue] = [] + self.attrs: list[IRAttributeParameter] = [] self.called_functions: dict[str, onnx.FunctionProto] = {} self.docstring: str = "" # a dictionary of nested function-definitions @@ -228,7 +245,7 @@ def append_input(self, name: IRVar) -> None: def append_output(self, name: IRVar) -> None: self.outputs.append(name) - def add_attr_parameter(self, attr: IRAttributeValue) -> None: + def add_attr_parameter(self, attr: IRAttributeParameter) -> None: self.attrs.append(attr) def debug_print(self): @@ -439,10 +456,10 @@ def __init__(self): def new_function(self, name: str, domain: str = "", register: bool = False): if register and (domain, name) in self.functions: raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.") - fct = IRFunction(name, domain) + function = IRFunction(name, domain) if register: - self.functions[domain, name] = fct - return fct + self.functions[domain, name] = function + return function def add_docstring(self, fn: IRFunction, docstring: str): fn.append_docstring(docstring) @@ -456,25 +473,25 @@ def add_stmt( attrs: Sequence[IRAttributeValue], sub_functions=None, ) -> None: - s = IRStmt(results, callee, args, attrs, sub_functions=sub_functions) - fn.append_stmt(s) + stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions) + fn.append_stmt(stmt) def add_input( self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo ) -> None: - v = IRVar(varname, type, info) - fn.append_input(v) + var = IRVar(varname, type, info) + fn.append_input(var) def add_attr_parameter( self, fn: IRFunction, varname: str, attribute_type: onnx.AttributeProto.AttributeType, - default_value, + default_value: int | float | str | None, ) -> None: if default_value is not None: fn.add_attr_parameter( - IRAttributeValue( + IRAttributeParameter( helper.make_attribute(varname, default_value), has_default=True ) ) @@ -482,18 +499,18 @@ def add_attr_parameter( proto = onnx.AttributeProto() proto.name = varname proto.type = attribute_type - fn.add_attr_parameter(IRAttributeValue(proto, has_default=False)) + fn.add_attr_parameter(IRAttributeParameter(proto, has_default=False)) - def add_output(self, fn: IRFunction, varname: str, type, info) -> None: - v = IRVar(varname, type, info) - fn.append_output(v) + def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: + var = IRVar(varname, typeinfo, sourceinfo) + fn.append_output(var) def make_attr(self, attrname: str, attrval: Any) -> IRAttributeValue: - return IRAttributeValue(helper.make_attribute(attrname, attrval), has_default=True) + return IRAttributeValue(helper.make_attribute(attrname, attrval)) def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue: - a = onnx.AttributeProto() - a.name = attrname - a.ref_attr_name = refname - a.type = ta.pytype_to_attrtype(pytype) - return IRAttributeValue(a, has_default=False) + proto = onnx.AttributeProto() + proto.name = attrname + proto.ref_attr_name = refname + proto.type = ta.pytype_to_attrtype(pytype) + return IRAttributeValue(proto) From bbe8e7e531d126019f2914a667d486abdaf838e7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 15:42:08 +0000 Subject: [PATCH 07/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/irbuilder.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 6c1d7011d0..427dcf8d3e 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- from __future__ import annotations +import dataclasses import io import logging import warnings @@ -134,7 +135,8 @@ def type(self) -> onnx.AttributeProto.AttributeType: return self.attr_proto.type -class IRAttributeParameter(IRAttributeValue): +@dataclasses.dataclass(frozen=True) +class IRAttributeParameter: """An attribute parameter (representing a formal parameter). It may carry a formal parameter. @@ -146,9 +148,22 @@ class IRAttributeParameter(IRAttributeValue): type: The type of the attribute. """ - def __init__(self, attrproto: onnx.AttributeProto, has_default: bool) -> None: - super().__init__(attrproto) - self.has_default = has_default + name: str + type: onnx.AttributeProto.AttributeType + default_value: str | int | float | None = None + + @property + def has_default(self): + return self.default_value is not None + + @property + def attr_proto(self): + if self.default_value is None: + proto = onnx.AttributeProto() + proto.name = self.name + proto.type = self.type + return self + return helper.make_attribute(self.name, self.default_value) class IRStmt: @@ -489,17 +504,7 @@ def add_attr_parameter( attribute_type: onnx.AttributeProto.AttributeType, default_value: int | float | str | None, ) -> None: - if default_value is not None: - fn.add_attr_parameter( - IRAttributeParameter( - helper.make_attribute(varname, default_value), has_default=True - ) - ) - else: - proto = onnx.AttributeProto() - proto.name = varname - proto.type = attribute_type - fn.add_attr_parameter(IRAttributeParameter(proto, has_default=False)) + fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value)) def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: var = IRVar(varname, typeinfo, sourceinfo) From 26d5caa96f865561995da004a3a67ac6a97872cb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 15:46:40 +0000 Subject: [PATCH 08/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/irbuilder.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 427dcf8d3e..f0b407378f 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -113,8 +113,6 @@ class IRAttributeValue: Attributes: attr_proto: The attribute proto. - name: The name of the attribute. - type: The type of the attribute. """ def __init__(self, attrproto: onnx.AttributeProto) -> None: @@ -142,27 +140,30 @@ class IRAttributeParameter: It may carry a formal parameter. Attributes: - attr_proto: The attribute proto - has_default: Whether the attribute has a default value. name: The name of the attribute. type: The type of the attribute. + has_default: Whether the attribute has a default value. + attr_proto: The attribute proto. """ name: str type: onnx.AttributeProto.AttributeType default_value: str | int | float | None = None + def __str__(self): + return helper.printable_attribute(self.attr_proto) + @property def has_default(self): return self.default_value is not None @property - def attr_proto(self): + def attr_proto(self) -> onnx.AttributeProto: if self.default_value is None: proto = onnx.AttributeProto() proto.name = self.name proto.type = self.type - return self + return proto return helper.make_attribute(self.name, self.default_value) From b3a035d2dd52bfcabaa04c83d30ef43b3f0955e3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 15:50:27 +0000 Subject: [PATCH 09/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/irbuilder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index f0b407378f..eac0a241d0 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -137,7 +137,7 @@ def type(self) -> onnx.AttributeProto.AttributeType: class IRAttributeParameter: """An attribute parameter (representing a formal parameter). - It may carry a formal parameter. + It may or may not carry a default value. Attributes: name: The name of the attribute. From 03cc7f431e5fa36c0745688834f85b0e102b98c5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 16:00:30 +0000 Subject: [PATCH 10/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/values.py | 39 +++++++++++++-------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index 681150c63a..86436177a2 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- from __future__ import annotations -import collections import dataclasses import logging import types @@ -311,13 +310,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # The first len(func_ir.inputs) arguments are onnx inputs inputs = function_ir.inputs # The rest is onnx attributes - # Construct a dictionary of attributes with their names specified in the function - # definition - attr_name_to_protos = collections.OrderedDict( - (attr.name, attr) for attr in function_ir.attrs - ) - # args with default value are attributes schemas = [] for arg in inputs: if isinstance(arg.typeinfo, onnx.TypeProto.Optional): @@ -328,26 +321,20 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required) ) - for name, attr_value in attr_name_to_protos.items(): - if not attr_value.has_default: - schemas.append( - ParamSchema( - name=name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], - is_input=False, - required=True, - ) - ) - else: - schemas.append( - ParamSchema( - name=name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attr_value.type], - default=_get_attribute_value(attr_value.attr_proto), - is_input=False, - required=True, - ) + for attr_parameter in function_ir.attrs: + schemas.append( + ParamSchema( + name=attr_parameter.name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( + onnx.defs.OpSchema.AttrType(attr_parameter.type) + ), + default=_EmptyDefault + if attr_parameter.default_value is None + else attr_parameter.default_value, + is_input=False, + required=not attr_parameter.has_default, ) + ) self._param_schemas = tuple(schemas) return self._param_schemas # type: ignore[return-value] From 7321b224cb02652e9f2160614c91b32f6bc083aa Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Apr 2023 23:37:20 +0000 Subject: [PATCH 11/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/irbuilder.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index eac0a241d0..0d4f649331 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -112,6 +112,8 @@ class IRAttributeValue: """An attribute value (representing an actual parameter). Attributes: + name: The name of the attribute. + type: The type of the attribute. attr_proto: The attribute proto. """ @@ -151,7 +153,10 @@ class IRAttributeParameter: default_value: str | int | float | None = None def __str__(self): - return helper.printable_attribute(self.attr_proto) + if self.has_default: + return helper.printable_attribute(self.attr_proto) + # TODO(justinchuby): Include a readable type name. + return self.name @property def has_default(self): @@ -159,11 +164,11 @@ def has_default(self): @property def attr_proto(self) -> onnx.AttributeProto: - if self.default_value is None: - proto = onnx.AttributeProto() - proto.name = self.name - proto.type = self.type - return proto + if not self.has_default: + raise ValueError( + "Attribute has no default value. Only attributes with default " + "values can be converted to AttributeProto." + ) return helper.make_attribute(self.name, self.default_value) From efc770896bb053aa52f441fe101dde94ff7f17e5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Apr 2023 01:41:26 +0000 Subject: [PATCH 12/28] Update base for Update on "Auto generate OpSchema for functions | feat" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ### TODO Test on all torch_lib functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/onnx_types.py | 87 ++++++++++++++++++++++++++++------- onnxscript/type_annotation.py | 64 ++++++++++++++++++++++++-- 2 files changed, 130 insertions(+), 21 deletions(-) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 32e48bf744..0e8f8b8b8c 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -105,69 +105,105 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = [cls.shape] # example: "FLOAT[10]" return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + @classmethod + def to_string(cls) -> str: + raise NotImplementedError() + class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): - pass + @classmethod + def to_string(cls): + return "tensor(float)" class UINT8(TensorType, dtype=onnx.TensorProto.UINT8): - pass + @classmethod + def to_string(cls): + return "tensor(uint8)" class INT8(TensorType, dtype=onnx.TensorProto.INT8): - pass + @classmethod + def to_string(cls): + return "tensor(int8)" class UINT16(TensorType, dtype=onnx.TensorProto.UINT16): - pass + @classmethod + def to_string(cls): + return "tensor(uint16)" class INT16(TensorType, dtype=onnx.TensorProto.INT16): - pass + @classmethod + def to_string(cls): + return "tensor(int16)" class INT32(TensorType, dtype=onnx.TensorProto.INT32): - pass + @classmethod + def to_string(cls): + return "tensor(int32)" class INT64(TensorType, dtype=onnx.TensorProto.INT64): - pass + @classmethod + def to_string(cls): + return "tensor(int64)" class STRING(TensorType, dtype=onnx.TensorProto.STRING): - pass + @classmethod + def to_string(cls): + return "tensor(string)" class BOOL(TensorType, dtype=onnx.TensorProto.BOOL): - pass + @classmethod + def to_string(cls): + return "tensor(bool)" class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16): - pass + @classmethod + def to_string(cls): + return "tensor(float16)" class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE): - pass + @classmethod + def to_string(cls): + return "tensor(double)" class UINT32(TensorType, dtype=onnx.TensorProto.UINT32): - pass + @classmethod + def to_string(cls): + return "tensor(uint32)" class UINT64(TensorType, dtype=onnx.TensorProto.UINT64): - pass + @classmethod + def to_string(cls): + return "tensor(uint64)" class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64): - pass + @classmethod + def to_string(cls): + return "tensor(complex64)" class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128): - pass + @classmethod + def to_string(cls): + return "tensor(complex128)" class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16): - pass + @classmethod + def to_string(cls): + return "tensor(bfloat16)" def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: @@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: # Currently, only tensor types are supported. Need to expand support for other ONNX types. ONNXType = TensorType + +ALL_TENSOR_TYPES = ( + BFLOAT16, + BOOL, + COMPLEX128, + COMPLEX64, + DOUBLE, + FLOAT, + FLOAT16, + INT16, + INT32, + INT64, + INT8, + STRING, + UINT16, + UINT32, + UINT64, + UINT8, +) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index c5e9f0a704..2afd54b5dd 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -7,11 +7,12 @@ import collections import inspect import typing +from typing import Any, TypeVar, Union import onnx from typing_extensions import get_args, get_origin -from onnxscript.onnx_types import TensorType +from onnxscript import onnx_types # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports @@ -59,7 +60,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> onnx.AttributeProto.AttributeType: +) -> typing.Optional[onnx.AttributeProto.AttributeType]: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] @@ -74,13 +75,13 @@ def pytype_to_attrtype( elt_type = get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return onnx.AttributeProto.UNDEFINED + return None def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: - if isinstance(typeinfo, TensorType): + if isinstance(typeinfo, onnx_types.TensorType): return True - if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType): + if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False @@ -146,3 +147,56 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[ if get_origin(typeinfo) is tuple: return get_args(typeinfo) return (typeinfo,) + + +def _reduce_type_var_to_union(hint: typing.TypeVar): + """Reduce a TypeVar to a Union type on which we can use issubclass to check membership.""" + assert isinstance(hint, TypeVar) + + # If the TypeVar has a bound, use that. + if hint.__bound__ is not None: + return hint.__bound__ + + # If the TypeVar has no bound, use the first constraint. + if hint.__constraints__: + return Union.__getitem__(hint.__constraints__) + + return Any + + +def get_supported_input_types(pytype) -> list[str]: + """Returns a list of all supported input types for a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + A list of all supported input types for the given type annotation. + """ + # TODO: Change this to + supported_types: list[str] = [] + if typing.get_origin(pytype) is Union and isinstance(typing.get_args(pytype)[0], TypeVar): + # Recursively unpack TypeVars inside an Optional + for arg in typing.get_args(pytype): + supported_types.extend(get_supported_input_types(arg)) + return supported_types + + if isinstance(pytype, TypeVar): + pytype = _reduce_type_var_to_union(pytype) + + for tensor_type in onnx_types.ALL_TENSOR_TYPES: + if pytype is None: + # The same as Any + supported_types.append(tensor_type.to_string()) + elif pytype == onnx_types.TensorType: + supported_types.append(tensor_type.to_string()) + elif isinstance(pytype, tensor_type): + supported_types.append(tensor_type.to_string()) + elif issubclass(tensor_type, pytype): + supported_types.append(tensor_type.to_string()) + # TODO(justinchuby): Handle sequence types + + return supported_types + + +def get_type_var_name From 67a8ee07f6dc17a03b838329691b0ac10479aebd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Apr 2023 02:24:06 +0000 Subject: [PATCH 13/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation.py | 22 +++++++++++----------- onnxscript/type_annotation_test.py | 6 +++++- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 2afd54b5dd..b6034ad506 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -16,13 +16,13 @@ # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports -# * float, int, str (primitive attribute types) -# * Sequence[float], Sequence[int], Sequence[str] (attribute types) -# * Tensor types -# * Sequence[Tensor] types -# * Union of above 2 -# * TypeVars with above bounds -# * Above types with annotation attached +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached TypeAnnotationValue = typing.Any # Map from python type to corresponding ONNX AttributeProto type @@ -164,8 +164,8 @@ def _reduce_type_var_to_union(hint: typing.TypeVar): return Any -def get_supported_input_types(pytype) -> list[str]: - """Returns a list of all supported input types for a given type annotation. +def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: + """Returns a list of all supported input types in string representation for a given type annotation. Args: pytype: A type annotation. @@ -173,7 +173,6 @@ def get_supported_input_types(pytype) -> list[str]: Returns: A list of all supported input types for the given type annotation. """ - # TODO: Change this to supported_types: list[str] = [] if typing.get_origin(pytype) is Union and isinstance(typing.get_args(pytype)[0], TypeVar): # Recursively unpack TypeVars inside an Optional @@ -199,4 +198,5 @@ def get_supported_input_types(pytype) -> list[str]: return supported_types -def get_type_var_name +def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: + pass diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 2d8ce5bf80..a2d5e79ebb 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -14,7 +14,7 @@ from onnxscript.tests.common import testutils -class TypeAnnotationTester(testutils.TestBase): +class TypeAnnotationTest(testutils.TestBase): def test_type_annotation(self): """Test type annotations.""" @@ -92,5 +92,9 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: self.assertSameFunction(bool_type_for_attribute, bool_type_for_attribute_txt) +class UtilityFunctionsTest(unittest.TestCase): + def test_pytype_to_input_strings(self): + pass + if __name__ == "__main__": unittest.main() From 622b68827a266f771c9656870f1771ab5547d47e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Apr 2023 00:28:31 +0000 Subject: [PATCH 14/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation_test.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 68790f643c..5d5494d4bc 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -3,14 +3,20 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from typing import TypeVar, Union import unittest +import parameterized + import onnxscript.testing +import onnxscript from onnxscript import script from onnxscript.onnx_opset import opset15 as op from onnxscript.onnx_types import FLOAT from onnxscript.tests.common import testutils +from onnxscript import type_annotation + class TypeAnnotationTest(testutils.TestBase): def test_type_annotation(self): @@ -87,9 +93,27 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: bool_type_for_attribute, bool_type_for_attribute_txt ) +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", onnxscript.INT64, onnxscript.FLOAT) +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=onnxscript.INT64) +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[onnxscript.INT64, onnxscript.FLOAT]) + class UtilityFunctionsTest(unittest.TestCase): - def test_pytype_to_input_strings(self): + @parameterized.parameterized.expand( + [ + ("tensor_type", onnxscript.onnx_types.TensorType, type_annotation.ALL_TYPE_STRINGS), + ("tensor_type", onnxscript.INT64, ["tensor(int64)"]), + ("tensor_type_variadic_shape", onnxscript.INT64[...], ["tensor(int64)"]), + ("tensor_type_shape", onnxscript.INT64[10], ["tensor(int64)"]), + ("type_var_constraints", _TestTypeVarConstraints, ["tensor(int64)"]), + ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), + ("type_bound_two", _TestTypeVarTwoBound, ["tensor(int64)", "tensor(float)"]), + ] + ) + def test_pytype_to_input_strings(self, _, pytype: Any, expected) + pass + + def get_type_constraint_name(self, _: str, typevar, expected): pass if __name__ == "__main__": From 31bb69c4ec5ff876a850269ac84e382597457c4a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Apr 2023 16:23:06 +0000 Subject: [PATCH 15/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation.py | 4 +- onnxscript/type_annotation_test.py | 155 ++++++++++++++++++++++++++--- 2 files changed, 144 insertions(+), 15 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index b6034ad506..7764e6dd18 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -199,4 +199,6 @@ def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: - pass + if isinstance(pytype, TypeVar): + return pytype.__name__ + return None diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 5d5494d4bc..28d0cb8f18 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import TypeVar, Union +from typing import Optional, Sequence, TypeVar, Union import unittest import parameterized @@ -12,7 +12,7 @@ import onnxscript from onnxscript import script from onnxscript.onnx_opset import opset15 as op -from onnxscript.onnx_types import FLOAT +from onnxscript import FLOAT, INT64 from onnxscript.tests.common import testutils from onnxscript import type_annotation @@ -93,28 +93,155 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: bool_type_for_attribute, bool_type_for_attribute_txt ) -_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", onnxscript.INT64, onnxscript.FLOAT) -_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=onnxscript.INT64) -_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[onnxscript.INT64, onnxscript.FLOAT]) + +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) class UtilityFunctionsTest(unittest.TestCase): @parameterized.parameterized.expand( [ - ("tensor_type", onnxscript.onnx_types.TensorType, type_annotation.ALL_TYPE_STRINGS), - ("tensor_type", onnxscript.INT64, ["tensor(int64)"]), - ("tensor_type_variadic_shape", onnxscript.INT64[...], ["tensor(int64)"]), - ("tensor_type_shape", onnxscript.INT64[10], ["tensor(int64)"]), - ("type_var_constraints", _TestTypeVarConstraints, ["tensor(int64)"]), + ( + "tensor_type_all", + onnxscript.onnx_types.TensorType, + type_annotation.ALL_TYPE_STRINGS, + ), + ("tensor_type", INT64, ["tensor(int64)"]), + ("tensor_type_union", Union[INT64, FLOAT], ["tensor(int64)", "tensor(float)"]), + ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), + ("tensor_type_shape", INT64[10], ["tensor(int64)"]), + ( + "type_var_constraints", + _TestTypeVarConstraints, + ["tensor(int64)", "tensor(float)"], + ), ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), ("type_bound_two", _TestTypeVarTwoBound, ["tensor(int64)", "tensor(float)"]), + ( + "optional_tensor_type_all", + Optional[onnxscript.onnx_types.TensorType], + type_annotation.ALL_TYPE_STRINGS + + [ + f"optional({tensor_type})" + for tensor_type in type_annotation.ALL_TYPE_STRINGS + ], + ), + ( + "optional_tensor_type", + Optional[INT64], + ["tensor(int64)", "optional(tensor(int64))"], + ), + ( + "optional_tensor_type_union", + Optional[Union[INT64, FLOAT]], + [ + "tensor(int64)", + "tensor(float)", + "optional(tensor(int64))", + "optional(tensor(float))", + ], + ), + ( + "optional_tensor_type_variadic_shape", + Optional[INT64[...]], + ["tensor(int64)", "optional(tensor(int64))"], + ), + ( + "optional_tensor_type_shape", + Optional[INT64[10]], + ["tensor(int64)", "optional(tensor(int64))"], + ), + ( + "optional_type_var_constraints", + Optional[_TestTypeVarConstraints], + [ + "tensor(int64)", + "tensor(float)", + "optional(tensor(int64))", + "optional(tensor(float))", + ], + ), + ( + "optional_type_bound_one", + Optional[_TestTypeVarOneBound], + ["tensor(int64)", "optional(tensor(int64))"], + ), + ( + "optional_type_bound_two", + Optional[_TestTypeVarTwoBound], + [ + "tensor(int64)", + "tensor(float)", + "optional(tensor(int64))", + "optional(tensor(float))", + ], + ), + ( + "sequence_type_all", + Sequence[onnxscript.onnx_types.TensorType], + [ + f"sequence({tensor_type})" + for tensor_type in type_annotation.ALL_TYPE_STRINGS + ], + ), + ("sequence_type", Sequence[INT64], ["sequence(tensor(int64))"]), + ( + "union_sequence_type", + Union[Sequence[INT64], Sequence[FLOAT]], + ["sequence(tensor(int64))", "sequence(tensor(float))"], + ), + ( + "sequence_type_variadic_shape", + Sequence[INT64[...]], + ["sequence(tensor(int64))"], + ), + ("sequence_type_shape", Sequence[INT64[10]], ["sequence(tensor(int64))"]), + ( + "sequence_type_var_constraints", + Sequence[_TestTypeVarConstraints], + ["sequence(tensor(int64))", "sequence(tensor(float))"], + ), + ( + "sequence_type_bound_one", + Sequence[_TestTypeVarOneBound], + ["sequence(tensor(int64))"], + ), + ( + "sequence_type_bound_two", + Sequence[_TestTypeVarTwoBound], + ["sequence(tensor(int64))", "sequence(tensor(float))"], + ), + ] + ) + def test_pytype_to_input_strings(self, _, pytype: Any, expected): + self.assertEqual(type_annotation.pytype_to_input_strings(pytype), expected) + + @parameterized.parameterized.expand( + [ + ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), + ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), + ( + "optional_type_var", + Optional[_TestTypeVarOneBound], + "Optional_TestTypeVarOneBound", + ), + ( + "sequence_type_var", + Sequence[_TestTypeVarOneBound], + "Sequence_TestTypeVarOneBound", + ), + ("normal_type", INT64, "None"), + ("union_type", Union[INT64, FLOAT], None), + ("optional_type", Optional[INT64], None), + ("sequence_type", Sequence[INT64], None), + ("optional_sequence_type", Optional[Sequence[INT64]], None), + ("optional_union_type", Optional[Union[INT64, FLOAT]], None), ] ) - def test_pytype_to_input_strings(self, _, pytype: Any, expected) - pass + def get_type_constraint_name(self, _: str, pytype, expected): + self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) - def get_type_constraint_name(self, _: str, typevar, expected): - pass if __name__ == "__main__": unittest.main() From 2c6be9280782db98740abb0bcf9537a3a2f23c89 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Apr 2023 19:16:16 +0000 Subject: [PATCH 16/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation.py | 17 +++++++++++++++++ onnxscript/type_annotation_test.py | 27 +++++++++++++-------------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 7764e6dd18..89b6bdeab5 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -44,6 +44,23 @@ _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) +ALL_TYPE_STRINGS = ( + "tensor(bfloat16)", + "tensor(bool)", + "tensor(double)", + "tensor(float)", + "tensor(float16)", + "tensor(int16)", + "tensor(int32)", + "tensor(int64)", + "tensor(int8)", + "tensor(string)", + "tensor(uint16)", + "tensor(uint32)", + "tensor(uint64)", + "tensor(uint8)", +) + def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: """Remove Annotated wrapper if present, otherwise return typeinfo as is.""" diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 28d0cb8f18..768a0812c5 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -3,20 +3,17 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import Optional, Sequence, TypeVar, Union import unittest +from typing import Any, Optional, Sequence, TypeVar, Union import parameterized -import onnxscript.testing import onnxscript -from onnxscript import script +import onnxscript.testing +from onnxscript import FLOAT, INT64, script, type_annotation from onnxscript.onnx_opset import opset15 as op -from onnxscript import FLOAT, INT64 from onnxscript.tests.common import testutils -from onnxscript import type_annotation - class TypeAnnotationTest(testutils.TestBase): def test_type_annotation(self): @@ -99,7 +96,7 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: _TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) -class UtilityFunctionsTest(unittest.TestCase): +class TypeConversionFunctionsTest(unittest.TestCase): @parameterized.parameterized.expand( [ ( @@ -121,10 +118,12 @@ class UtilityFunctionsTest(unittest.TestCase): ( "optional_tensor_type_all", Optional[onnxscript.onnx_types.TensorType], - type_annotation.ALL_TYPE_STRINGS - + [ - f"optional({tensor_type})" - for tensor_type in type_annotation.ALL_TYPE_STRINGS + [ + *type_annotation.ALL_TYPE_STRINGS, + *[ + f"optional({tensor_type})" + for tensor_type in type_annotation.ALL_TYPE_STRINGS + ], ], ), ( @@ -214,7 +213,7 @@ class UtilityFunctionsTest(unittest.TestCase): ), ] ) - def test_pytype_to_input_strings(self, _, pytype: Any, expected): + def test_pytype_to_input_strings(self, _, pytype: Any, expected: list[str]): self.assertEqual(type_annotation.pytype_to_input_strings(pytype), expected) @parameterized.parameterized.expand( @@ -231,7 +230,7 @@ def test_pytype_to_input_strings(self, _, pytype: Any, expected): Sequence[_TestTypeVarOneBound], "Sequence_TestTypeVarOneBound", ), - ("normal_type", INT64, "None"), + ("normal_type", INT64, None), ("union_type", Union[INT64, FLOAT], None), ("optional_type", Optional[INT64], None), ("sequence_type", Sequence[INT64], None), @@ -239,7 +238,7 @@ def test_pytype_to_input_strings(self, _, pytype: Any, expected): ("optional_union_type", Optional[Union[INT64, FLOAT]], None), ] ) - def get_type_constraint_name(self, _: str, pytype, expected): + def get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) From 79d3605bf4b2ee280c1347703fe1225bdc14533b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 20 Apr 2023 17:16:57 +0000 Subject: [PATCH 17/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation.py | 6 ++++-- onnxscript/type_annotation_test.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index f01a5890b0..5a295d461e 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -204,10 +204,12 @@ def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: if isinstance(pytype, onnx_types.TensorType): return [pytype.to_string()] if isinstance(pytype, TypeVar): + constraints = pytype.__constraints__ + if constraints: + return pytype_to_input_strings(Union.__getitem__(constraints)) bound = pytype.__bound__ if bound is None: return list(ALL_TYPE_STRINGS) - # TODO(justinchuby): Handle __constraints__ return pytype_to_input_strings(bound) if typing.get_origin(pytype) is Union: options = [] @@ -216,7 +218,7 @@ def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: is_optional = any(subtype is type(None) for subtype in subtypes) for subtype in subtypes: if subtype is type(None): - # Skip None type + # Skip None type because we are handling it with is_optional continue if is_optional: options += [ diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 9fe011baf0..015c6dde67 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -102,7 +102,7 @@ class TypeConversionFunctionsTest(unittest.TestCase): ( "tensor_type_all", onnxscript.onnx_types.TensorType, - type_annotation.ALL_TYPE_STRINGS, + list(type_annotation.ALL_TYPE_STRINGS), ), ("tensor_type", INT64, ["tensor(int64)"]), ("tensor_type_union", Union[INT64, FLOAT], ["tensor(float)", "tensor(int64)"]), From 1439aaf76dda9b179f58b55393b351dbc30bb9b2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 20 Apr 2023 17:35:57 +0000 Subject: [PATCH 18/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation.py | 25 ++++++++++++++++++++++++- onnxscript/type_annotation_test.py | 26 ++++++++++++-------------- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 5a295d461e..921d56b2a9 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -195,6 +195,8 @@ def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: A list of all supported input types for the given type annotation. Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS. """ + if pytype is None: + return list(ALL_TYPE_STRINGS) if pytype is type(None): return list(ALL_TYPE_STRINGS) if pytype is onnx_types.TensorType: @@ -231,12 +233,33 @@ def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: return sorted(set(options)) if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: subtypes = typing.get_args(pytype) - return [f"sequence({s})" for s in pytype_to_input_strings(subtypes[0])] + return [f"seq({s})" for s in pytype_to_input_strings(subtypes[0])] raise ValueError(f"Unsupported type: {pytype}") def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: + """Returns the name of the type constraint for a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar]. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + - Returns None if the type annotation does not have a type constraint. + """ if isinstance(pytype, TypeVar): return pytype.__name__ + if typing.get_origin(pytype) is Union: + subtypes = typing.get_args(pytype) + if len(subtypes) == 2 and type(None) in subtypes: + for subtype in subtypes: + if isinstance(subtype, TypeVar): + return f"Optional_{subtype.__name__}" + if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: + subtypes = typing.get_args(pytype) + if len(subtypes) == 1 and isinstance(subtypes[0], TypeVar): + return f"Sequence_{get_type_constraint_name(subtypes[0])}" return None diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 015c6dde67..50a8c31497 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -104,6 +104,7 @@ class TypeConversionFunctionsTest(unittest.TestCase): onnxscript.onnx_types.TensorType, list(type_annotation.ALL_TYPE_STRINGS), ), + ("none", None, list(type_annotation.ALL_TYPE_STRINGS)), ("tensor_type", INT64, ["tensor(int64)"]), ("tensor_type_union", Union[INT64, FLOAT], ["tensor(float)", "tensor(int64)"]), ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), @@ -179,37 +180,34 @@ class TypeConversionFunctionsTest(unittest.TestCase): ( "sequence_type_all", Sequence[onnxscript.onnx_types.TensorType], - [ - f"sequence({tensor_type})" - for tensor_type in type_annotation.ALL_TYPE_STRINGS - ], + [f"seq({tensor_type})" for tensor_type in type_annotation.ALL_TYPE_STRINGS], ), - ("sequence_type", Sequence[INT64], ["sequence(tensor(int64))"]), + ("sequence_type", Sequence[INT64], ["seq(tensor(int64))"]), ( "union_sequence_type", Union[Sequence[INT64], Sequence[FLOAT]], - ["sequence(tensor(float))", "sequence(tensor(int64))"], + ["seq(tensor(float))", "seq(tensor(int64))"], ), ( "sequence_type_variadic_shape", Sequence[INT64[...]], - ["sequence(tensor(int64))"], + ["seq(tensor(int64))"], ), - ("sequence_type_shape", Sequence[INT64[10]], ["sequence(tensor(int64))"]), + ("sequence_type_shape", Sequence[INT64[10]], ["seq(tensor(int64))"]), ( "sequence_type_var_constraints", Sequence[_TestTypeVarConstraints], - ["sequence(tensor(float))", "sequence(tensor(int64))"], + ["seq(tensor(float))", "seq(tensor(int64))"], ), ( "sequence_type_bound_one", Sequence[_TestTypeVarOneBound], - ["sequence(tensor(int64))"], + ["seq(tensor(int64))"], ), ( "sequence_type_bound_two", Sequence[_TestTypeVarTwoBound], - ["sequence(tensor(float))", "sequence(tensor(int64))"], + ["seq(tensor(float))", "seq(tensor(int64))"], ), ] ) @@ -223,12 +221,12 @@ def test_pytype_to_input_strings(self, _, pytype: Any, expected: list[str]): ( "optional_type_var", Optional[_TestTypeVarOneBound], - "Optional_TestTypeVarOneBound", + "Optional__TestTypeVarOneBound", ), ( "sequence_type_var", Sequence[_TestTypeVarOneBound], - "Sequence_TestTypeVarOneBound", + "Sequence__TestTypeVarOneBound", ), ("normal_type", INT64, None), ("union_type", Union[INT64, FLOAT], None), @@ -238,7 +236,7 @@ def test_pytype_to_input_strings(self, _, pytype: Any, expected: list[str]): ("optional_union_type", Optional[Union[INT64, FLOAT]], None), ] ) - def get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): + def test_get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) From 138c2ed38ad8d88d36df5f6d9698755b03205fd2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Apr 2023 00:14:52 +0000 Subject: [PATCH 19/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] From 90208e164825faba8610b2165d8f3f167c6ac0b9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Apr 2023 00:23:17 +0000 Subject: [PATCH 20/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 6f2970e14d..4359637e0c 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- import unittest -from typing import Any, Optional, Sequence, TypeVar, Union +from typing import Any, List, Optional, Sequence, TypeVar, Union import parameterized @@ -214,7 +214,7 @@ class TypeConversionFunctionsTest(unittest.TestCase): ), ] ) - def test_pytype_to_input_strings(self, _, pytype: Any, expected: list[str]): + def test_pytype_to_input_strings(self, _, pytype: Any, expected: List[str]): self.assertEqual(type_annotation.pytype_to_input_strings(pytype), expected) @parameterized.parameterized.expand( From 2d9627e47bff48ce2234494d56809078c565665a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Apr 2023 00:25:19 +0000 Subject: [PATCH 21/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/onnx_types.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 09a05e17c0..4951a3ff18 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -207,22 +207,3 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: # Currently, only tensor types are supported. Need to expand support for other ONNX types. ONNXType = TensorType - -ALL_TENSOR_TYPES = ( - BFLOAT16, - BOOL, - COMPLEX128, - COMPLEX64, - DOUBLE, - FLOAT, - FLOAT16, - INT16, - INT32, - INT64, - INT8, - STRING, - UINT16, - UINT32, - UINT64, - UINT8, -) From 8f5f7ba82594b3c2f9ffa908583d2d788d20ab69 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sat, 22 Apr 2023 00:30:28 +0000 Subject: [PATCH 22/28] Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)" This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=, is_homogeneous=True, min_arity=1, differentiation_category=)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=, description='', default_value=, required=True)} ) ``` Fixes https://github.com/microsoft/onnx-script/issues/476 [ghstack-poisoned] --- onnxscript/type_annotation.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 2e5c42f873..8ce0fb38eb 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -151,6 +151,11 @@ def is_valid_type(typeinfo: TypeAnnotationValue): return False +def is_optional(pytype) -> bool: + """Returns whether a pytype is an Optional.""" + return typing.get_origin(pytype) is typing.Union and type(None) in typing.get_args(pytype) + + def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]: """Converts return-type annotation into a sequence of types. @@ -196,12 +201,12 @@ def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]: options = [] subtypes = typing.get_args(pytype) # A None type in a Union is equivalent to an optional type - is_optional = any(subtype is type(None) for subtype in subtypes) + optional = is_optional(pytype) for subtype in subtypes: if subtype is type(None): # Skip None type because we are handling it with is_optional continue - if is_optional: + if optional: options += [ *pytype_to_input_strings(subtype), *[f"optional({s})" for s in pytype_to_input_strings(subtype)], @@ -231,17 +236,13 @@ def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: """ if isinstance(pytype, typing.TypeVar): return pytype.__name__ - if typing.get_origin(pytype) is Union: + if is_optional(pytype): subtypes = typing.get_args(pytype) - if len(subtypes) == 2 and type(None) in subtypes: - # An Optional type - for subtype in subtypes: - if subtype is type(None): - continue - type_param_name = get_type_constraint_name(subtype) - return f"Optional_{type_param_name}" if type_param_name else None - else: - return None + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = get_type_constraint_name(subtype) + return f"Optional_{type_param_name}" if type_param_name else None if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: subtypes = typing.get_args(pytype) type_param_name = get_type_constraint_name(subtypes[0]) From dc0963b1699907abec79fc9cf1580d10b954a2a8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Apr 2023 02:20:59 +0000 Subject: [PATCH 23/28] Refactor converter to isolate translate_function_signature logic | feat(converter) Signed-off-by: Justin Chu [ghstack-poisoned] --- onnxscript/converter.py | 103 +++++++++++++++++++++++----------------- 1 file changed, 59 insertions(+), 44 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 75f3a8dafb..2369180cd2 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -180,6 +180,13 @@ def __init__( self.this_module = opset self.default_opset_ = default_opset + # States initialized by `init_function_translation` + self._outer = [] + self._current_fn = None + self._nextvar = 0 + self._used_vars = set() + self._locals: List[Dict[Any, Any]] = [{}] + @property def default_opset(self): if self.default_opset_ is None: @@ -222,14 +229,14 @@ def find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: def init_function_translation(self): """Initialize self for translating a new (top-level) function.""" - self.outer = [] - self.current_fn = None - self.nextvar = 0 - self.used_vars = set() - self.locals: List[Dict[Any, Any]] = [{}] + self._outer = [] + self._current_fn = None + self._nextvar = 0 + self._used_vars = set() + self._locals: List[Dict[Any, Any]] = [{}] def source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: - return sourceinfo.SourceInfo(node, self.source, self.current_fn.name) + return sourceinfo.SourceInfo(node, self.source, self._current_fn.name) def message(self, node: ast.AST, error_msg: str) -> str: """Constructs an error message containing source information about an ast node.""" @@ -252,28 +259,28 @@ def enter_scope(self, name, parent_node): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ - self.outer.insert(0, self.current_fn) - self.current_fn = self.ir_builder.new_function(name) - self.locals.insert(0, {}) - logger.debug("Converter:enter_scope:%d:node:%s", len(self.locals), type(parent_node)) + self._outer.insert(0, self._current_fn) + self._current_fn = self.ir_builder.new_function(name) + self._locals.insert(0, {}) + logger.debug("Converter:enter_scope:%d:node:%s", len(self._locals), type(parent_node)) def exit_scope(self): """Exit from a control-flow block (a loop body or if-then-else branch).""" - logger.debug("Converter:exit_scope:%d", len(self.locals)) - graph = self.current_fn - self.current_fn = self.outer.pop(0) - self.locals.pop(0) + logger.debug("Converter:exit_scope:%d", len(self._locals)) + graph = self._current_fn + self._current_fn = self._outer.pop(0) + self._locals.pop(0) return graph def current_scope(self): - return self.locals[0] + return self._locals[0] def bind(self, name, val): logger.debug("Converter:bind:%s", name) - self.locals[0][name] = val + self._locals[0][name] = val def lookup(self, name, info, raise_exception=True): - for scope in self.locals: + for scope in self._locals: if name in scope: return scope[name] if name in self.globals: @@ -285,10 +292,10 @@ def lookup(self, name, info, raise_exception=True): def generate_unique_name(self, candidate: str = "tmp") -> str: # TODO(justinchuby): Can we reduce the O complexity of this function? r = candidate - while r in self.used_vars: - r = f"{candidate}_{self.nextvar}" - self.nextvar = self.nextvar + 1 - self.used_vars.add(r) + while r in self._used_vars: + r = f"{candidate}_{self._nextvar}" + self._nextvar = self._nextvar + 1 + self._used_vars.add(r) return r def to_onnx_attr_ref(self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]): @@ -326,11 +333,11 @@ def py_var_to_onnx_var(self, py_var, info: sourceinfo.SourceInfo): return self.to_onnx_var(self.lookup(py_var, info), target=py_var, info=info) def emit_docstring(self, docstring): - self.ir_builder.add_docstring(self.current_fn, docstring) + self.ir_builder.add_docstring(self._current_fn, docstring) def emit(self, outputs, callee, inputs, attrs, sub_functions=None): self.ir_builder.add_stmt( - self.current_fn, + self._current_fn, outputs, callee, inputs, @@ -898,7 +905,7 @@ def translate_callee_expr(self, node) -> values.Op: # pylint: disable=R1710 function_name = node.id found = self.lookup(function_name, self.source_of(node), raise_exception=False) if isinstance(found, onnxscript.OnnxFunction): - self.current_fn.add_called_function(found) + self._current_fn.add_called_function(found) return found if isinstance(found, values.Op): return found @@ -1016,7 +1023,7 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self.ir_builder.add_output(self.current_fn, return_var, t, self.source_of(stmt)) + self.ir_builder.add_output(self._current_fn, return_var, t, self.source_of(stmt)) return return_var val = stmt.value @@ -1123,7 +1130,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): self.enter_scope("loop_body", loop_stmt) o_loop_var = self.generate_unique_name(p_loop_var) self.ir_builder.add_input( - self.current_fn, + self._current_fn, o_loop_var, onnx_types.INT64, self.source_of(loop_stmt), @@ -1134,7 +1141,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): ) self.ir_builder.add_input( - self.current_fn, + self._current_fn, i_cond_var, onnx_types.BOOL, self.source_of(loop_stmt), @@ -1145,7 +1152,9 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self.eval_constant_expr(pv.annotation) typeinfo = None - self.ir_builder.add_input(self.current_fn, ov, typeinfo, self.source_of(loop_stmt)) + self.ir_builder.add_input( + self._current_fn, ov, typeinfo, self.source_of(loop_stmt) + ) self.bind( pv, values.Dynamic(ov, values.DynamicKind.Loop, self.source_of(loop_stmt)), @@ -1201,14 +1210,14 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): ) self.ir_builder.add_output( - self.current_fn, + self._current_fn, o_cond_out, onnx_types.BOOL, self.source_of(loop_stmt), ) for pv in loop_state_vars: ov = self.py_var_to_onnx_var(pv, self.source_of(loop_stmt)) - if ov not in self.current_fn.assigned_names: + if ov not in self._current_fn.assigned_names: # When converting the loop-body into a graph, we need to handle # identity assignments of the form "x = y" inside the loop body # specially if y represents a value computed outside the loop body. @@ -1218,7 +1227,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]): # TODO: retrieve variable type for the annotation if any. typeinfo = None self.ir_builder.add_output( - self.current_fn, ov, typeinfo, self.source_of(loop_stmt) + self._current_fn, ov, typeinfo, self.source_of(loop_stmt) ) body = self.exit_scope() inputs = [o_loop_bound, o_true] + [ @@ -1245,19 +1254,19 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None): if pvar in self.current_scope(): pv_val = self.current_scope()[pvar] output = self.to_onnx_var(pv_val, pvar) - if output not in self.current_fn.assigned_names: + if output not in self._current_fn.assigned_names: # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self.emit_copy(output, pvar) self.ir_builder.add_output( - self.current_fn, + self._current_fn, output, pv_val.typeinfo, self.source_of(info_stmt), ) else: pv_val = None - for scope in self.locals: # TODO: skip current_scope + for scope in self._locals: # TODO: skip current_scope if pvar in scope: pv_val = scope[pvar] break @@ -1265,7 +1274,7 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None): self.fail( stmts[0], f"Variable {pvar} is not assigned a value along a conditional " - f"branch, known variables: {list(self.locals)}.", + f"branch, known variables: {list(self._locals)}.", ) # introduce a copy ovar = self.generate_unique_name(pvar) @@ -1278,7 +1287,7 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None): # TODO: retrieve the annotation if any. typeinfo = None self.ir_builder.add_output( - self.current_fn, ovar, typeinfo, self.source_of(info_stmt) + self._current_fn, ovar, typeinfo, self.source_of(info_stmt) ) graph = self.exit_scope() return graph.to_graph_and_functions() @@ -1294,15 +1303,15 @@ def translate_nested_function_def(self, fn: ast.FunctionDef): ] self.bind(fn.name, function_ir) # TODO: Does not yet handle nested functions within nested functions. - self.current_fn.add_nested_function(function_ir) + self._current_fn.add_nested_function(function_ir) - def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: - logger.debug("Converter:translate_function_def:%s", fn.name) + def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: + """Translate a function signature.""" args = fn.args if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg: warn(f"{fn.name}: Unsupported feature in function signature.") domain = self.this_module.domain - self.current_fn = self.ir_builder.new_function(fn.name, domain, True) + self._current_fn = self.ir_builder.new_function(fn.name, domain, True) for i, x in enumerate(args.args): arg_with_default_start_index = len(args.args) - len(args.defaults) if args.defaults and i >= arg_with_default_start_index: @@ -1324,15 +1333,15 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): self.ir_builder.add_attr_parameter( - self.current_fn, + self._current_fn, x.arg, ta.pytype_to_attrtype(typeinfo), default_value, ) self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x))) else: - self.ir_builder.add_input(self.current_fn, x.arg, typeinfo, self.source_of(x)) - self.used_vars.add(x.arg) + self.ir_builder.add_input(self._current_fn, x.arg, typeinfo, self.source_of(x)) + self._used_vars.add(x.arg) self.bind( x.arg, values.Dynamic(x.arg, values.DynamicKind.Input, self.source_of(x)), @@ -1352,9 +1361,15 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: self.returntype = None else: self.returntype = None + + return self._current_fn + + def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: + logger.debug("Converter:translate_function_def:%s", fn.name) + _ = self.translate_function_signature(fn) for i, s in enumerate(fn.body): self.translate_stmt(s, index_of_stmt=i) - return self.current_fn + return self._current_fn def top_level_stmt(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: if isinstance(stmt, ast.FunctionDef): From 2ab1af3a69ab669f0c73770fc89d908fea06f886 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Apr 2023 03:58:11 +0000 Subject: [PATCH 24/28] Update base for Update on "Refactor converter to isolate translate_function_signature logic | feat(converter)" This change refactors the `translate_function_def` method in `Converter` to isolate the signature handling logic to `translate_function_signature`. `translate_function_signature` is used in #674 to handle function signatures so we do not need to translate the function body for general python functions incompatible with ONNX. Signed-off-by: Justin Chu [ghstack-poisoned] --- .../tests/function_libs/torch_lib/ops_test.py | 49 +++--- onnxscript/values.py | 149 +++++++++--------- 2 files changed, 99 insertions(+), 99 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 9191db0a3d..106e794839 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -17,7 +17,7 @@ import unittest import warnings -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence import numpy as np import onnx @@ -55,14 +55,23 @@ def _should_skip_test_sample(op_name: str, sample) -> Optional[str]: return None +def _split_function_and_wrangler( + onnx_function_and_wrangler: Callable[..., Any] + | tuple[Callable[..., Any], Callable[..., Any]] +) -> tuple[Callable[..., Any], Callable[..., Any] | None]: + """Splits a function with an optional input wrangler into a function and an input wrangler.""" + if isinstance(onnx_function_and_wrangler, tuple): + return onnx_function_and_wrangler + + assert callable(onnx_function_and_wrangler) + return onnx_function_and_wrangler, None + + class TestFunctionValidity(unittest.TestCase): def test_all_script_functions_are_onnx_functions(self): functions = set() for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.values(): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) functions.add(func) # TODO(justinchuby): Add from the registry @@ -76,10 +85,7 @@ def test_all_script_functions_are_onnx_functions(self): def test_all_trace_only_functions_are_not_onnx_functions(self): for func_with_wrangler in ops_test_data.OPINFO_FUNCTION_MAPPING_TRACE_ONLY.values(): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) if isinstance(func, onnxscript.OnnxFunction): raise AssertionError( f"'{func.name}' is an OnnxFunction. " @@ -95,10 +101,7 @@ def test_all_trace_only_functions_are_not_onnx_functions(self): "Function checker is not available before ONNX 1.14", ) def test_script_function_passes_checker(self, _, func_with_wrangler): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) function_proto = func.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @@ -110,10 +113,7 @@ def test_script_function_passes_checker(self, _, func_with_wrangler): "OpSchema is not writable before ONNX 1.15", ) def test_script_function_has_op_schema(self, _, func_with_wrangler): - if isinstance(func_with_wrangler, tuple): - func = func_with_wrangler[0] - else: - func = func_with_wrangler + func, _ = _split_function_and_wrangler(func_with_wrangler) schema = func.opschema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) @@ -143,16 +143,11 @@ def run_test_output_match( ) onnx_function_and_wrangler = ops_test_data.OPINFO_FUNCTION_MAPPING[op.name] - input_wrangler = None - if isinstance(onnx_function_and_wrangler, tuple): - # Obtain the input_wrangler that manipulates the OpInfo inputs - # to match the aten operator signature - # An example is nn.functional.upsample_nearest2d, which has a different signature - # than the aten operator upsample_nearest2d - onnx_function, input_wrangler = onnx_function_and_wrangler - else: - assert callable(onnx_function_and_wrangler) - onnx_function = onnx_function_and_wrangler + # Obtain the input_wrangler that manipulates the OpInfo inputs + # to match the aten operator signature + # An example is nn.functional.upsample_nearest2d, which has a different signature + # than the aten operator upsample_nearest2d + onnx_function, input_wrangler = _split_function_and_wrangler(onnx_function_and_wrangler) for i, cpu_sample in enumerate(samples): inputs = (cpu_sample.input, *cpu_sample.args) diff --git a/onnxscript/values.py b/onnxscript/values.py index 82be4cc42e..10df0019cf 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -274,6 +274,82 @@ def as_tuple(self) -> tuple[str, list[str], str]: return (self.name, self.allowed_types, self.description) +def op_schema_from_function_ir( + function_ir: irbuilder.IRFunction, opset: Opset +) -> onnx.defs.OpSchema: + """Construct an ONNX OpSchema from an IRFunction.""" + + # Find all distinct types in the inputs and outputs + distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union( + {arg.typeinfo for arg in function_ir.outputs} + ) + # Create a mapping from type to a unique name + type_to_constraint = {} + for i, type_ in enumerate(distinct_types): + name = f"T{i}" + type_to_constraint[type_] = TypeConstraint( + name=type_annotation.get_type_constraint_name(type_) or name, + allowed_types=type_annotation.pytype_to_input_strings(type_), + ) + + formal_inputs = [ + onnx.defs.OpSchema.FormalParameter( + arg.name, + type_to_constraint[arg.typeinfo].name, + param_option=( + onnx.defs.OpSchema.FormalParameterOption.Optional + if type_annotation.is_optional(arg.typeinfo) + else onnx.defs.OpSchema.FormalParameterOption.Single + ), + # TODO(justinchu): Check this is_homogeneous thing + is_homogeneous=True, + ) + for arg in function_ir.inputs + ] + formal_outputs = [ + onnx.defs.OpSchema.FormalParameter( + arg.name, + type_to_constraint[arg.typeinfo].name, + param_option=( + onnx.defs.OpSchema.FormalParameterOption.Optional + if type_annotation.is_optional(arg.typeinfo) + else onnx.defs.OpSchema.FormalParameterOption.Single + ), + # TODO(justinchu): Check this is_homogeneous thing + is_homogeneous=True, + ) + for arg in function_ir.outputs + ] + + return onnx.defs.OpSchema( + function_ir.name, + opset.domain, + since_version=opset.version, + doc=function_ir.docstring, + inputs=formal_inputs, + outputs=formal_outputs, + type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], + attributes=[ + *[ + onnx.defs.OpSchema.Attribute( + attr.name, + type=onnx.defs.OpSchema.AttrType(attr.type), + ) + for attr in function_ir.attrs + if not attr.has_default + ], + *[ + onnx.defs.OpSchema.Attribute( + attr.name, + default_value=attr.attr_proto, + ) + for attr in function_ir.attrs + if attr.has_default + ], + ], + ) + + class OnnxFunction(Op): """Represents an ONNX op for which a function-body has been defined in onnxscript. @@ -317,78 +393,7 @@ def opschema(self) -> Optional[onnx.defs.OpSchema]: if not _ONNX_OP_SCHEMA_WRITABLE: return None - function_ir = self.function_ir - # Find all distinct types in the inputs and outputs - distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union( - {arg.typeinfo for arg in function_ir.outputs} - ) - # Create a mapping from type to a unique name - type_to_constraint = {} - for i, type_ in enumerate(distinct_types): - name = f"T{i}" - type_to_constraint[type_] = TypeConstraint( - name=type_annotation.get_type_constraint_name(type_) or name, - allowed_types=type_annotation.pytype_to_input_strings(type_), - ) - - formal_inputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[arg.typeinfo].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.inputs - ] - formal_outputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[arg.typeinfo].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.outputs - ] - - self._opschema = onnx.defs.OpSchema( - self.name, - self.opset.domain, - since_version=self.opset.version, - doc=function_ir.docstring, - inputs=formal_inputs, - outputs=formal_outputs, - type_constraints=[ - constraint.as_tuple() for constraint in type_to_constraint.values() - ], - attributes=[ - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - type=onnx.defs.OpSchema.AttrType(attr.type), - ) - for attr in function_ir.attrs - if not attr.has_default - ], - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - default_value=attr.attr_proto, - ) - for attr in function_ir.attrs - if attr.has_default - ], - ], - ) + self._opschema = op_schema_from_function_ir(self.function_ir, self.opset) return self._opschema From b9f7a5075a8d21f5c2718c5b369b9138ebc77bbc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Apr 2023 04:08:31 +0000 Subject: [PATCH 25/28] Update base for Update on "Refactor converter to isolate translate_function_signature logic | feat(converter)" This change refactors the `translate_function_def` method in `Converter` to isolate the signature handling logic to `translate_function_signature`. `translate_function_signature` is used in #674 to handle function signatures so we do not need to translate the function body for general python functions incompatible with ONNX. Signed-off-by: Justin Chu [ghstack-poisoned] --- onnxscript/irbuilder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 9c4e5f5d9e..443f18412e 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -526,5 +526,7 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut proto = onnx.AttributeProto() proto.name = attrname proto.ref_attr_name = refname - proto.type = ta.pytype_to_attrtype(pytype) + attr_type = ta.pytype_to_attrtype(pytype) + assert attr_type is not None + proto.type = attr_type return IRAttributeValue(proto) From a491eb83773386e501edce803f5a7adbffbda0ea Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 28 Apr 2023 02:47:42 +0000 Subject: [PATCH 26/28] Update base for Update on "Refactor converter to isolate translate_function_signature logic | feat(converter)" This change refactors the `translate_function_def` method in `Converter` to isolate the signature handling logic to `translate_function_signature`. `translate_function_signature` is used in #674 to handle function signatures so we do not need to translate the function body for general python functions incompatible with ONNX. Signed-off-by: Justin Chu [ghstack-poisoned] --- onnxscript/autocast.py | 2 +- onnxscript/irbuilder.py | 4 +- onnxscript/onnx_types.py | 4 - .../tests/function_libs/torch_lib/ops_test.py | 13 -- onnxscript/type_annotation.py | 150 +++-------------- onnxscript/type_annotation_test.py | 159 +----------------- onnxscript/values.py | 126 +------------- 7 files changed, 32 insertions(+), 426 deletions(-) diff --git a/onnxscript/autocast.py b/onnxscript/autocast.py index 7dd91304ae..edd0f9bfb3 100644 --- a/onnxscript/autocast.py +++ b/onnxscript/autocast.py @@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor: return cast_inputs(get_type_info, cast, op_schema, *args) -def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args): +def static_cast_inputs(converter, op_schema: OpSchema, *args): """Used for autocast during script-translation.""" if op_schema is None: return args diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 443f18412e..9c4e5f5d9e 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -526,7 +526,5 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut proto = onnx.AttributeProto() proto.name = attrname proto.ref_attr_name = refname - attr_type = ta.pytype_to_attrtype(pytype) - assert attr_type is not None - proto.type = attr_type + proto.type = ta.pytype_to_attrtype(pytype) return IRAttributeValue(proto) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 4951a3ff18..32e48bf744 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -105,10 +105,6 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = [cls.shape] # example: "FLOAT[10]" return onnx.helper.make_tensor_type_proto(cls.dtype, shape) - @classmethod - def to_string(cls) -> str: - return f"tensor({cls.__name__.lower()})" - class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): pass diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index df3da0a0e8..5b317aebd2 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -106,19 +106,6 @@ def test_script_function_passes_checker(self, _, func_with_wrangler): function_proto = func.to_function_proto() onnx.checker.check_function(function_proto) # type: ignore[attr-defined] - @parameterized.parameterized.expand( - list(ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.items()) - ) - @unittest.skipIf( - version_utils.onnx_older_than("1.15"), - "OpSchema is not writable before ONNX 1.15", - ) - def test_script_function_has_op_schema(self, _, func_with_wrangler): - func, _ = _split_function_and_wrangler(func_with_wrangler) - schema = func.opschema - self.assertIsNotNone(schema) - self.assertEqual(schema.name, func.name) - def run_test_output_match( test_suite: unittest.TestCase, diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 9f78222a87..c5e9f0a704 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -7,21 +7,21 @@ import collections import inspect import typing -from typing import Optional, Sequence, Union import onnx +from typing_extensions import get_args, get_origin -from onnxscript import onnx_types +from onnxscript.onnx_types import TensorType # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports -# - float, int, str (primitive attribute types) -# - Sequence[float], Sequence[int], Sequence[str] (attribute types) -# - Tensor types -# - Sequence[Tensor] types -# - Union of above 2 -# - TypeVars with above bounds -# - Above types with annotation attached +# * float, int, str (primitive attribute types) +# * Sequence[float], Sequence[int], Sequence[str] (attribute types) +# * Tensor types +# * Sequence[Tensor] types +# * Union of above 2 +# * TypeVars with above bounds +# * Above types with annotation attached TypeAnnotationValue = typing.Any # Map from python type to corresponding ONNX AttributeProto type @@ -43,31 +43,13 @@ _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) -# A sorted list of all type strings used in an OpSchema -ALL_TENSOR_TYPE_STRINGS = ( - "tensor(bfloat16)", - "tensor(bool)", - "tensor(double)", - "tensor(float)", - "tensor(float16)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(int8)", - "tensor(string)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)", - "tensor(uint8)", -) - def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: """Remove Annotated wrapper if present, otherwise return typeinfo as is.""" if hasattr(typing, "Annotated"): # Present in Python 3.9+ - if typing.get_origin(typeinfo) is typing.Annotated: - return typing.get_args(typeinfo)[0] + if get_origin(typeinfo) is typing.Annotated: + return get_args(typeinfo)[0] return typeinfo @@ -77,28 +59,28 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> Optional[onnx.AttributeProto.AttributeType]: +) -> onnx.AttributeProto.AttributeType: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] - type_constructor = typing.get_origin(pytype) + type_constructor = get_origin(pytype) # Remove Optional wrapper if present, which is represented as an Union[..., type(None)] if type_constructor is typing.Union: # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)] - args = [x for x in typing.get_args(pytype) if x is not type(None)] + args = [x for x in get_args(pytype) if x is not type(None)] if len(args) == 1: return pytype_to_attrtype(args[0]) if type_constructor in _LIST_CONSTRUCTORS: - elt_type = typing.get_args(pytype)[0] + elt_type = get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return None + return onnx.AttributeProto.UNDEFINED def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: - if isinstance(typeinfo, onnx_types.TensorType): + if isinstance(typeinfo, TensorType): return True - if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): + if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType): return True return False @@ -112,16 +94,16 @@ def is_value_type(typeinfo: TypeAnnotationValue) -> bool: return True if _is_primitive_attr_type(typeinfo): return False - type_constructor = typing.get_origin(typeinfo) + type_constructor = get_origin(typeinfo) # Handle List-like type-constructor # Eg. List[INT32] is a value type, while List[int] is an attribute type if type_constructor in _LIST_CONSTRUCTORS: - elt_type = typing.get_args(typeinfo)[0] + elt_type = get_args(typeinfo)[0] return is_value_type(elt_type) # Handle Union and Optional type-constructors if type_constructor is typing.Union: # Filter out None, since typing.Optional[X] evaluates to Union[X, None] - args = [x for x in typing.get_args(typeinfo) if x is not type(None)] + args = [x for x in get_args(typeinfo) if x is not type(None)] args_value_check = [is_value_type(x) for x in args] if all(args_value_check): # Handles cases like Optional[INT32] as well as Union[FLOAT16, FLOAT, DOUBLE] @@ -151,12 +133,7 @@ def is_valid_type(typeinfo: TypeAnnotationValue): return False -def is_optional(pytype) -> bool: - """Returns whether a pytype is an Optional.""" - return typing.get_origin(pytype) is typing.Union and type(None) in typing.get_args(pytype) - - -def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]: +def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[type]: """Converts return-type annotation into a sequence of types. The return type annotation can be either a single type (for a single output) @@ -166,85 +143,6 @@ def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]: """ if isinstance(typeinfo, typing.Sequence): return typeinfo - if typing.get_origin(typeinfo) is tuple: - return typing.get_args(typeinfo) + if get_origin(typeinfo) is tuple: + return get_args(typeinfo) return (typeinfo,) - - -def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]: - """Returns a list of type-strings corresponding to a given type annotation. - - Args: - pytype: A type annotation. - - Returns: - A list of all supported input types for the given type annotation. - Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS. - """ - if pytype is None: - return list(ALL_TENSOR_TYPE_STRINGS) - if pytype is onnx_types.TensorType: - return list(ALL_TENSOR_TYPE_STRINGS) - if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType): - return [pytype.to_string()] - if isinstance(pytype, onnx_types.TensorType): - return [pytype.to_string()] - if isinstance(pytype, typing.TypeVar): - constraints = pytype.__constraints__ - if constraints: - return pytype_to_type_strings(Union.__getitem__(constraints)) - bound = pytype.__bound__ - if bound is None: - return list(ALL_TENSOR_TYPE_STRINGS) - return pytype_to_type_strings(bound) - if typing.get_origin(pytype) is Union: - options = [] - subtypes = typing.get_args(pytype) - # A None type in a Union is equivalent to an optional type - optional = is_optional(pytype) - for subtype in subtypes: - if subtype is type(None): - # Skip None type because we are handling it with is_optional - continue - if optional: - options += [ - *pytype_to_type_strings(subtype), - *[f"optional({s})" for s in pytype_to_type_strings(subtype)], - ] - else: - options += pytype_to_type_strings(subtype) - # Remove duplicates - return sorted(set(options)) - if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: - subtypes = typing.get_args(pytype) - return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])] - - raise ValueError(f"Unsupported type: {pytype}") - - -def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: - """Returns the name of the type constraint for a given type annotation. - - Args: - pytype: A type annotation. - - Returns: - The name of the type constraint if it is a TypeVar. - - Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar]. - - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. - - Returns None if the type annotation does not have a type constraint. - """ - if isinstance(pytype, typing.TypeVar): - return pytype.__name__ - if is_optional(pytype): - subtypes = typing.get_args(pytype) - for subtype in subtypes: - if subtype is type(None): - continue - type_param_name = get_type_constraint_name(subtype) - return f"Optional_{type_param_name}" if type_param_name else None - if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: - subtypes = typing.get_args(pytype) - type_param_name = get_type_constraint_name(subtypes[0]) - return f"Sequence_{type_param_name}" if type_param_name else None - return None diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index a4b64485e1..17f04deff8 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -4,14 +4,11 @@ # -------------------------------------------------------------------------- import unittest -from typing import Any, List, Optional, Sequence, TypeVar, Union -import parameterized - -import onnxscript import onnxscript.testing -from onnxscript import FLOAT, INT64, script, type_annotation +from onnxscript import script from onnxscript.onnx_opset import opset15 as op +from onnxscript.onnx_types import FLOAT from onnxscript.tests.common import testutils @@ -91,157 +88,5 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: ) -_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) -_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) -_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) - - -class TypeConversionFunctionsTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ( - "tensor_type_all", - onnxscript.onnx_types.TensorType, - list(type_annotation.ALL_TENSOR_TYPE_STRINGS), - ), - ("none", None, list(type_annotation.ALL_TENSOR_TYPE_STRINGS)), - ("tensor_type", INT64, ["tensor(int64)"]), - ("tensor_type_union", Union[INT64, FLOAT], ["tensor(float)", "tensor(int64)"]), - ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), - ("tensor_type_shape", INT64[10], ["tensor(int64)"]), - ( - "type_var_constraints", - _TestTypeVarConstraints, - ["tensor(float)", "tensor(int64)"], - ), - ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), - ("type_bound_two", _TestTypeVarTwoBound, ["tensor(float)", "tensor(int64)"]), - ( - "optional_tensor_type_all", - Optional[onnxscript.onnx_types.TensorType], - [ - *[ - f"optional({tensor_type})" - for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS - ], - *type_annotation.ALL_TENSOR_TYPE_STRINGS, - ], - ), - ( - "optional_tensor_type", - Optional[INT64], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_tensor_type_union", - Optional[Union[INT64, FLOAT]], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "optional_tensor_type_variadic_shape", - Optional[INT64[...]], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_tensor_type_shape", - Optional[INT64[10]], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_type_var_constraints", - Optional[_TestTypeVarConstraints], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "optional_type_bound_one", - Optional[_TestTypeVarOneBound], - ["optional(tensor(int64))", "tensor(int64)"], - ), - ( - "optional_type_bound_two", - Optional[_TestTypeVarTwoBound], - [ - "optional(tensor(float))", - "optional(tensor(int64))", - "tensor(float)", - "tensor(int64)", - ], - ), - ( - "sequence_type_all", - Sequence[onnxscript.onnx_types.TensorType], - [ - f"seq({tensor_type})" - for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS - ], - ), - ("sequence_type", Sequence[INT64], ["seq(tensor(int64))"]), - ( - "union_sequence_type", - Union[Sequence[INT64], Sequence[FLOAT]], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ( - "sequence_type_variadic_shape", - Sequence[INT64[...]], - ["seq(tensor(int64))"], - ), - ("sequence_type_shape", Sequence[INT64[10]], ["seq(tensor(int64))"]), - ( - "sequence_type_var_constraints", - Sequence[_TestTypeVarConstraints], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ( - "sequence_type_bound_one", - Sequence[_TestTypeVarOneBound], - ["seq(tensor(int64))"], - ), - ( - "sequence_type_bound_two", - Sequence[_TestTypeVarTwoBound], - ["seq(tensor(float))", "seq(tensor(int64))"], - ), - ] - ) - def test_pytype_to_type_strings(self, _, pytype: Any, expected: List[str]): - self.assertEqual(type_annotation.pytype_to_type_strings(pytype), expected) - - @parameterized.parameterized.expand( - [ - ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), - ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), - ( - "optional_type_var", - Optional[_TestTypeVarOneBound], - "Optional__TestTypeVarOneBound", - ), - ( - "sequence_type_var", - Sequence[_TestTypeVarOneBound], - "Sequence__TestTypeVarOneBound", - ), - ("normal_type", INT64, None), - ("union_type", Union[INT64, FLOAT], None), - ("optional_type", Optional[INT64], None), - ("sequence_type", Sequence[INT64], None), - ("optional_sequence_type", Optional[Sequence[INT64]], None), - ("optional_union_type", Optional[Union[INT64, FLOAT]], None), - ] - ) - def test_get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): - self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) - - if __name__ == "__main__": unittest.main() diff --git a/onnxscript/values.py b/onnxscript/values.py index 10df0019cf..eb766d4b9e 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -8,14 +8,12 @@ import logging import types from enum import IntFlag -from typing import _GenericAlias # type: ignore[attr-defined] -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, _GenericAlias # type: ignore[attr-defined] import onnx import onnx.defs -from onnxscript import irbuilder, sourceinfo, type_annotation -from onnxscript._internal import version_utils +from onnxscript import irbuilder, sourceinfo _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, @@ -36,7 +34,6 @@ # A special value to indicate that the default value is not specified _EmptyDefault = object() -_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14") class Opset: @@ -176,7 +173,7 @@ def __init__( ) -> None: self.opset = opset self.opname = opname - self._opschema = opschema + self.opschema = opschema self._param_schemas: Optional[tuple[ParamSchema, ...]] = None def __call__(self, *args, **kwargs): @@ -193,13 +190,9 @@ def __call__(self, *args, **kwargs): def is_single_op(self) -> bool: return isinstance(self.opname, str) - @property - def opschema(self) -> Optional[onnx.defs.OpSchema]: - return self._opschema - def get_schema(self) -> Optional[onnx.defs.OpSchema]: """Returns the ONNX OpSchema for this op.""" - if self.opschema is not None: + if self.opschema: return self.opschema return self.opset[self.opname] @@ -256,100 +249,6 @@ class OnnxClosure: function: Any -@dataclasses.dataclass -class TypeConstraint: - """Represents a type constraint for an ONNX op. - - Attributes: - name: The name of the type constraint. - allowed_types: The allowed types for the type constraint. - """ - - name: str - allowed_types: list[str] - description: str = "" - - def as_tuple(self) -> tuple[str, list[str], str]: - """Returns the type constraint as a tuple.""" - return (self.name, self.allowed_types, self.description) - - -def op_schema_from_function_ir( - function_ir: irbuilder.IRFunction, opset: Opset -) -> onnx.defs.OpSchema: - """Construct an ONNX OpSchema from an IRFunction.""" - - # Find all distinct types in the inputs and outputs - distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union( - {arg.typeinfo for arg in function_ir.outputs} - ) - # Create a mapping from type to a unique name - type_to_constraint = {} - for i, type_ in enumerate(distinct_types): - name = f"T{i}" - type_to_constraint[type_] = TypeConstraint( - name=type_annotation.get_type_constraint_name(type_) or name, - allowed_types=type_annotation.pytype_to_input_strings(type_), - ) - - formal_inputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[arg.typeinfo].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.inputs - ] - formal_outputs = [ - onnx.defs.OpSchema.FormalParameter( - arg.name, - type_to_constraint[arg.typeinfo].name, - param_option=( - onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) - else onnx.defs.OpSchema.FormalParameterOption.Single - ), - # TODO(justinchu): Check this is_homogeneous thing - is_homogeneous=True, - ) - for arg in function_ir.outputs - ] - - return onnx.defs.OpSchema( - function_ir.name, - opset.domain, - since_version=opset.version, - doc=function_ir.docstring, - inputs=formal_inputs, - outputs=formal_outputs, - type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], - attributes=[ - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - type=onnx.defs.OpSchema.AttrType(attr.type), - ) - for attr in function_ir.attrs - if not attr.has_default - ], - *[ - onnx.defs.OpSchema.Attribute( - attr.name, - default_value=attr.attr_proto, - ) - for attr in function_ir.attrs - if attr.has_default - ], - ], - ) - - class OnnxFunction(Op): """Represents an ONNX op for which a function-body has been defined in onnxscript. @@ -377,26 +276,12 @@ def __init__( self.source = source self.kwargs = kwargs self._param_schemas: Optional[tuple[ParamSchema, ...]] = None - self._opschema: Optional[onnx.defs.OpSchema] = None @property def name(self): """Returns the function name.""" return self.opname - @property - def opschema(self) -> Optional[onnx.defs.OpSchema]: - """Construct an OpSchema from function_ir.""" - if self._opschema is not None: - return self._opschema - - if not _ONNX_OP_SCHEMA_WRITABLE: - return None - - self._opschema = op_schema_from_function_ir(self.function_ir, self.opset) - - return self._opschema - def __getitem__(self, instance): """Returns a lambda to evaluate function using given evaluator instance. @@ -426,9 +311,6 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: if self._param_schemas is not None: return self._param_schemas - # NOTE: We generate the parameter schemas from the function_ir instead - # of relying on the auto generated OpSchema because we need to preserve the keyword - # argument order from the Python function definition, which is lost in OpSchema. function_ir = self.function_ir # The first len(func_ir.inputs) arguments are onnx inputs inputs = function_ir.inputs From 88a14d05f59112cbbbbd4630911bed18bed15cff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 28 Apr 2023 02:48:18 +0000 Subject: [PATCH 27/28] Update base for Update on "Refactor converter to isolate translate_function_signature logic | feat(converter)" This change refactors the `translate_function_def` method in `Converter` to isolate the signature handling logic to `translate_function_signature`. `translate_function_signature` is used in #674 to handle function signatures so we do not need to translate the function body for general python functions incompatible with ONNX. Signed-off-by: Justin Chu [ghstack-poisoned] --- onnxscript/irbuilder.py | 4 +- onnxscript/onnx_types.py | 4 + onnxscript/type_annotation.py | 150 ++++++++++++++++++++++----- onnxscript/type_annotation_test.py | 159 ++++++++++++++++++++++++++++- 4 files changed, 290 insertions(+), 27 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 9c4e5f5d9e..443f18412e 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -526,5 +526,7 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut proto = onnx.AttributeProto() proto.name = attrname proto.ref_attr_name = refname - proto.type = ta.pytype_to_attrtype(pytype) + attr_type = ta.pytype_to_attrtype(pytype) + assert attr_type is not None + proto.type = attr_type return IRAttributeValue(proto) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 32e48bf744..4951a3ff18 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -105,6 +105,10 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = [cls.shape] # example: "FLOAT[10]" return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + @classmethod + def to_string(cls) -> str: + return f"tensor({cls.__name__.lower()})" + class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT): pass diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index c5e9f0a704..9f78222a87 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -7,21 +7,21 @@ import collections import inspect import typing +from typing import Optional, Sequence, Union import onnx -from typing_extensions import get_args, get_origin -from onnxscript.onnx_types import TensorType +from onnxscript import onnx_types # TypeAnnotationValue represents the (value of) valid type-annotations recognized # by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports -# * float, int, str (primitive attribute types) -# * Sequence[float], Sequence[int], Sequence[str] (attribute types) -# * Tensor types -# * Sequence[Tensor] types -# * Union of above 2 -# * TypeVars with above bounds -# * Above types with annotation attached +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached TypeAnnotationValue = typing.Any # Map from python type to corresponding ONNX AttributeProto type @@ -43,13 +43,31 @@ _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) +# A sorted list of all type strings used in an OpSchema +ALL_TENSOR_TYPE_STRINGS = ( + "tensor(bfloat16)", + "tensor(bool)", + "tensor(double)", + "tensor(float)", + "tensor(float16)", + "tensor(int16)", + "tensor(int32)", + "tensor(int64)", + "tensor(int8)", + "tensor(string)", + "tensor(uint16)", + "tensor(uint32)", + "tensor(uint64)", + "tensor(uint8)", +) + def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: """Remove Annotated wrapper if present, otherwise return typeinfo as is.""" if hasattr(typing, "Annotated"): # Present in Python 3.9+ - if get_origin(typeinfo) is typing.Annotated: - return get_args(typeinfo)[0] + if typing.get_origin(typeinfo) is typing.Annotated: + return typing.get_args(typeinfo)[0] return typeinfo @@ -59,28 +77,28 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, -) -> onnx.AttributeProto.AttributeType: +) -> Optional[onnx.AttributeProto.AttributeType]: pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] - type_constructor = get_origin(pytype) + type_constructor = typing.get_origin(pytype) # Remove Optional wrapper if present, which is represented as an Union[..., type(None)] if type_constructor is typing.Union: # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)] - args = [x for x in get_args(pytype) if x is not type(None)] + args = [x for x in typing.get_args(pytype) if x is not type(None)] if len(args) == 1: return pytype_to_attrtype(args[0]) if type_constructor in _LIST_CONSTRUCTORS: - elt_type = get_args(pytype)[0] + elt_type = typing.get_args(pytype)[0] if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return onnx.AttributeProto.UNDEFINED + return None def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: - if isinstance(typeinfo, TensorType): + if isinstance(typeinfo, onnx_types.TensorType): return True - if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType): + if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False @@ -94,16 +112,16 @@ def is_value_type(typeinfo: TypeAnnotationValue) -> bool: return True if _is_primitive_attr_type(typeinfo): return False - type_constructor = get_origin(typeinfo) + type_constructor = typing.get_origin(typeinfo) # Handle List-like type-constructor # Eg. List[INT32] is a value type, while List[int] is an attribute type if type_constructor in _LIST_CONSTRUCTORS: - elt_type = get_args(typeinfo)[0] + elt_type = typing.get_args(typeinfo)[0] return is_value_type(elt_type) # Handle Union and Optional type-constructors if type_constructor is typing.Union: # Filter out None, since typing.Optional[X] evaluates to Union[X, None] - args = [x for x in get_args(typeinfo) if x is not type(None)] + args = [x for x in typing.get_args(typeinfo) if x is not type(None)] args_value_check = [is_value_type(x) for x in args] if all(args_value_check): # Handles cases like Optional[INT32] as well as Union[FLOAT16, FLOAT, DOUBLE] @@ -133,7 +151,12 @@ def is_valid_type(typeinfo: TypeAnnotationValue): return False -def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[type]: +def is_optional(pytype) -> bool: + """Returns whether a pytype is an Optional.""" + return typing.get_origin(pytype) is typing.Union and type(None) in typing.get_args(pytype) + + +def get_return_types(typeinfo: type | Sequence[type]) -> Sequence[type]: """Converts return-type annotation into a sequence of types. The return type annotation can be either a single type (for a single output) @@ -143,6 +166,85 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[ """ if isinstance(typeinfo, typing.Sequence): return typeinfo - if get_origin(typeinfo) is tuple: - return get_args(typeinfo) + if typing.get_origin(typeinfo) is tuple: + return typing.get_args(typeinfo) return (typeinfo,) + + +def pytype_to_type_strings(pytype: TypeAnnotationValue) -> list[str]: + """Returns a list of type-strings corresponding to a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + A list of all supported input types for the given type annotation. + Ensures that the list is sorted in the same order as ALL_TYPE_STRINGS. + """ + if pytype is None: + return list(ALL_TENSOR_TYPE_STRINGS) + if pytype is onnx_types.TensorType: + return list(ALL_TENSOR_TYPE_STRINGS) + if isinstance(pytype, type) and issubclass(pytype, onnx_types.TensorType): + return [pytype.to_string()] + if isinstance(pytype, onnx_types.TensorType): + return [pytype.to_string()] + if isinstance(pytype, typing.TypeVar): + constraints = pytype.__constraints__ + if constraints: + return pytype_to_type_strings(Union.__getitem__(constraints)) + bound = pytype.__bound__ + if bound is None: + return list(ALL_TENSOR_TYPE_STRINGS) + return pytype_to_type_strings(bound) + if typing.get_origin(pytype) is Union: + options = [] + subtypes = typing.get_args(pytype) + # A None type in a Union is equivalent to an optional type + optional = is_optional(pytype) + for subtype in subtypes: + if subtype is type(None): + # Skip None type because we are handling it with is_optional + continue + if optional: + options += [ + *pytype_to_type_strings(subtype), + *[f"optional({s})" for s in pytype_to_type_strings(subtype)], + ] + else: + options += pytype_to_type_strings(subtype) + # Remove duplicates + return sorted(set(options)) + if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: + subtypes = typing.get_args(pytype) + return [f"seq({s})" for s in pytype_to_type_strings(subtypes[0])] + + raise ValueError(f"Unsupported type: {pytype}") + + +def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]: + """Returns the name of the type constraint for a given type annotation. + + Args: + pytype: A type annotation. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Optional_" if the type annotation is Optional[TypeVar]. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + - Returns None if the type annotation does not have a type constraint. + """ + if isinstance(pytype, typing.TypeVar): + return pytype.__name__ + if is_optional(pytype): + subtypes = typing.get_args(pytype) + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = get_type_constraint_name(subtype) + return f"Optional_{type_param_name}" if type_param_name else None + if typing.get_origin(pytype) in _LIST_CONSTRUCTORS: + subtypes = typing.get_args(pytype) + type_param_name = get_type_constraint_name(subtypes[0]) + return f"Sequence_{type_param_name}" if type_param_name else None + return None diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 17f04deff8..a4b64485e1 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -4,11 +4,14 @@ # -------------------------------------------------------------------------- import unittest +from typing import Any, List, Optional, Sequence, TypeVar, Union +import parameterized + +import onnxscript import onnxscript.testing -from onnxscript import script +from onnxscript import FLOAT, INT64, script, type_annotation from onnxscript.onnx_opset import opset15 as op -from onnxscript.onnx_types import FLOAT from onnxscript.tests.common import testutils @@ -88,5 +91,157 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]: ) +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) + + +class TypeConversionFunctionsTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ( + "tensor_type_all", + onnxscript.onnx_types.TensorType, + list(type_annotation.ALL_TENSOR_TYPE_STRINGS), + ), + ("none", None, list(type_annotation.ALL_TENSOR_TYPE_STRINGS)), + ("tensor_type", INT64, ["tensor(int64)"]), + ("tensor_type_union", Union[INT64, FLOAT], ["tensor(float)", "tensor(int64)"]), + ("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]), + ("tensor_type_shape", INT64[10], ["tensor(int64)"]), + ( + "type_var_constraints", + _TestTypeVarConstraints, + ["tensor(float)", "tensor(int64)"], + ), + ("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]), + ("type_bound_two", _TestTypeVarTwoBound, ["tensor(float)", "tensor(int64)"]), + ( + "optional_tensor_type_all", + Optional[onnxscript.onnx_types.TensorType], + [ + *[ + f"optional({tensor_type})" + for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS + ], + *type_annotation.ALL_TENSOR_TYPE_STRINGS, + ], + ), + ( + "optional_tensor_type", + Optional[INT64], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_tensor_type_union", + Optional[Union[INT64, FLOAT]], + [ + "optional(tensor(float))", + "optional(tensor(int64))", + "tensor(float)", + "tensor(int64)", + ], + ), + ( + "optional_tensor_type_variadic_shape", + Optional[INT64[...]], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_tensor_type_shape", + Optional[INT64[10]], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_type_var_constraints", + Optional[_TestTypeVarConstraints], + [ + "optional(tensor(float))", + "optional(tensor(int64))", + "tensor(float)", + "tensor(int64)", + ], + ), + ( + "optional_type_bound_one", + Optional[_TestTypeVarOneBound], + ["optional(tensor(int64))", "tensor(int64)"], + ), + ( + "optional_type_bound_two", + Optional[_TestTypeVarTwoBound], + [ + "optional(tensor(float))", + "optional(tensor(int64))", + "tensor(float)", + "tensor(int64)", + ], + ), + ( + "sequence_type_all", + Sequence[onnxscript.onnx_types.TensorType], + [ + f"seq({tensor_type})" + for tensor_type in type_annotation.ALL_TENSOR_TYPE_STRINGS + ], + ), + ("sequence_type", Sequence[INT64], ["seq(tensor(int64))"]), + ( + "union_sequence_type", + Union[Sequence[INT64], Sequence[FLOAT]], + ["seq(tensor(float))", "seq(tensor(int64))"], + ), + ( + "sequence_type_variadic_shape", + Sequence[INT64[...]], + ["seq(tensor(int64))"], + ), + ("sequence_type_shape", Sequence[INT64[10]], ["seq(tensor(int64))"]), + ( + "sequence_type_var_constraints", + Sequence[_TestTypeVarConstraints], + ["seq(tensor(float))", "seq(tensor(int64))"], + ), + ( + "sequence_type_bound_one", + Sequence[_TestTypeVarOneBound], + ["seq(tensor(int64))"], + ), + ( + "sequence_type_bound_two", + Sequence[_TestTypeVarTwoBound], + ["seq(tensor(float))", "seq(tensor(int64))"], + ), + ] + ) + def test_pytype_to_type_strings(self, _, pytype: Any, expected: List[str]): + self.assertEqual(type_annotation.pytype_to_type_strings(pytype), expected) + + @parameterized.parameterized.expand( + [ + ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), + ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), + ( + "optional_type_var", + Optional[_TestTypeVarOneBound], + "Optional__TestTypeVarOneBound", + ), + ( + "sequence_type_var", + Sequence[_TestTypeVarOneBound], + "Sequence__TestTypeVarOneBound", + ), + ("normal_type", INT64, None), + ("union_type", Union[INT64, FLOAT], None), + ("optional_type", Optional[INT64], None), + ("sequence_type", Sequence[INT64], None), + ("optional_sequence_type", Optional[Sequence[INT64]], None), + ("optional_union_type", Optional[Union[INT64, FLOAT]], None), + ] + ) + def test_get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]): + self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected) + + if __name__ == "__main__": unittest.main() From 3e60a7883bc48ef62a39e6d99ca02e914d9d0fd2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 28 Apr 2023 02:55:56 +0000 Subject: [PATCH 28/28] Update base for Update on "Refactor converter to isolate translate_function_signature logic | feat(converter)" This change refactors the `translate_function_def` method in `Converter` to isolate the signature handling logic to `translate_function_signature`. `translate_function_signature` is used in #674 to handle function signatures so we do not need to translate the function body for general python functions incompatible with ONNX. Signed-off-by: Justin Chu [ghstack-poisoned] --- onnxscript/type_annotation.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 9f78222a87..1f87ee8878 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -44,21 +44,8 @@ _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) # A sorted list of all type strings used in an OpSchema -ALL_TENSOR_TYPE_STRINGS = ( - "tensor(bfloat16)", - "tensor(bool)", - "tensor(double)", - "tensor(float)", - "tensor(float16)", - "tensor(int16)", - "tensor(int32)", - "tensor(int64)", - "tensor(int8)", - "tensor(string)", - "tensor(uint16)", - "tensor(uint32)", - "tensor(uint64)", - "tensor(uint8)", +ALL_TENSOR_TYPE_STRINGS = tuple( + sorted(tensor_type.to_string() for tensor_type in onnx_types.tensor_type_registry.values()) )