Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ReferenceEvaluator when run from a subclass #5936

Merged
merged 4 commits into from Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions onnx/reference/op_run.py
Expand Up @@ -249,16 +249,16 @@ def _extract_attribute_value(
) -> Any:
"""Converts an attribute value into a python value."""
if att.type == AttributeProto.GRAPH:
from onnx.reference.reference_evaluator import (
ReferenceEvaluator, # type: ignore
)

new_ops = self.run_params.get("new_ops", None)
if "existing_functions" in self.run_params:
functions = list(self.run_params["existing_functions"].values())
else:
functions = None
return ReferenceEvaluator(
evaluator_cls = self.run_params.get("evaluator_cls", None)
assert (
evaluator_cls is not None
), f"evaluator_cls must be specified to evaluate att={att}"
return evaluator_cls(
att.g,
opsets=self.run_params["opsets"],
verbose=max(0, self.run_params.get("verbose", 0) - 2),
Expand Down
15 changes: 1 addition & 14 deletions onnx/reference/ops/_op.py
Expand Up @@ -16,9 +16,6 @@ class OpRunUnary(OpRun):
Checks that input and output types are the same.
"""

def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]):
OpRun.__init__(self, onnx_node, run_params)

def run(self, x): # type: ignore
"""Calls method ``_run``, catches exceptions, displays a longer error message.

Expand All @@ -42,9 +39,6 @@ class OpRunUnaryNum(OpRunUnary):
Checks that input and output types are the same.
"""

def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]):
OpRunUnary.__init__(self, onnx_node, run_params)

def run(self, x): # type: ignore
"""Calls method ``OpRunUnary.run``.

Expand All @@ -68,9 +62,6 @@ class OpRunBinary(OpRun):
Checks that input and output types are the same.
"""

def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]):
OpRun.__init__(self, onnx_node, run_params)

def run(self, x, y): # type: ignore
"""Calls method ``_run``, catches exceptions, displays a longer error message.

Expand Down Expand Up @@ -101,8 +92,7 @@ def run(self, x, y): # type: ignore
class OpRunBinaryComparison(OpRunBinary):
"""Ancestor to all binary operators in this subfolder comparing tensors."""

def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]):
OpRunBinary.__init__(self, onnx_node, run_params)
pass


class OpRunBinaryNum(OpRunBinary):
Expand All @@ -111,9 +101,6 @@ class OpRunBinaryNum(OpRunBinary):
Checks that input oud output types are the same.
"""

def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]):
OpRunBinary.__init__(self, onnx_node, run_params)

def run(self, x, y): # type: ignore
"""Calls method ``OpRunBinary.run``, catches exceptions, displays a longer error message."""
res = OpRunBinary.run(self, x, y)
Expand Down
16 changes: 10 additions & 6 deletions onnx/reference/ops/_op_list.py
Expand Up @@ -252,6 +252,7 @@ def load_op(
node: Union[None, NodeProto] = None,
input_types: Union[None, List[TypeProto]] = None,
expand: bool = False,
evaluator_cls: TOptional[type] = None,
) -> Any:
"""Loads the implemented for a specified operator.

