Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use IR in onnxscript #1409

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 66 additions & 200 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
import io
import logging
import warnings
from typing import Any, Optional, Protocol, Sequence, Union
from typing import Any, Mapping, Optional, Protocol, Sequence, Union

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused Protocol imported from typing (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

typing.Protocol imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import

import onnx
from onnx import ValueInfoProto, helper
from onnx.defs import onnx_opset_version

import onnxscript
from onnxscript import type_annotation as ta
from onnxscript import values
from onnxscript import values, ir, sourceinfo
from onnxscript._internal import version_utils
from onnxscript.onnx_types import ONNXType
from onnxscript.sourceinfo import SourceInfo
from onnxscript.ir import _convenience as ir_convenience

# A simple IR (Function, Stmt, Attr, Var):

Expand All @@ -40,39 +40,14 @@
return helper.OP_SET_ID_VERSION_MAP[domain, version]


class IRType:
def __init__(self):
self.onnx_type = onnx.TypeProto()

def to_type_proto(self):
return self.onnx_type

def __repr__(self) -> str:
return "IRType()"


class IRTensorType(IRType):
def __init__(self, elem_type: onnx.TensorProto.DataType) -> None:
super().__init__()
self.onnx_type.tensor_type.elem_type = elem_type

def __repr__(self) -> str:
return f"IRTensorType({self.onnx_type.tensor_type.elem_type})"


class IRTypeLike(Protocol):
def to_type_proto(self) -> onnx.TypeProto:
"""Converts IR type representation to onnx.TypeProto"""


class IRVar:
"""A variable (representing a formal parameter)."""

def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None:
def __init__(self, varname: str, typeinfo: ir.TypeProtocol, source_info: sourceinfo.SourceInfo) -> None:
if not isinstance(varname, str):
raise TypeError(f"varname must be a string not {type(varname)!r}.")
self.name = varname
self.info = sourceinfo
self.info = source_info
self.typeinfo = typeinfo

def __str__(self):
Expand All @@ -81,128 +56,34 @@
def __repr__(self):
return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})"

def typed_str(self):
return f"{self.name} : {self.typeinfo}"

def to_value_info(self, use_default_type: bool = True):
"""Converts the content of this class into :class:`onnx.ValueInfoProto`.
def typed_str(self) -> str:
return f"{self.name}: {self.typeinfo}"

Check warning on line 60 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L60

Added line #L60 was not covered by tests

Args:
use_default_type: if True, use a default type if an explicit type
is not known. Otherwise, returns a ValueInfoProto without type.

Returns:
an instance of :class:`onnx.ValueInfoProto`
"""
if self.name is None:
raise ValueError(self.info.msg("name cannot be None."))
value_info_proto = ValueInfoProto()
value_info_proto.name = self.name
if self.typeinfo is not None:
value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto())
elif use_default_type:
value_info_proto.type.CopyFrom(IRType().to_type_proto())
return value_info_proto


def _opt_var_to_str(x):
return "" if x is None else str(x)


class IRAttributeValue:
"""An attribute value (representing an actual parameter).

Attributes:
name: The name of the attribute.
type: The type of the attribute.
attr_proto: The attribute proto.
"""

def __init__(self, attrproto: onnx.AttributeProto) -> None:
self.attr_proto = attrproto

def __str__(self):
if self.attr_proto.HasField("ref_attr_name"):
return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}"
# self.name + " = " + self.value
return helper.printable_attribute(self.attr_proto)

@property
def name(self) -> str:
return self.attr_proto.name

@property
def type(self) -> onnx.AttributeProto.AttributeType:
return self.attr_proto.type


@dataclasses.dataclass(frozen=True)
class IRAttributeParameter:
"""An attribute parameter (representing a formal parameter).

It may or may not carry a default value.

Attributes:
name: The name of the attribute.
type: The type of the attribute.
default_value: The default value of the attribute.
has_default: Whether the attribute has a default value.
attr_proto: The attribute proto.
"""

name: str
type: onnx.AttributeProto.AttributeType
default_value: str | int | float | None = None

# TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.

def __str__(self):
if self.has_default:
return helper.printable_attribute(self.attr_proto)
# TODO(justinchuby): Include a readable type name.
return self.name

@property
def has_default(self):
return self.default_value is not None

@property
def attr_proto(self) -> onnx.AttributeProto:
if not self.has_default:
raise ValueError(
"Attribute has no default value. Only attributes with default "
"values can be converted to AttributeProto."
)
if version_utils.onnx_older_than("1.15"):
# TODO(after 1.14 is deprecated): Remove this branch.
# Argument 'attr_type' was added after version 1.14.
return helper.make_attribute(self.name, self.default_value)
# pylint: disable=unexpected-keyword-arg
return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg]
# pylint: enable=unexpected-keyword-arg


