Skip to content

Commit

Permalink
Fix missing node names (#187)
Browse files Browse the repository at this point in the history
force node name to be not empty
  • Loading branch information
xadupre committed Jun 20, 2019
1 parent 200cb00 commit 6fc5a0e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
10 changes: 8 additions & 2 deletions skl2onnx/common/_container.py
Expand Up @@ -269,7 +269,7 @@ def _check_operator(self, op_type):
op_type, fct.__name__))

def add_node(self, op_type, inputs, outputs, op_domain='', op_version=1,
**attrs):
name=None, **attrs):
"""
Adds a *NodeProto* into the node list of the final ONNX model.
If the input operator's domain-version information cannot be
Expand All @@ -285,10 +285,15 @@ def add_node(self, op_type, inputs, outputs, op_domain='', op_version=1,
operator we are trying to add.
:param op_version: The version number (e.g., 0 and 1) of the
operator we are trying to add.
:param name: name of the node, this name cannot be empty
:param attrs: A Python dictionary. Keys and values are
attributes' names and attributes' values,
respectively.
"""
if name is None or not isinstance(
name, str) or name == '':
raise RuntimeError("Parameter name cannot be empty "
"and must be a string.")
if op_domain is None:
op_domain = get_domain()
self._check_operator(op_type)
Expand All @@ -314,7 +319,8 @@ def add_node(self, op_type, inputs, outputs, op_domain='', op_version=1,
raise ValueError('Failed to create ONNX node. Undefined '
'attribute pair (%s, %s) found' % (k, v))

node = helper.make_node(op_type, inputs, outputs, **attrs)
node = helper.make_node(op_type, inputs, outputs,
name=name, **attrs)
node.domain = op_domain

self.node_domain_version_pair_sets.add((op_domain, op_version))
Expand Down
14 changes: 10 additions & 4 deletions skl2onnx/operator_converters/nearest_neighbours.py
Expand Up @@ -133,21 +133,27 @@ def _get_probability_score(scope, container, operator, weights,
'weighted_distance')

container.add_node('Equal', [labels_name[i], topk_labels_name],
output_label_name[i])
output_label_name[i],
name=scope.get_unique_operator_name('Equal'))
apply_cast(scope, output_label_name[i], output_cast_label_name[i],
container, to=onnx_proto.TensorProto.FLOAT)
apply_mul(scope, [output_cast_label_name[i], weights_val],
weighted_distance_name, container, broadcast=0)
container.add_node('ReduceSum', weighted_distance_name,
output_label_reduced_name[i], axes=[1])
output_label_reduced_name[i], axes=[1],
name=scope.get_unique_operator_name(
'ReduceSum'))
else:
for i in range(len(classes)):
container.add_node('Equal', [labels_name[i], topk_labels_name],
output_label_name[i])
output_label_name[i],
name=scope.get_unique_operator_name('Equal'))
apply_cast(scope, output_label_name[i], output_cast_label_name[i],
container, to=onnx_proto.TensorProto.INT32)
container.add_node('ReduceSum', output_cast_label_name[i],
output_label_reduced_name[i], axes=[1])
output_label_reduced_name[i], axes=[1],
name=scope.get_unique_operator_name(
'ReduceSum'))

concat_labels_name = scope.get_unique_variable_name('concat_labels')
cast_concat_labels_name = scope.get_unique_variable_name(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_algebra_onnx_operators.py
Expand Up @@ -47,7 +47,9 @@ def conv(scope, operator, container):
op = OnnxSub(operator.inputs[0], W, output_names=operator.outputs)
op.add_to(scope, container)
text = str(container)
assert 'name:"Sub"' in text
if 'name:"Sub"' not in text:
raise AssertionError(
"Unnamed operator:\n".format(text))
nin = list(op.enumerate_initial_types())
nno = list(op.enumerate_nodes())
nva = list(op.enumerate_variables())
Expand Down

0 comments on commit 6fc5a0e

Please sign in to comment.