Expand All @@ -266,6 +267,7 @@ def load_op(
operator defines a function which is context dependant
expand: use the function implemented in the schema instead of
its reference implementation
evaluator_cls: evaluator to use

Returns:
class
Expand All @@ -292,10 +294,11 @@ def load_op(
f"and domain {domain!r}. Did you recompile the sources after updating the repository?"
) from None
if schema.has_function: # type: ignore
from onnx.reference import ReferenceEvaluator

body = schema.function_body # type: ignore
sess = ReferenceEvaluator(body)
assert (
evaluator_cls is not None
), f"evaluator_cls must be specified to implement operator {op_type!r} from domain {domain!r}"
sess = evaluator_cls(body)
return lambda *args, sess=sess: OpFunction(*args, impl=sess) # type: ignore
if schema.has_context_dependent_function: # type: ignore
if node is None or input_types is None:
Expand All @@ -304,14 +307,15 @@ def load_op(
f"and domain {domain!r}, the operator has a context dependent function. "
f"but argument node or input_types is not defined (input_types={input_types})."
)
from onnx.reference import ReferenceEvaluator

body = schema.get_context_dependent_function( # type: ignore
node.SerializeToString(), [it.SerializeToString() for it in input_types]
)
proto = FunctionProto()
proto.ParseFromString(body)
sess = ReferenceEvaluator(proto)
assert (
evaluator_cls is not None
), f"evaluator_cls must be specified to evaluate function {proto.name!r}"
sess = evaluator_cls(proto)
return lambda *args, sess=sess: OpFunction(*args, impl=sess) # type: ignore
found = False
if not found:
Expand Down
7 changes: 6 additions & 1 deletion onnx/reference/ops/aionnx_preview_training/_op_list.py
Expand Up @@ -21,7 +21,11 @@ def _build_registered_operators() -> Dict[str, Dict[Union[int, None], OpRunTrain


def load_op(
domain: str, op_type: str, version: Union[None, int], custom: Any = None
domain: str,
op_type: str,
version: Union[None, int],
custom: Any = None,
evaluator_cls: TOptional[type] = None,
) -> Any:
"""Loads the implemented for a specified operator.

Expand All @@ -30,6 +34,7 @@ def load_op(
op_type: oprator type
version: requested version
custom: custom implementation (like a function)
evaluator_cls: unused

Returns:
class
Expand Down
7 changes: 6 additions & 1 deletion onnx/reference/ops/aionnxml/_op_list.py
Expand Up @@ -37,7 +37,11 @@ def _build_registered_operators() -> Dict[str, Dict[Union[int, None], OpRunAiOnn


def load_op(
domain: str, op_type: str, version: Union[None, int], custom: Any = None
domain: str,
op_type: str,
version: Union[None, int],
custom: Any = None,
evaluator_cls: TOptional[type] = None,
) -> Any:
"""Loads the implemented for a specified operator.

Expand All @@ -46,6 +50,7 @@ def load_op(
op_type: oprator type
version: requested version
custom: custom implementation (like a function)
evaluator_cls: unused

Returns:
class
Expand Down
16 changes: 10 additions & 6 deletions onnx/reference/ops/aionnxml/op_tree_ensemble.py
Expand Up @@ -175,18 +175,22 @@ def build_node(current_node_index, is_leaf) -> Node | Leaf:
nodes_modes[current_node_index],
set_members,
nodes_featureids[current_node_index],
nodes_missing_value_tracks_true[current_node_index]
if nodes_missing_value_tracks_true is not None
else False,
(
nodes_missing_value_tracks_true[current_node_index]
if nodes_missing_value_tracks_true is not None
else False
),
)
else:
node = Node(
nodes_modes[current_node_index],
nodes_splits[current_node_index],
nodes_featureids[current_node_index],
nodes_missing_value_tracks_true[current_node_index]
if nodes_missing_value_tracks_true is not None
else False,
(
nodes_missing_value_tracks_true[current_node_index]
if nodes_missing_value_tracks_true is not None
else False
),
)

# recurse true and false branches
Expand Down
7 changes: 6 additions & 1 deletion onnx/reference/ops/experimental/_op_list.py
Expand Up @@ -21,7 +21,11 @@ def _build_registered_operators() -> (


def load_op(
domain: str, op_type: str, version: Union[None, int], custom: Any = None
domain: str,
op_type: str,
version: Union[None, int],
custom: Any = None,
evaluator_cls: TOptional[type] = None,
) -> Any:
"""Loads the implemented for a specified operator.

Expand All @@ -30,6 +34,7 @@ def load_op(
op_type: oprator type
version: requested version
custom: custom implementation (like a function)
evaluator_cls: unused

Returns:
class
Expand Down
32 changes: 20 additions & 12 deletions onnx/reference/ops/op_rnn.py
Expand Up @@ -36,22 +36,30 @@ def __init__(self, onnx_node, run_params): # type: ignore

self.f1 = self.choose_act(
self.activations[0], # type: ignore
self.activation_alpha[0] # type: ignore
if self.activation_alpha is not None and len(self.activation_alpha) > 0 # type: ignore
else None,
self.activation_beta[0] # type: ignore
if self.activation_beta is not None and len(self.activation_beta) > 0 # type: ignore
else None,
(
self.activation_alpha[0] # type: ignore
if self.activation_alpha is not None and len(self.activation_alpha) > 0 # type: ignore
else None
),
(
self.activation_beta[0] # type: ignore
if self.activation_beta is not None and len(self.activation_beta) > 0 # type: ignore
else None
),
)
if len(self.activations) > 1: # type: ignore
self.f2 = self.choose_act(
self.activations[1], # type: ignore
self.activation_alpha[1] # type: ignore
if self.activation_alpha is not None and len(self.activation_alpha) > 1 # type: ignore
else None,
self.activation_beta[1] # type: ignore
if self.activation_beta is not None and len(self.activation_beta) > 1 # type: ignore
else None,
(
self.activation_alpha[1] # type: ignore
if self.activation_alpha is not None and len(self.activation_alpha) > 1 # type: ignore
else None
),
(
self.activation_beta[1] # type: ignore
if self.activation_beta is not None and len(self.activation_beta) > 1 # type: ignore
else None
),
)
self.n_outputs = len(onnx_node.output)

Expand Down
36 changes: 22 additions & 14 deletions onnx/reference/ops/op_scan.py
Expand Up @@ -16,10 +16,12 @@ def __init__(self, onnx_node, run_params): # type: ignore
f"Parameter 'body' must have a method 'run', type {type(self.body)}." # type: ignore
)
self.input_directions_ = [
0
if self.scan_input_directions is None # type: ignore
or i >= len(self.scan_input_directions) # type: ignore
else self.scan_input_directions[i] # type: ignore
(
0
if self.scan_input_directions is None # type: ignore
or i >= len(self.scan_input_directions) # type: ignore
else self.scan_input_directions[i]
) # type: ignore
for i in range(self.num_scan_inputs) # type: ignore
]
max_dir_in = max(self.input_directions_)
Expand All @@ -28,9 +30,11 @@ def __init__(self, onnx_node, run_params): # type: ignore
"Scan is not implemented for other output input_direction than 0."
)
self.input_axes_ = [
0
if self.scan_input_axes is None or i >= len(self.scan_input_axes) # type: ignore
else self.scan_input_axes[i] # type: ignore
(
0
if self.scan_input_axes is None or i >= len(self.scan_input_axes) # type: ignore
else self.scan_input_axes[i]
) # type: ignore
for i in range(self.num_scan_inputs) # type: ignore
]
max_axe_in = max(self.input_axes_)
Expand All @@ -44,10 +48,12 @@ def _common_run_shape(self, *args): # type: ignore
num_scan_outputs = len(args) - num_loop_state_vars

output_directions = [
0
if self.scan_output_directions is None # type: ignore
or i >= len(self.scan_output_directions) # type: ignore
else self.scan_output_directions[i] # type: ignore
(
0
if self.scan_output_directions is None # type: ignore
or i >= len(self.scan_output_directions) # type: ignore
else self.scan_output_directions[i]
) # type: ignore
for i in range(num_scan_outputs)
]
max_dir_out = max(output_directions)
Expand All @@ -56,9 +62,11 @@ def _common_run_shape(self, *args): # type: ignore
"Scan is not implemented for other output output_direction than 0."
)
output_axes = [
0
if self.scan_output_axes is None or i >= len(self.scan_output_axes) # type: ignore
else self.scan_output_axes[i] # type: ignore
(
0
if self.scan_output_axes is None or i >= len(self.scan_output_axes) # type: ignore
else self.scan_output_axes[i]
) # type: ignore
for i in range(num_scan_outputs)
]
max_axe_out = max(output_axes)
Expand Down