class IRStmt:
def __init__(
self,
result: Sequence[str],
callee: values.Op,
args: Sequence[Optional[str]],
attrs: Sequence[IRAttributeValue],
args: Sequence[str],
attrs: Sequence[ir.Attr | ir.RefAttr],
sub_functions=None,
) -> None:
if not isinstance(callee, values.Op):
raise TypeError(f"Unexpected type {type(callee)} for callee.")
self.result = result
self._output_names = result
self.callee = callee
self.args = args
self.attrs = attrs
self.functions = sub_functions or {}

def __str__(self):
if isinstance(self.result, str):
logger.debug("unexpected str type for self.result where type(self)=%r", type(self))
lhs = ", ".join(self.result)
lhs = ", ".join(self._output_names)

Check warning on line 86 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L86

Added line #L86 was not covered by tests
attrs = ""
if self.attrs:
attrs = _format(self.attrs, "<", ", ", ">")
Expand All @@ -217,22 +98,29 @@
if logger.isEnabledFor(logging.DEBUG):
logger.debug("%s: %s", type(self), str(self))

def to_node_proto(self, node_name: str) -> onnx.NodeProto:
n = helper.make_node(
self.callee.name,
[_opt_var_to_str(x) for x in self.args],
[str(x) for x in self.result],
def to_node(self, node_name: str, values: Mapping[str, ir.Value]) -> ir.Node:

Check warning

Code scanning / lintrunner

PYLINT/W0621 Warning

Redefining name 'values' from outer scope (line 19) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name
"""
Converts this statement into a node in the IR.

Args:
node_name: The name of the node.
values: A dictionary mapping value names to values.
"""
node = ir.Node(
domain=self.callee.opset.domain,
op_type=self.callee.name,
inputs=[values[x] if x != "" else None for x in self.args],
name=node_name,
attributes=self.attrs,
)
for a in self.attrs:
n.attribute.append(a.attr_proto)
return n
for name, output in zip(self._output_names, node.outputs):
output.name = name
return node

Check warning on line 118 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L117-L118

Added lines #L117 - L118 were not covered by tests

@property
def output_names(self) -> Sequence[str]:
"""Returns the list of variables assigned to by this statement."""
return [str(x) for x in self.result]
return self._output_names


class IRFunction:
Expand All @@ -248,7 +136,7 @@
# a dictionary of nested function-definitions
self.nested_functions: dict[str, IRFunction] = {}
self.outer_scope_variables: dict[Any, Any] = {}
self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = []
self.ordered_inputs_and_attrs: list[Union[IRVar, ir.Attr]] = []

@property
def assigned_names(self) -> Sequence[str]:
Expand All @@ -260,11 +148,11 @@
return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)]

@property
def attrs(self) -> Sequence[IRAttributeParameter]:
def attrs(self) -> Sequence[ir.Attr]:
return [
attr
for attr in self.ordered_inputs_and_attrs
if isinstance(attr, IRAttributeParameter)
if isinstance(attr, ir.Attr)
]

def __str__(self):
Expand All @@ -286,18 +174,9 @@
def append_output(self, name: IRVar) -> None:
self.outputs.append(name)

def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
def add_attr_parameter(self, attr: ir.Attr) -> None:
self.ordered_inputs_and_attrs.append(attr)

def debug_print(self):
if logger.isEnabledFor(logging.DEBUG):
st = io.StringIO()
for s in self.stmts:
for attr in s.attrs:
if attr.attr_proto.HasField("g"):
st.write(helper.printable_graph(attr.attr_proto.g))
st.write("\n")

def add_called_function(self, fun: values.OnnxFunction) -> None:
for name, fct in fun.function_ir.called_functions.items():
if name in self.called_functions:
Expand Down Expand Up @@ -440,40 +319,33 @@
)
return func_opset_imports

def to_function_proto(self) -> onnx.FunctionProto:
"""Converts this instance into a `onnx.FunctionProto`.

Note: Default values for attributes are an experimental feature in ONNX.
Conversion ignores default values for attributes if the ONNX version installed
doesn't support it.
"""
def to_ir_function(self) -> ir.Function:
"""Converts this instance into a `ir.Function`."""
opsets = self.get_opset_import()
nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)]
for n in nodes:
if n.domain not in opsets:
opsets[n.domain] = 1 # TODO: how to get n.version?
opset_imports = [
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
]

attribute_names = [attr.name for attr in self.attrs if not attr.has_default]

