Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for subfunctions included in a subgraphs #41

Merged
merged 8 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _known_modules():
'onnx.helper': onnx.helper,
'onnxscript': onnxscript,
'onnxscript.onnx': onnxscript.onnx,
'onnxscript.values': onnxscript.values,
'onnxscript.onnx_types': onnxscript.onnx_types,
'onnxscript.onnx.opset15': onnxscript.onnx.opset15
}
Expand Down Expand Up @@ -246,7 +247,7 @@ def py_var_to_onnx_var(self, py_var, info):
def emit_docstring(self, docstring):
self.ir_builder.add_docstring(self.current_fn, docstring)

def emit(self, outputs, callee, inputs, attrs):
def emit(self, outputs, callee, inputs, attrs, sub_functions=None):
if callee.opname == 'NotEqual':
if len(attrs) != 0:
raise RuntimeError(
Expand All @@ -259,9 +260,9 @@ def emit(self, outputs, callee, inputs, attrs):
else:
self.ir_builder.add_stmt(
self.current_fn, outputs, callee.opset,
callee.opname, inputs, attrs)
callee.opname, inputs, attrs, sub_functions)

def emit_loop(self, outputs, callee, inputs, attrs, info):
def emit_loop(self, outputs, callee, inputs, attrs, info, sub_functions=None):
def rename(x):
r = self.generate_unique_name(x)
self.bind(x, Dynamic(r, DynamicKind.Output, info))
Expand All @@ -270,7 +271,8 @@ def rename(x):
# [ self.to_onnx_var(self.lookup(pvar)) for pvar in inputs ]
onnx_inputs = inputs
onnx_outputs = [rename(x) for x in outputs]
self.emit(onnx_outputs, Op(default_opset, callee), onnx_inputs, attrs)
self.emit(onnx_outputs, Op(default_opset, callee), onnx_inputs, attrs,
sub_functions=sub_functions)

def emit_const(self, pyvalue, suggested_name, info):
ovar = self.generate_unique_name(suggested_name)
Expand Down Expand Up @@ -550,9 +552,11 @@ def translate_if_stmt(self, stmt: ast.If):
live_defs = list(stmt.live_out.intersection(analysis.defs(stmt)))
test = self.translate_expr(stmt.test, "cond")
lineno = DebugInfo(stmt).lineno
thenGraph = self.translate_block(stmt.body, "thenGraph_%d" % lineno, live_defs)
thenGraph, sub_fct_then = self.translate_block(
stmt.body, "thenGraph_%d" % lineno, live_defs)
thenAttr = self.ir_builder.attr("then_branch", thenGraph)
elseGraph = self.translate_block(stmt.orelse, "elseGraph_%d" % lineno, live_defs)
elseGraph, sub_fct_else = self.translate_block(
stmt.orelse, "elseGraph_%d" % lineno, live_defs)
elseAttr = self.ir_builder.attr("else_branch", elseGraph)

def rename(x):
Expand All @@ -561,7 +565,11 @@ def rename(x):
return r

renamed = [rename(x) for x in live_defs]
self.emit(renamed, Op(default_opset, "If"), [test], [thenAttr, elseAttr])
sub_functions = {}
sub_functions.update(sub_fct_then)
sub_functions.update(sub_fct_else)
self.emit(renamed, Op(default_opset, "If"), [test], [thenAttr, elseAttr],
sub_functions=sub_functions)

def translate_for_stmt(self, for_stmt: ast.For):
# loop-variable
Expand Down Expand Up @@ -612,8 +620,11 @@ def translate_for_stmt(self, for_stmt: ast.For):

inputs = [o_loop_bound, o_true] + \
[self.py_var_to_onnx_var(pv, DebugInfo(for_stmt)) for pv in loop_state_vars]
attrs = [self.ir_builder.attr("body", body.to_graph_proto())]
return self.emit_loop(outputs, "Loop", inputs, attrs, DebugInfo(for_stmt))
graph, sub_functions = body.to_graph_proto()
attrs = [self.ir_builder.attr("body", graph)]
return self.emit_loop(outputs, "Loop", inputs, attrs,
sub_functions=sub_functions,
info=DebugInfo(for_stmt))

# Translation of a statement-block to GraphProto attribute
def translate_block(self, stmts, name, live_defs):
Expand Down
74 changes: 41 additions & 33 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from io import StringIO
import onnx
import onnx.helper as helper
from onnx.defs import onnx_opset_version
from . import type_annotation as ta
from .values import Opset

Expand Down Expand Up @@ -72,7 +73,7 @@ def __str__(self):


