Skip to content

Commit

Permalink
Fix reference implementation for nested local functions (#5817)
Browse files Browse the repository at this point in the history
### Description
Propagation model local functions to subgraphs.

### Motivation and Context
ReferenceEvaluator fails at calling a local function from a subgraph in
a local functions.

---------

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
Co-authored-by: G. Ramalingam <grama@microsoft.com>
  • Loading branch information
xadupre and gramalingam committed Jan 3, 2024
1 parent 717b630 commit 75c6892
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 7 deletions.
22 changes: 17 additions & 5 deletions onnx/reference/op_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,16 @@ def _extract_attribute_value(
)

new_ops = self.run_params.get("new_ops", None)
if "existing_functions" in self.run_params:
functions = list(self.run_params["existing_functions"].values())
else:
functions = None
return ReferenceEvaluator(
att.g,
opsets=self.run_params["opsets"],
verbose=max(0, self.run_params.get("verbose", 0) - 2),
new_ops=None if new_ops is None else list(new_ops.values()),
functions=functions,
)
if att.type in OpRun._attribute_conversion_functions:
return OpRun._attribute_conversion_functions[att.type](att) # type: ignore
Expand Down Expand Up @@ -629,7 +634,9 @@ def eval( # noqa: A003
class OpRunExpand(OpRun):
"""Class any operator to avoid must inherit from."""

def __init__(self, onnx_node: NodeProto, log_function: Any, impl: Any = None):
def __init__(
self, onnx_node: NodeProto, run_params: dict[str, Any], impl: Any = None
):
raise RuntimeError(
f"The reference implementation must not use this node ({type(self)})."
)
Expand All @@ -646,7 +653,7 @@ class OpFunction(OpRun):
def __init__(
self,
onnx_node: NodeProto,
log_function: Any,
run_params: dict[str, Any] | None,
impl: Any = None,
attributes: dict[str, Any] | None = None,
):
Expand All @@ -655,7 +662,7 @@ def __init__(
f"impl cannot be None for node type {onnx_node.op_type!r} "
f"from domain {onnx_node.domain!r}."
)
OpRun.__init__(self, onnx_node, log_function)
OpRun.__init__(self, onnx_node, run_params) # type: ignore[arg-type]
self.impl_ = impl
# The function implementation is the same whenever the function is called
# but the attributes may be different at every call.
Expand Down Expand Up @@ -693,8 +700,13 @@ class OpFunctionContextDependant(OpFunction):
This is needed when the schema of an operator defines a context dependant function.
"""

def __init__(self, onnx_node: NodeProto, log_function: Any, parent: Any = None):
OpFunction.__init__(self, onnx_node, log_function, impl=self, attributes={})
def __init__(
self,
onnx_node: NodeProto,
run_params: dict[str, Any] | None,
parent: Any = None,
):
OpFunction.__init__(self, onnx_node, run_params, impl=self, attributes={})
self.parent = parent
version = parent.opsets[onnx_node.domain]
self.schema_ = get_schema(onnx_node.op_type, version, onnx_node.domain)
Expand Down
2 changes: 2 additions & 0 deletions onnx/reference/ops/op_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __init__(self, onnx_node, run_params): # type: ignore
raise KeyError("run_params must contains key 'opsets'.")
if "verbose" not in run_params:
raise KeyError("run_params must contains key 'verbose'.")
if "existing_functions" not in self.run_params:
raise KeyError("run_params must contains key 'existing_functions'.")

def need_context(self) -> bool:
"""Tells the runtime if this node needs the context
Expand Down
4 changes: 2 additions & 2 deletions onnx/reference/reference_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,8 @@ def __init__( # type: ignore
if functions is not None:
for f in functions: # type: ignore
if isinstance(f, FunctionProto):
existing_functions = list(self.functions_.values())
self.functions_[f.domain, f.name] = ReferenceEvaluator(
f, verbose=verbose, functions=existing_functions
f, verbose=verbose, functions=list(self.functions_.values())
)
elif isinstance(f, ReferenceEvaluator):
onx = f.proto_ # type: ignore
Expand Down Expand Up @@ -411,6 +410,7 @@ def _init(self) -> None:
"opsets": self.opsets,
"verbose": self.verbose,
"new_ops": self.new_ops_,
"existing_functions": self.functions_.copy(),
}
if self.input_types_:
all_types = {i.name: i.type for i in self.onnx_graph_.input}
Expand Down
158 changes: 158 additions & 0 deletions onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
make_model,
make_model_gen_version,
make_node,
make_operatorsetid,
make_opsetid,
make_sequence_type_proto,
make_tensor,
Expand Down Expand Up @@ -5339,6 +5340,163 @@ def test_regex_invalid_pattern(self):
with self.assertRaises(ValueError):
ref.run(None, {"X": np.array(["x"])})

def test_a_function_calling_a_function_once(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, ["N"])
output = make_tensor_value_info("output", TensorProto.FLOAT, ["N"])
Z = make_tensor_value_info("output", TensorProto.FLOAT, ["N"])

func_def_add = make_function(
"this",
"fctadd",
["input2"],
["output"],
[
make_node("Constant", [], ["one"], value_floats=[1.0], name="CC0"),
make_node("Add", ["input2", "one"], ["output"], name="A1"),
],
opset_imports=[make_operatorsetid("", 15)],
)

func_def = make_function(
"this",
"fct",
["input"],
["output"],
[
make_node("Constant", [], ["one"], value_floats=[1.0], name="CC"),
make_node("Greater", ["input", "one"], ["cond"]),
make_node(
"If",
["cond"],
["output"],
then_branch=make_graph(
[make_node("fctadd", ["input"], ["output"], domain="this")],
"gthen",
[],
[output],
),
else_branch=make_graph(
[make_node("Add", ["input", "one"], ["output"], domain="")],
"gelse",
[],
[output],
),
name=":IF",
),
],
opset_imports=[
make_operatorsetid("", 15),
make_operatorsetid("this", 1),
],
)

model_def = make_model(
make_graph(
[
make_node("fct", ["X"], ["output"], domain="this"),
],
"test",
[X],
[Z],
),
ir_version=7,
opset_imports=[
make_operatorsetid("", 15),
make_operatorsetid("this", 1),
],
functions=[func_def_add, func_def],
)

feeds = {"X": np.array([-5], dtype=np.float32)}
oinf = ReferenceEvaluator(model_def)
expected = oinf.run(None, feeds)

# inlining does not work here
# inlined = inline_local_functions(model_def)
# oinf = ReferenceEvaluator(inlined)
# goti = oinf.run(None, feeds)
# self.assertEqual(expected[0].tolist(), goti[0].tolist())
self.assertEqual(expected[0], np.array([-4], dtype=np.float32))

def test_a_function_calling_a_function_double(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, ["N"])
output = make_tensor_value_info("output", TensorProto.FLOAT, ["N"])
Z = make_tensor_value_info("output", TensorProto.FLOAT, ["N"])

func_def_add = make_function(
"this",
"fctadd",
["input2"],
["output"],
[
make_node("Constant", [], ["one"], value_floats=[1.0], name="CC0"),
make_node("Add", ["input2", "one"], ["output"], name="A1"),
],
opset_imports=[make_operatorsetid("", 15)],
)

func_def = make_function(
"this",
"fct",
["input"],
["output"],
[
make_node("Constant", [], ["one"], value_floats=[1.0], name="CC"),
make_node("Greater", ["input", "one"], ["cond"]),
make_node(
"If",
["cond"],
["output"],
then_branch=make_graph(
[make_node("fctadd", ["input"], ["output"], domain="this")],
"gthen",
[],
[output],
),
else_branch=make_graph(
[make_node("Add", ["input", "one"], ["output"], domain="")],
"gelse",
[],
[output],
),
name=":IF",
),
],
opset_imports=[
make_operatorsetid("", 15),
make_operatorsetid("this", 1),
],
)

model_def = make_model(
make_graph(
[
make_node("fct", ["X"], ["ztmp"], domain="this"),
make_node("fct", ["ztmp"], ["output"], domain="this"),
],
"test",
[X],
[Z],
),
ir_version=7,
opset_imports=[
make_operatorsetid("", 15),
make_operatorsetid("this", 1),
],
functions=[func_def_add, func_def],
)

feeds = {"X": np.array([-5], dtype=np.float32)}
oinf = ReferenceEvaluator(model_def)
expected = oinf.run(None, feeds)

# inlining does not work here
# inlined = inline_local_functions(model_def)
# oinf = ReferenceEvaluator(inlined)
# goti = oinf.run(None, feeds)
# self.assertEqual(expected[0].tolist(), goti[0].tolist())
self.assertEqual(expected[0], np.array([-3], dtype=np.float32))


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 75c6892

Please sign in to comment.