f = helper.make_function(
self.domain,
self.name,
inputs=[x.name for x in self.inputs],
outputs=[y.name for y in self.outputs],
values = {}

Check warning

Code scanning / lintrunner

PYLINT/W0621 Warning

Redefining name 'values' from outer scope (line 19) (redefined-outer-name)
See redefined-outer-name. To disable, use # pylint: disable=redefined-outer-name

Check failure

Code scanning / lintrunner

MYPY/var-annotated Error

Need type annotation for "values" (hint: "values: Dict[, ] = ...") To disable, use # type: ignore[var-annotated]
nodes = []

Check warning on line 326 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L325-L326

Added lines #L325 - L326 were not covered by tests
function_outputs: dict[str, ir.Value | None] = {x.name: None for x in self.outputs}
for i, s in enumerate(self.stmts):
node = s.to_node(f"n{i}", values)
nodes.append(node)

Check warning on line 330 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L329-L330

Added lines #L329 - L330 were not covered by tests
if node.domain not in opsets:
# FIXME(justinchuby): Node version
assert s.version is not None

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"IRStmt" has no attribute "version" To disable, use # type: ignore[attr-defined]
opsets[node.domain] = s.version

Check warning on line 334 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L333-L334

Added lines #L333 - L334 were not covered by tests

Check failure

Code scanning / lintrunner

MYPY/attr-defined Error

"IRStmt" has no attribute "version" To disable, use # type: ignore[attr-defined]
for output in node.outputs:
values[output.name] = output

Check warning on line 336 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L336

Added line #L336 was not covered by tests
if output.name in function_outputs:
function_outputs[output.name] = output

Check warning on line 338 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L338

Added line #L338 was not covered by tests
inputs = [ir.Input(input.name) for input in self.inputs]
for name, output in function_outputs.items():

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "Value | None", variable has type "Value") To disable, use # type: ignore[assignment]
assert output is not None, f"Output {name!r} is an output of any node is the function."
graph = ir.Graph(

Check warning on line 342 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L341-L342

Added lines #L341 - L342 were not covered by tests
inputs=inputs,
outputs=function_outputs.values(), # type: ignore
nodes=nodes,
opset_imports=opset_imports, # TODO
attributes=attribute_names,
doc_string=self.docstring,
opset_imports=opsets,
)
# In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead
if hasattr(f, "attribute_proto"):
f.attribute_proto.extend(
[attr.attr_proto for attr in self.attrs if attr.has_default]
)
return f
return ir.Function(domain=self.domain, name=self.name, graph=graph, attributes=self.attrs)

Check warning on line 348 in onnxscript/irbuilder.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/irbuilder.py#L348

Added line #L348 was not covered by tests


# IRBuilder: abstracts out details of the IR in the python-to-IR converter
Expand All @@ -500,14 +372,15 @@
results: Sequence[str],
callee: values.Op,
args: Sequence[Optional[str]],
attrs: Sequence[IRAttributeValue],
attrs: Sequence[ir.Attr | ir.RefAttr],
sub_functions=None,
) -> None:
# TODO(justinchuby): Capture opset version here
stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
fn.append_stmt(stmt)

def add_input(
self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo
self, fn: IRFunction, varname: str, type: ir.TypeProtocol, info: sourceinfo.SourceInfo
) -> None:
var = IRVar(varname, type, info)
fn.append_input(var)
Expand All @@ -516,23 +389,16 @@
self,
fn: IRFunction,
varname: str,
attribute_type: onnx.AttributeProto.AttributeType,
attribute_type: ir.AttributeType,
default_value: int | float | str | None,
) -> None:
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))
fn.add_attr_parameter(ir_convenience.convert_attribute(varname, default_value, attribute_type))

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "add_attr_parameter" of "IRFunction" has incompatible type "Attr | RefAttr"; expected "Attr" To disable, use # type: ignore[arg-type]

def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
var = IRVar(varname, typeinfo, sourceinfo)
def add_output(self, fn: IRFunction, varname: str, typeinfo, source_info) -> None:
var = IRVar(varname, typeinfo, source_info)
fn.append_output(var)

def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue:
return IRAttributeValue(attrproto)

def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
proto = onnx.AttributeProto()
proto.name = attrname
proto.ref_attr_name = refname
attr_type = ta.pytype_to_attrtype(pytype)
assert attr_type is not None
proto.type = attr_type
return IRAttributeValue(proto)
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> ir.RefAttr:
return ir.RefAttr(
attrname, refname, ta.pytype_to_attrtype(pytype)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 3 to "RefAttr" has incompatible type "onnx.onnx_ml_pb2.AttributeProto.AttributeType | None"; expected "onnxscript.ir._enums.AttributeType" To disable, use # type: ignore[arg-type]
)
Loading
Loading