class Stmt:
def __init__(self, result, module, opname, args, attrs) -> None:
def __init__(self, result, module, opname, args, attrs, sub_functions=None) -> None:
if not isinstance(module, Opset):
raise TypeError(f"Unexpected type {type(module)} for module.")
if not isinstance(opname, str):
Expand All @@ -82,6 +83,7 @@ def __init__(self, result, module, opname, args, attrs) -> None:
self.opname = opname
self.args = args
self.attrs = attrs
self.functions = sub_functions or {}

def __str__(self):
if (isinstance(self.result, str)):
Expand Down Expand Up @@ -156,49 +158,56 @@ def debug_print(self):
st.write("\n")

def append_function(self, opf):
name = opf.name
if name in self.functions:
for name, fct in opf.function_ir.functions.items():
if name in self.functions:
continue
self.functions[name] = fct
if opf.name in self.functions:
# Already added.
return
try:
proto = opf.to_function_proto(opf.opset)
except (TypeError, AttributeError) as e:
raise TypeError(f"Issue with type f{type(opf)}.") from e
self.functions[name] = proto

def model_functions(self):
"""
The ONNX implementation may rely on additional functions
stored in `self.functions`. This method returns it.
"""
return self.functions

def to_model_proto(self, opsets=None, functions=None, **kwargs):
if opsets is None:
opsets = {'': 15}
elif isinstance(opsets, int):
opsets = {'': opsets}
else:
opsets = opsets.copy()
self.functions[opf.name] = proto

def to_model_proto(self, functions=None, **kwargs):
graph, sub_functions = self.to_graph_proto()
functions = [] if functions is None else list(functions)
functions.extend(sub_functions.values())

opsets = {}
for n in self.stmts:
if n.module.domain not in opsets:
opsets[n.module.domain] = n.module.version
if '' not in opsets:
# No operator is using the standard opset.
# A default value is given.
opsets[''] = onnx_opset_version()
for proto in functions:
if proto.domain not in opsets:
opsets[proto.domain] = 1

opset_imports = [onnx.helper.make_opsetid(domain, version)
for domain, version in opsets.items()]
graph = self.to_graph_proto()
functions = [] if functions is None else list(functions)
# TODO: the following is incomplete. we need to do this iteratively.
functions.extend(self.functions.values())

return helper.make_model(graph, opset_imports=opset_imports,
functions=functions, **kwargs)

def to_graph_proto(self):
return helper.make_graph([s.to_node_proto() for s in self.stmts],
self.name,
[x.to_value_info() for x in self.inputs],
[y.to_value_info() for y in self.outputs])

def to_function_proto_with_opset_imports(self, domain="", func_opset_imports=[]):
sub_functions = {}
for s in self.stmts:
sub_functions.update(s.functions)
sub_functions.update(self.functions)
graph = helper.make_graph([s.to_node_proto() for s in self.stmts],
self.name,
[x.to_value_info() for x in self.inputs],
[y.to_value_info() for y in self.outputs])
return graph, sub_functions

def to_function_proto_with_opset_imports(self, domain="", func_opset_imports=None):
if func_opset_imports is None:
func_opset_imports = []
# TODO: Ideally, in the long term, we should infer func_opset_imports
# from the set of calls within the function itself.
return helper.make_function(domain,
Expand Down Expand Up @@ -230,8 +239,7 @@ def to_function_proto(self, domain):
nodes=nodes,
opset_imports=opset_imports, # TODO
attributes=[a.name for a in self.attrs],
doc_string=self.docstring
)
doc_string=self.docstring)

# IRBuilder: abstracts out details of the IR in the python-to-IR converter

Expand All @@ -252,8 +260,8 @@ def new_function(self, name, domain="", register=False):
def add_docstring(self, fn, docstring):
fn.append_docstring(docstring)

def add_stmt(self, fn, results, module, opname, args, attrs):
s = Stmt(results, module, opname, args, attrs)
def add_stmt(self, fn, results, module, opname, args, attrs, sub_functions=None):
s = Stmt(results, module, opname, args, attrs, sub_functions=sub_functions)
fn.append_stmt(s)

