Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for subfunctions included in a subgraphs #41

Merged
merged 8 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _known_modules():
return {
'onnxscript': onnxscript,
'onnxscript.onnx': onnxscript.onnx,
'onnxscript.values': onnxscript.values,
'onnxscript.onnx_types': onnxscript.onnx_types,
'onnxscript.onnx.opset15': onnxscript.onnx.opset15
}
Expand Down Expand Up @@ -587,7 +588,7 @@ def translate_function_def(self, fn: ast.FunctionDef):
warn(f"{fn.name}: Default values not yet implemented.")
if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg:
warn(f"{fn.name}: Unsupported feature in function signature.")
domain = self.globals["__opset_domain__"] if "__opset_domain__" in self.globals else ""
domain = self.globals["__opset_domain__"] if "__opset_domain__" in self.globals else "this"
xadupre marked this conversation as resolved.
Show resolved Hide resolved
self.current_fn = self.ir_builder.new_function(fn.name, domain)
for x in args.args:
if x.annotation:
Expand Down
28 changes: 18 additions & 10 deletions onnxscript/test/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from onnx.helper import printable_graph
from onnx.onnx_cpp2py_export.checker import ValidationError
import onnxruntime
from onnxruntime.capi.onnxruntime_pybind11_state import Fail
from onnxscript.converter import Converter
from onnxscript.values import Opset

Expand All @@ -29,23 +30,26 @@ def _convert_and_save(self, script, save_text=False, check_ort=False):
with self.subTest(f=f.name):
model = f.to_model_proto(producer_name='p2o')
if save_text:
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".txt"), 'w') as f:
f.write(printable_graph(model.graph))
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".txt"), 'w') as fi:
fi.write(printable_graph(model.graph))
for fct in model.functions:
f.write("\n-------------------------\n")
f.write(printable_graph(fct))
fi.write("\n-------------------------\n")
fi.write(printable_graph(fct))
if check_ort:
onnxruntime.InferenceSession(model.SerializeToString())
try:
onnxruntime.InferenceSession(model.SerializeToString())
except Fail as e:
raise AssertionError(f"onnxruntime cannot load function {f.name}\n{str(model)}") from e
model = onnx.shape_inference.infer_shapes(model)
if save_text:
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".shape.txt"), 'w') as f:
f.write(printable_graph(model.graph))
with open(os.path.join(TEST_OUTPUT_DIR, f.name + ".shape.txt"), 'w') as fi:
fi.write(printable_graph(model.graph))
for fct in model.functions:
f.write("\n-------------------------\n")
f.write(printable_graph(fct))
try:
onnx.checker.check_model(model)
except ValidationError as e:
except (ValidationError, AssertionError) as e:
onnx.save(model, os.path.join(TEST_OUTPUT_DIR, f.name + ".error.onnx"))
raise AssertionError(
"Verification of model failed.") from e
Expand Down Expand Up @@ -101,12 +105,15 @@ def test_onnxfns1A(self):
def test_models(self):
self._convert_and_save(os.path.join(TEST_INPUT_DIR, "onnxmodels.py"))

def test_subfunction(self):
from .models import subfunction
def test_subfunction_check_model(self):
from onnxscript.test.models import subfunction
model = subfunction.MyElu.function_ir.to_model_proto(producer_name='p2o')
model = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model)

def test_subfunction(self):
self._convert_and_save(os.path.join(TEST_INPUT_DIR, "subfunction.py"), check_ort=True)

def test_if_models(self):
self._convert_and_save(os.path.join(TEST_INPUT_DIR, "if_statement.py"))

Expand All @@ -123,4 +130,5 @@ def test_docstring(self):
if __name__ == '__main__':
# import logging
# logging.basicConfig(level=logging.DEBUG)
TestConverter().test_subfunction()
unittest.main()
20 changes: 14 additions & 6 deletions onnxscript/test/models/subfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,35 @@
from onnxscript.values import CustomOpset
from onnxscript.onnx import opset15 as op

opset = CustomOpset('this', 1)

@script(opset)
@script(CustomOpset('this', 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any significance to this change? If I am not mistaken, CustomOpset is mutable, with methods to add new functions to it. In principle, we could potentially update the opset object in script.

I think this does not matter right now, but I feel it may be useful to have examples share the same opset object in the long run.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No significance except the converter failed. We should decide how to handle global variables, ignore them, or do something else. Ignore them is easier.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised. Are you saying it fails with

   opset = CustomOpset('this', 1)
   @script(opset)

but succeeds with

   @script(CustomOpset('this', 1))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

def MySelu(X: FLOAT[None], alpha: FLOAT[1], gamma: FLOAT[1]) -> FLOAT[None]:
neg = gamma * (alpha * op.Exp(X) - alpha)
pos = gamma * X
return op.Where(X <= 0, neg, pos)
return op.Where(X <= 0., neg, pos)

@script(opset)
@script(CustomOpset('this', 1))
def MyElu(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
alpha = op.Constant(value_float=1.)
return MySelu(X, alpha, beta)

@script(opset)
@script(CustomOpset('this', 1))
def MyEluB(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
alpha = op.Constant(value_float=1.)
res = MySelu(X, alpha, beta)
return res

@script(opset)
@script(CustomOpset('this', 1))
def MyEluC(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
alpha = op.Constant(value_float=1.)
res = op.Identity(MySelu(X, alpha, beta))
return res

@script(CustomOpset('this', 1))
def IfMyEluD(X: FLOAT[None], beta: FLOAT[1]) -> FLOAT[None]:
zero = op.Constant(value_float=1.)
if beta > 0:
result = MyEluB(X, beta)
else:
result = MyEluC(X, beta)
return result