Skip to content

Commit

Permalink
[Frontend] Onnx improvement (#165)
Browse files Browse the repository at this point in the history
* fix recently released layers

* fix fc layers with partial infer_shape
  • Loading branch information
zhreshold authored and tqchen committed Oct 10, 2017
1 parent fd05263 commit 52fafd3
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions python/nnvm/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import absolute_import as _abs
import tvm
from .. import symbol as _sym
from .. import graph as _graph
from .. compiler import graph_util
from .common import Renamer, AttrConverter as AttrCvt

__all__ = ['from_onnx']
Expand Down Expand Up @@ -60,9 +62,9 @@ def _pooling(name):
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), _revert_caffe2_pad)},
# very weird attributes here in onnx, force check
excludes=['dilations'],
ignores=['dilations'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': True},
extras={'ceil_mode': False},
custom_check=_dimension_constraint())

def _conv():
Expand Down Expand Up @@ -90,7 +92,7 @@ def _batch_norm():
return AttrCvt(
op_name='batch_norm',
disables=['momentum'],
ignores=['spatial', 'is_test'])
ignores=['spatial', 'is_test', 'consumed_inputs'])


# compatible operators that do NOT require any conversion.
Expand All @@ -100,6 +102,7 @@ def _batch_norm():
_convert_map = {
# defs/experimental
'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']),
'SpatialBN' : _batch_norm(),

# defs/generator
# 'Constant'
Expand Down Expand Up @@ -200,7 +203,7 @@ def _convert_operator(op_name, attrs, identity_list=None, convert_map=None):
elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs)
else:
_raise_not_supported('Operator: ' + op_name)
raise NotImplementedError("Operator {} not implemented.".format(op_name))
op = getattr(_sym, op_name, None)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
Expand Down Expand Up @@ -267,10 +270,11 @@ def from_onnx(self, graph):
new_attr = self._fix_channels(new_op, new_attr, list(node.input))
self._fix_bias_shape(node.op_type, graph.node[idx-1].op_type, node.input)
op = new_op(name=node_name, *inputs, **new_attr)
assert len(node.output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {}.".format(
len(node.output), len(op.list_output_names())))
for k, i in zip(list(node.output), range(len(node.output))):
node_output = self._fix_outputs(op_name, node.output)
assert len(node_output) == len(op.list_output_names()), (
"Number of output mismatch {} vs {} in {}.".format(
len(node_output), len(op.list_output_names()), op_name))
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# now return the outputs
out = [self._nodes[i] for i in graph.output]
Expand Down Expand Up @@ -310,6 +314,15 @@ def _parse_attr(self, attr_proto):
raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
return attrs

def _fix_outputs(self, op, outputs):
"""A hack to handle dropout or similar operator that have more than one out
in ONNX.
"""
if op == 'Dropout':
assert len(outputs) == 2, "ONNX have two outputs for dropout layer."
outputs = outputs[:-1]
return outputs

def _fix_bias(self, op, attrs, num_inputs):
"""A hack for 'use_bias' attribute since onnx don't provide this attribute,
we have to check the number of inputs to decide it."""
Expand Down Expand Up @@ -340,17 +353,24 @@ def _fix_channels(self, op, attrs, inputs):
"""
if op not in [_sym.conv2d, _sym.conv2d_transpose, _sym.dense]:
return attrs
weight_name = self._renames[inputs[1]]
if not weight_name in self._params:
raise ValueError("Unable to get channels/units attr from onnx graph.")
if inputs[1] not in self._renames:
assert inputs[1] in self._nodes
g = _graph.create(self._nodes[inputs[1]])
shape_dict = {k: v.shape for k, v in self._params.items()}
_, out_shapes = graph_util.infer_shape(g, **shape_dict)
channels = out_shapes[0][0]
else:
wshape = self._params[weight_name].shape
assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape)
channels = wshape[0]
if op in [_sym.dense]:
attrs['units'] = channels
weight_name = self._renames[inputs[1]]
if not weight_name in self._params:
raise ValueError("Unable to get channels/units attr from onnx graph.")
else:
attrs['channels'] = channels
wshape = self._params[weight_name].shape
assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape)
channels = wshape[0]
if op in [_sym.dense]:
attrs['units'] = channels
else:
attrs['channels'] = channels
return attrs

def from_onnx(graph):
Expand Down

0 comments on commit 52fafd3

Please sign in to comment.