def add_input(self, fn, varname, type):
Expand Down
30 changes: 20 additions & 10 deletions onnxscript/test/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from onnx.helper import printable_graph
from onnx.onnx_cpp2py_export.checker import ValidationError
import onnxruntime
from onnxruntime.capi.onnxruntime_pybind11_state import Fail, InvalidGraph
from onnxscript import script
from onnxscript.onnx import opset15 as op
from onnxscript.onnx_types import FLOAT, INT64
Expand Down Expand Up @@ -48,24 +49,29 @@ def validate_save(self, script, save_text=False, check_ort=False, shape_inferenc
with self.subTest(f=f.name):
model = f.to_model_proto()
if save_text:
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".txt"), 'w') as f:
f.write(printable_graph(model.graph))
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".txt"), 'w') as fi:
fi.write(printable_graph(model.graph))
for fct in model.functions:
f.write("\n-------------------------\n")
f.write(printable_graph(fct))
fi.write("\n-------------------------\n")
fi.write(printable_graph(fct))
if check_ort:
onnxruntime.InferenceSession(model.SerializeToString())
try:
onnxruntime.InferenceSession(model.SerializeToString())
except (Fail, InvalidGraph) as e:
raise AssertionError(
f"onnxruntime cannot load function "
f"{f.name}\n{str(model)}") from e
if shape_inference:
model = onnx.shape_inference.infer_shapes(model)
if save_text:
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".shape.txt"), 'w') as f:
f.write(printable_graph(model.graph))
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".shape.txt"), 'w') as fi:
fi.write(printable_graph(model.graph))
for fct in model.functions:
f.write("\n-------------------------\n")
f.write(printable_graph(fct))
try:
onnx.checker.check_model(model)
except ValidationError as e:
except (ValidationError, AssertionError) as e:
if "Field 'shape' of type is required but missing" in str(e):
# input or output shapes are missing because the function
# was defined with FLOAT[...].
Expand Down Expand Up @@ -109,12 +115,16 @@ def test_unary_op(self):
from onnxscript.test.models import m1
self.validate_save(m1)

def test_subfunction(self):
def test_subfunction_check_model(self):
from onnxscript.test.models import subfunction
model = subfunction.MyElu.function_ir.to_model_proto(producer_name='p2o')
model = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model)

def test_subfunction(self):
from onnxscript.test.models import subfunction
self.validate_save(subfunction, check_ort=True)

def test_if_models(self):
from onnxscript.test.models import if_statement
self.validate_save(if_statement)
Expand Down Expand Up @@ -152,5 +162,5 @@ def clipmax(x: FLOAT['N'], max: FLOAT): # noqa: F821
if __name__ == '__main__':
# import logging
# logging.basicConfig(level=logging.DEBUG)
# TestConverter().test_signal()
# TestConverter().test_subfunction()
unittest.main()
26 changes: 20 additions & 6 deletions onnxscript/test/models/subfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,41 @@
from onnxscript.values import CustomOpset
from onnxscript.onnx import opset15 as op

opset = CustomOpset('this', 1)

@script(opset)
@script(CustomOpset('this', 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any significance to this change? If I am not mistaken, CustomOpset is mutable, with methods to add new functions to it. In principle, we could potentially update the opset object in script.

I think this does not matter right now, but I feel it may be useful to have examples share the same opset object in the long run.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No significance except the converter failed. We should decide how to handle global variables, ignore them, or do something else. Ignore them is easier.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised. Are you saying it fails with

   opset = CustomOpset('this', 1)
   @script(opset)

but succeeds with

   @script(CustomOpset('this', 1))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

def MySelu(X: FLOAT[None], alpha: FLOAT[1], gamma: FLOAT[1]) -> FLOAT[None]:
zero = op.Constant(value_float=1.)
neg = gamma * (alpha * op.Exp(X) - alpha)
pos = gamma * X
return op.Where(X <= 0, neg, pos)
return op.Where(X <= zero, neg, pos)

@script(opset)
@script(CustomOpset('this', 1))
def MyElu(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
alpha = op.Constant(value_float=1.)
return MySelu(X, alpha, beta)

@script(opset)
@script(CustomOpset('this', 1))
def MyEluB(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
alpha = op.Constant(value_float=1.)
res = MySelu(X, alpha, beta)
return res

@script(opset)
@script(CustomOpset('this', 1))
def MyEluC(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
alpha = op.Constant(value_float=1.)
res = op.Identity(MySelu(X, alpha, beta))
return res

@script(CustomOpset('this', 1))
def MyEluD(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
res = op.Identity(MyEluC(X, beta))
return res

@script(CustomOpset('this', 1))
def IfMyEluD(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
zero = op.Constant(value_float=1.)
if beta > zero:
result = MyEluB(X, beta)
else:
result = MyEluC(X, beta)
return result
4 changes: 2 additions & 2 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ def __call__(self, *args, **kwargs):
def to_function_proto(self, domain=None):
return self.function_ir.to_function_proto(domain or self.opset)

def to_model_proto(self):
def to_model_proto(self, **kwargs):
if self.function_ir.attrs:
raise ValueError("A function with attributes cannot be exported as a model.")
# Note: The function must also have monomorphic type annotation for inputs/outputs
# to be converted into a valid model. Otherwise, we can still produce an ONNX
# model, but it will not pass the ONNX model checker. We do not report an error
# at this stage.
return self.function_ir.to_model_proto()
return self.function_ir.to_model_proto(**kwargs)


# Values fall into the following categories:
Expand Down