From 3dbb59ea0ef25a767159c2fafd824485c752dc2a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Apr 2023 15:56:14 +0000 Subject: [PATCH] Create the OpLike protocol and refactor Op | feat(values) - Removes `is_single_op` because it is unused. Signed-off-by: Justin Chu [ghstack-poisoned] --- onnxscript/function_libs/torch_lib/tracing.py | 42 +++- onnxscript/irbuilder.py | 4 +- onnxscript/values.py | 196 +++++++++++------- 3 files changed, 153 insertions(+), 89 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/tracing.py b/onnxscript/function_libs/torch_lib/tracing.py index 3b61a9e60f..d89d6eb664 100644 --- a/onnxscript/function_libs/torch_lib/tracing.py +++ b/onnxscript/function_libs/torch_lib/tracing.py @@ -5,7 +5,8 @@ import inspect import textwrap import types -from typing import Optional +import typing +from typing import Optional, Tuple import onnx @@ -13,6 +14,9 @@ from onnxscript import converter as ons_converter from onnxscript._internal import version_utils +if typing.TYPE_CHECKING: + from onnxscript import irbuilder + _ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14") @@ -33,7 +37,7 @@ def _get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]: return src, f_ast -class TraceOnlyFunction: +class TraceOnlyFunction(onnxscript.values.OpLike): """TraceOnlyFunction. Attributes: @@ -44,9 +48,11 @@ class TraceOnlyFunction: def __init__(self, opset: onnxscript.values.Opset, func: types.FunctionType): self._opset = opset self._func = func - self._opschema: Optional[onnx.defs.OpSchema] = None # Set the signature of the class to function's self.__signature__ = inspect.signature(func) + # Cached computed fields + self._opschema: Optional[onnx.defs.OpSchema] = None + self._param_schemas: Optional[Tuple[onnxscript.values.ParamSchema, ...]] = None def __call__(self, *args, **kwargs): return self._func(*args, **kwargs) @@ -72,13 +78,32 @@ def opset(self) -> onnxscript.values.Opset: @property def opschema(self) -> Optional[onnx.defs.OpSchema]: """Return the opschema.""" - if self._opschema is not None: return self._opschema - if not _ONNX_OP_SCHEMA_WRITABLE: return None + # FIXME(justinchuby): outputs are empty. Need to fix. + self._opschema = onnxscript.values.op_schema_from_function_ir( + self._function_ir(), self._opset + ) + + return self._opschema + + def param_schemas(self) -> tuple[onnxscript.values.ParamSchema, ...]: + """Generate param_schemas for the TraceOnlyFunction.""" + if self._param_schemas is None: + self._param_schemas = onnxscript.values.param_schemas_from_function_ir( + self._function_ir() + ) + + return self._param_schemas + + def _function_ir(self) -> irbuilder.IRFunction: + """Return the IRFunction of the function. + + This IRFunction contains only the function signature. + """ src, func_ast = _get_src_and_ast(self._func) module = inspect.getmodule(self._func) closure = inspect.getclosurevars(self._func) @@ -90,9 +115,4 @@ def opschema(self) -> Optional[onnx.defs.OpSchema]: source=src, ) - function_ir = converter.translate_function_signature(func_ast) - - # FIXME(justinchuby): outputs are empty. Need to fix. - self._opschema = onnxscript.values.op_schema_from_function_ir(function_ir, self._opset) - - return self._opschema + return converter.translate_function_signature(func_ast) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 443f18412e..98cd6a462b 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -202,7 +202,7 @@ def __str__(self): args = _format(self.args, "(", ", ", ")", _opt_var_to_str) domain = self.callee.opset.domain - opname = self.callee.opname + opname = self.callee.name callee = f"{domain}.{opname}" if (domain != "") else opname return f"{lhs} = {callee} {attrs}{args}" @@ -212,7 +212,7 @@ def debug_print(self): def to_node_proto(self, node_name: str) -> onnx.NodeProto: n = helper.make_node( - self.callee.opname, + self.callee.name, [_opt_var_to_str(x) for x in self.args], [str(x) for x in self.result], domain=self.callee.opset.domain, diff --git a/onnxscript/values.py b/onnxscript/values.py index 47591a6a15..2cac25080e 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -163,21 +163,105 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: return onnx.helper.get_attribute_value(attr_proto) -class Op: +def param_schemas_from_op_schema( + op_schema: onnx.defs.OpSchema, +) -> tuple[ParamSchema, ...]: + """Get the parameter schemas from an ONNX OpSchema.""" + schemas = [] + for input_ in op_schema.inputs: + param_schema = ParamSchema( + name=input_.name, + is_input=True, + required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional), + is_variadic_input=( + input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic + ), + ) + schemas.append(param_schema) + for attr_name, attribute in op_schema.attributes.items(): + default_attr_proto = attribute.default_value + param_schema = ParamSchema( + name=attr_name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type], + default=_get_attribute_value(default_attr_proto), + is_input=False, + required=attribute.required, + ) + schemas.append(param_schema) + + return tuple(schemas) + + +def param_schemas_from_function_ir( + function_ir: irbuilder.IRFunction, +) -> tuple[ParamSchema, ...]: + """Get the parameter schemas from a FunctionIR.""" + # The first len(func_ir.inputs) arguments are onnx inputs + # The rest is onnx attributes + + schemas = [] + for arg in function_ir.inputs: + if isinstance(arg.typeinfo, onnx.TypeProto.Optional): + required = False + else: + required = True + schemas.append( + ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required) + ) + + for attr_parameter in function_ir.attrs: + schemas.append( + ParamSchema( + name=attr_parameter.name, + type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( + onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg] + ), + default=_EmptyDefault + if attr_parameter.default_value is None + else attr_parameter.default_value, + is_input=False, + required=not attr_parameter.has_default, + ) + ) + + return tuple(schemas) + + +@typing.runtime_checkable +class OpLike(Protocol): + """A protocol for objects that have an ONNX OpSchema.""" + + @property + def name(self) -> str: + ... + + @property + def opset(self) -> Opset: + ... + + @property + def opschema(self) -> onnx.defs.OpSchema: + ... + + def param_schemas(self) -> tuple[ParamSchema, ...]: + ... + + +class Op(OpLike): """Represents an ONNX op instance (for example, the MatMul op from ONNX opset version 13). It belongs to a particular Opset and has a name. Attributes: opset: The Opset that this op belongs to. - opname: The name of the op. + name: The name of the op. opschema: The ONNX OpSchema for the op. """ def __init__( - self, opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None + self, opset: Opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None ) -> None: - self.opset = opset - self.opname = opname + self._opset = opset + self._name = opname self._opschema = opschema self._param_schemas: Optional[tuple[ParamSchema, ...]] = None @@ -188,12 +272,17 @@ def __call__(self, *args, **kwargs): schema = self.get_schema() if schema is None: raise RuntimeError( - f"Op '{self.opname}' does not have an OpSchema and cannot be evaluated." + f"Op '{self.name}' does not have an OpSchema and cannot be evaluated." ) return evaluator.default().eval(schema, args, kwargs) - def is_single_op(self) -> bool: - return isinstance(self.opname, str) + @property + def name(self) -> str: + return self._name + + @property + def opset(self) -> Opset: + return self._opset @property def opschema(self) -> Optional[onnx.defs.OpSchema]: @@ -203,7 +292,7 @@ def get_schema(self) -> Optional[onnx.defs.OpSchema]: """Returns the ONNX OpSchema for this op.""" if self.opschema is not None: return self.opschema - return self.opset[self.opname] + return self.opset[self.name] def has_schema(self) -> bool: """Returns True if this op has an OpSchema.""" @@ -217,30 +306,9 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: op_schema = self.get_schema() if op_schema is None: return None - schemas = [] - for input_ in op_schema.inputs: - param_schema = ParamSchema( - name=input_.name, - is_input=True, - required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional), - is_variadic_input=( - input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic - ), - ) - schemas.append(param_schema) - for attr_name, attribute in op_schema.attributes.items(): - default_attr_proto = attribute.default_value - param_schema = ParamSchema( - name=attr_name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type], - default=_get_attribute_value(default_attr_proto), - is_input=False, - required=attribute.required, - ) - schemas.append(param_schema) - self._param_schemas = tuple(schemas) - return self._param_schemas # type: ignore[return-value] + self._param_schemas = param_schemas_from_op_schema(op_schema) + return self._param_schemas @dataclasses.dataclass(repr=False, eq=False) @@ -355,13 +423,14 @@ def op_schema_from_function_ir( class OnnxFunction(Op): """Represents an ONNX op for which a function-body has been defined in onnxscript. - Args: - opset: opset the function belongs to - pyfun: python function - irfun: python code parsed by class - :class:`onnxscript.converter.Converter` - source: source code used to generate the function - kwargs: additional properties used to construct a ModelProto + Attributes: + opset: Opset the function belongs to. + name: Name of the function. + function: Python function. + function_ir: Python code parsed as an :class:`irbuilder.IRFunction`. + source: Source code used to generate the function. + kwargs: Additional properties used to construct a ModelProto. + opschema: Generated ONNX OpSchema for this op. """ def __init__( @@ -372,6 +441,16 @@ def __init__( source: str, kwargs: dict[str, Any], ): + """Constructs an OnnxFunction. + + Args: + opset: opset the function belongs to + pyfun: python function + irfun: python code parsed by class + :class:`onnxscript.converter.Converter` + source: source code used to generate the function + kwargs: additional properties used to construct a ModelProto + """ opset = opset or Opset(irfun.domain, 1) super().__init__(opset, irfun.name) self.function = pyfun @@ -383,11 +462,6 @@ def __init__( # Set the signature of the class to function's self.__signature__ = inspect.signature(pyfun) - @property - def name(self): - """Returns the function name.""" - return self.opname - @property def opschema(self) -> Optional[onnx.defs.OpSchema]: """Construct an OpSchema from function_ir.""" @@ -433,38 +507,8 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - function_ir = self.function_ir - # The first len(func_ir.inputs) arguments are onnx inputs - inputs = function_ir.inputs - # The rest is onnx attributes - - schemas = [] - for arg in inputs: - if isinstance(arg.typeinfo, onnx.TypeProto.Optional): - required = False - else: - required = True - schemas.append( - ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required) - ) - - for attr_parameter in function_ir.attrs: - schemas.append( - ParamSchema( - name=attr_parameter.name, - type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( - onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg] - ), - default=_EmptyDefault - if attr_parameter.default_value is None - else attr_parameter.default_value, - is_input=False, - required=not attr_parameter.has_default, - ) - ) - - self._param_schemas = tuple(schemas) - return self._param_schemas # type: ignore[return-value] + self._param_schemas = param_schemas_from_function_ir(self.function_ir) + return self._param_schemas def to_function_proto(self): """Converts the function into :class:`onnx.FunctionProto`."""