diff --git a/ci/docker/install/requirements b/ci/docker/install/requirements index 0654fe287668..6e938e3cd7b0 100644 --- a/ci/docker/install/requirements +++ b/ci/docker/install/requirements @@ -25,7 +25,8 @@ graphviz<0.9.0,>=0.8.1 contextvars;python_version<"3.7" # Optional dependencies -onnx==1.5.0 +onnx==1.7.0 +onnxruntime==1.4.0 # protobuf version frozen due to ps-lite protobuf==3.5.2 scipy==1.4.1 diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 34249506f427..2a719a9c243e 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -913,11 +913,13 @@ unittest_centos7_gpu() { integrationtest_ubuntu_cpu_onnx() { set -ex export PYTHONPATH=./python/ - export DMLC_LOG_STACK_TRACE_DEPTH=10 + export MXNET_SUBGRAPH_VERBOSE=0 + export DMLC_LOG_STACK_TRACE_DEPTH=10 python3 tests/python/unittest/onnx/backend_test.py OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -n 4 tests/python/unittest/onnx/mxnet_export_test.py OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -n 4 tests/python/unittest/onnx/test_models.py OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -n 4 tests/python/unittest/onnx/test_node.py + OMP_NUM_THREADS=$(expr $(nproc) / 4) pytest -n 4 tests/python-pytest/onnx/test_onnxruntime.py } integrationtest_ubuntu_cpu_dist_kvstore() { diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 807e816cd9f5..3f68a3e2dd13 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -191,7 +191,7 @@ def convert_weights_and_inputs(node, **kwargs): data_type=data_type, dims=dims, vals=np_arr.flatten().tolist(), - raw=False, + raw=False ) ) @@ -478,36 +478,73 @@ def convert_pad(node, **kwargs): """Map MXNet's pad operator attributes to onnx's Pad operator and return the created node. """ + opset_version = kwargs["opset_version"] name, input_nodes, attrs = get_inputs(node, kwargs) mxnet_pad_width = convert_string_to_list(attrs.get("pad_width")) onnx_pad_width = transform_padding(mxnet_pad_width) pad_mode = attrs.get("mode") + pad_value = np.float32(attrs.get("constant_value", 0.0)) - if pad_mode == "constant": - pad_value = float(attrs.get("constant_value")) \ - if "constant_value" in attrs else 0.0 - node = onnx.helper.make_node( - 'Pad', - inputs=input_nodes, - outputs=[name], - mode='constant', - value=pad_value, - pads=onnx_pad_width, - name=name - ) + if opset_version >= 11: + # starting with opset 11, pads and constant_value are inputs instead of attributes + from onnx.helper import make_tensor, make_tensor_value_info + initializer = kwargs["initializer"] + pads_input_name = name + "_pads" + pads_input_type = onnx.TensorProto.INT64 + pads_input_shape = np.shape(np.array(onnx_pad_width)) + pads_value_node = make_tensor_value_info(pads_input_name, pads_input_type, pads_input_shape) + pads_tensor_node = make_tensor(pads_input_name, pads_input_type, pads_input_shape, onnx_pad_width) + initializer.append(pads_tensor_node) + input_nodes.append(pads_input_name) + + if pad_mode == "constant": + const_input_name = name + "_constant" + const_input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[pad_value.dtype] + const_value_node = make_tensor_value_info(const_input_name, const_input_type, ()) + const_tensor_node = make_tensor(const_input_name, const_input_type, (), [pad_value]) + initializer.append(const_tensor_node) + input_nodes.append(const_input_name) + pad_node = onnx.helper.make_node( + "Pad", + input_nodes, + [name], + mode=pad_mode, + name=name + ) + return [pads_value_node, const_value_node, pad_node] + else: + pad_node = onnx.helper.make_node( + "Pad", + input_nodes, + [name], + mode=pad_mode, + name=name + ) + return [pads_value_node, pad_node] else: - node = onnx.helper.make_node( - 'Pad', - inputs=input_nodes, - outputs=[name], - mode=pad_mode, - pads=onnx_pad_width, - name=name - ) - - return [node] + if pad_mode == "constant": + node = onnx.helper.make_node( + 'Pad', + inputs=input_nodes, + outputs=[name], + mode='constant', + value=pad_value, + pads=onnx_pad_width, + name=name + ) + return [node] + else: + node = onnx.helper.make_node( + 'Pad', + inputs=input_nodes, + outputs=[name], + mode=pad_mode, + pads=onnx_pad_width, + name=name + ) + return [node] def create_helper_tensor_node(input_vals, output_name, kwargs): """create extra tensor node from numpy values""" @@ -766,6 +803,7 @@ def convert_pooling(node, **kwargs): MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators based on the input node's attributes and return the created node. """ + opset_version = kwargs["opset_version"] name, input_nodes, attrs = get_inputs(node, kwargs) kernel = eval(attrs["kernel"]) @@ -777,12 +815,12 @@ def convert_pooling(node, **kwargs): pooling_convention = attrs.get('pooling_convention', 'valid') ceil_mode = False if pooling_convention == 'full': - if onnx.__version__ < "1.5.0": + if opset_version < 10: pooling_warning = "Pooling: ONNX lower than 1.5.0 doesn't support pooling_convention. " \ "This might lead to shape or accuracy issues. " \ "https://github.com/onnx/onnx/issues/549" + logging.warning(pooling_warning) ceil_mode = True - logging.warning(pooling_warning) pad_dims = list(parse_helper(attrs, "pad", [0, 0])) pad_dims = pad_dims + pad_dims @@ -822,7 +860,7 @@ def convert_pooling(node, **kwargs): name=name ) else: - if onnx.__version__ >= "1.5.0": + if opset_version >= 10: node = onnx.helper.make_node( pool_types[pool_type], input_nodes, # input @@ -1353,17 +1391,35 @@ def convert_dropout(node, **kwargs): and return the created node. """ name, input_nodes, attrs = get_inputs(node, kwargs) + opset_version = kwargs["opset_version"] probability = float(attrs.get("p", 0.5)) - dropout_node = onnx.helper.make_node( - "Dropout", - input_nodes, - [name], - ratio=probability, - name=name - ) - return [dropout_node] + if opset_version >= 12: + # opset >= 12 requires the ratio to be an input + initializer = kwargs["initializer"] + ratio_input_name = name + "_ratio" + value_node = onnx.helper.make_tensor_value_info(ratio_input_name, + onnx.TensorProto.FLOAT, ()) + tensor_node = onnx.helper.make_tensor(ratio_input_name, onnx.TensorProto.FLOAT, + (), [probability]) + initializer.append(tensor_node) + dropout_node = onnx.helper.make_node( + "Dropout", + [input_nodes[0], ratio_input_name], + [name], + name=name + ) + return [value_node, dropout_node] + else: + dropout_node = onnx.helper.make_node( + "Dropout", + input_nodes, + [name], + ratio=probability, + name=name + ) + return [dropout_node] @mx_op.register("Flatten") @@ -1379,19 +1435,46 @@ def convert_clip(node, **kwargs): and return the created node. """ name, input_nodes, attrs = get_inputs(node, kwargs) + opset_version = kwargs["opset_version"] - a_min = np.float(attrs.get('a_min', -np.inf)) - a_max = np.float(attrs.get('a_max', np.inf)) + a_min = float(attrs.get('a_min', -np.inf)) + a_max = float(attrs.get('a_max', np.inf)) - clip_node = onnx.helper.make_node( - "Clip", - input_nodes, - [name], - name=name, - min=a_min, - max=a_max - ) - return [clip_node] + if opset_version >= 11: + # opset >= 11 requires min/max to be inputs + initializer = kwargs["initializer"] + min_input_name = name + "_min" + max_input_name = name + "_max" + min_value_node = onnx.helper.make_tensor_value_info(min_input_name, + onnx.TensorProto.FLOAT, ()) + max_value_node = onnx.helper.make_tensor_value_info(max_input_name, + onnx.TensorProto.FLOAT, ()) + min_tensor_node = onnx.helper.make_tensor(min_input_name, onnx.TensorProto.FLOAT, + (), [a_min]) + max_tensor_node = onnx.helper.make_tensor(max_input_name, onnx.TensorProto.FLOAT, + (), [a_max]) + initializer.append(min_tensor_node) + initializer.append(max_tensor_node) + input_nodes.append(min_input_name) + input_nodes.append(max_input_name) + clip_node = onnx.helper.make_node( + "Clip", + input_nodes, + [name], + name=name + ) + return [min_value_node, max_value_node, clip_node] + + else: + clip_node = onnx.helper.make_node( + "Clip", + input_nodes, + [name], + name=name, + min=a_min, + max=a_max + ) + return [clip_node] def scalar_op_helper(node, op_name, **kwargs): @@ -2496,22 +2579,34 @@ def convert_topk(node, **kwargs): else: raise NotImplementedError("ONNX expects both value and indices as output") - export_nodes = [] - - k = np.asarray([k], dtype=np.int) - k_node = create_helper_tensor_node(k, name + '__k', kwargs) - export_nodes.extend(k_node) - k_node = k_node[-1].name - - input_node = input_nodes[0] - topk_node = onnx.helper.make_node( - "TopK", - [input_node, k_node], - outputs, - axis=axis, - name=name - ) - export_nodes.extend([topk_node]) + opset_version = kwargs['opset_version'] + if opset_version >= 10: + from onnx.helper import make_tensor, make_tensor_value_info + initializer = kwargs["initializer"] + k_input_name = name + "_k" + k_input_type = onnx.TensorProto.INT64 + k_value_node = make_tensor_value_info(k_input_name, k_input_type, ()) + k_tensor_node = make_tensor(k_input_name, k_input_type, (), k) + initializer.append(k_tensor_node) + input_nodes.append(k_input_name) + + topk_node = onnx.helper.make_node( + "TopK", + input_nodes, + outputs, + axis=axis, + name=name + ) + return [k_value_node, topk_node] + else: + topk_node = onnx.helper.make_node( + "TopK", + input_nodes, + outputs, + axis=axis, + k=k, + name=name + ) return [topk_node] diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py index 51a62ed46e59..2fc77604b9b6 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -29,7 +29,7 @@ def export_model(sym, params, input_shape, input_type=np.float32, - onnx_file_path='model.onnx', verbose=False): + onnx_file_path='model.onnx', verbose=False, opset_version=None): """Exports the MXNet model file, passed as a parameter, into ONNX model. Accepts both symbol,parameter objects as well as json and params filepaths as input. Operator support and coverage - @@ -63,11 +63,15 @@ def export_model(sym, params, input_shape, input_type=np.float32, try: from onnx import helper, mapping + from onnx.defs import onnx_opset_version except ImportError: raise ImportError("Onnx and protobuf need to be installed. " + "Instructions to install - https://github.com/onnx/onnx") converter = MXNetGraph() + if opset_version is None: + # default is to use latest opset version the onnx package supports + opset_version = onnx_opset_version() data_format = np.dtype(input_type) # if input parameters are strings(file paths), load files and create symbol parameter objects @@ -76,11 +80,11 @@ def export_model(sym, params, input_shape, input_type=np.float32, sym_obj, params_obj = load_module(sym, params) onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape, mapping.NP_TYPE_TO_TENSOR_TYPE[data_format], - verbose=verbose) + verbose=verbose, opset_version=opset_version) elif isinstance(sym, symbol.Symbol) and isinstance(params, dict): onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape, mapping.NP_TYPE_TO_TENSOR_TYPE[data_format], - verbose=verbose) + verbose=verbose, opset_version=opset_version) else: raise ValueError("Input sym and params should either be files or objects") diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 8e36685b2d40..07fdabd97598 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -159,7 +159,7 @@ def convert_weights_to_numpy(weights_dict): return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy()) for k, v in weights_dict.items()]) - def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False): + def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, opset_version=None): """Convert MXNet graph to ONNX graph Parameters @@ -174,6 +174,8 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) Input data type e.g. np.float32 verbose : Boolean If true will print logs of the model conversion + opset_version : Int + ONNX opset version to use for export, defaults to latest supported by onnx package Returns ------- @@ -183,10 +185,14 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) try: from onnx import (checker, helper, NodeProto, ValueInfoProto, TensorProto) from onnx.helper import make_tensor_value_info + from onnx.defs import onnx_opset_version except ImportError: raise ImportError("Onnx and protobuf need to be installed. " + "Instructions to install - https://github.com/onnx/onnx") + if opset_version is None: + opset_version = onnx_opset_version() + # When MXNet model is saved to json file , MXNet adds a node for label. # The name of this node is, name of the last node + "_label" ( i.e if last node # name is "Softmax", this node will have a name "Softmax_label". Also, the new node @@ -251,7 +257,8 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) graph_shapes=graph_shapes, initializer=initializer, index_lookup=index_lookup, - idx=idx + idx=idx, + opset_version=opset_version ) if isinstance(converted, list): diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 1bf60a02160b..76c8e611f450 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -195,11 +195,24 @@ def relu(attrs, inputs, proto_obj): def pad(attrs, inputs, proto_obj): """ Add padding to input tensor""" - new_attrs = translation_utils._fix_attribute_names(attrs, {'pads' : 'pad_width', - 'value' : 'constant_value' - }) - new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width')) - return 'pad', new_attrs, inputs + opset_version = proto_obj.opset_version + if 'mode' not in attrs.keys(): + attrs['mode'] = 'constant' + if opset_version >= 11: + pads = list(proto_obj._params[inputs[1].name].asnumpy()) + pads = tuple([int(i) for i in pads]) + new_attrs = translation_utils._add_extra_attributes(attrs, {'pad_width': pads}) + if len(inputs) == 3: + const = proto_obj._params[inputs[2].name].asnumpy()[0] + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'constant_value': const}) + new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width')) + return 'pad', new_attrs, inputs[0] + else: + new_attrs = translation_utils._fix_attribute_names(attrs, {'pads' : 'pad_width', + 'value' : 'constant_value' + }) + new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width')) + return 'pad', new_attrs, inputs def matrix_multiplication(attrs, inputs, proto_obj): """Performs general matrix multiplication""" @@ -322,7 +335,7 @@ def deconv(attrs, inputs, proto_obj): new_attrs = translation_utils._fix_bias('Deconvolution', new_attrs, len(inputs)) new_attrs = translation_utils._fix_channels('Deconvolution', new_attrs, inputs, proto_obj) - kernel = new_attrs['kernel'] + kernel = new_attrs['kernel'] if 'kernel' in new_attrs else [] stride = new_attrs['stride'] if 'stride' in new_attrs else [] padding = new_attrs['pad'] if 'pad' in new_attrs else [] dilations = new_attrs['dilate'] if 'dilate' in new_attrs else [] @@ -412,12 +425,22 @@ def local_response_norm(attrs, inputs, proto_obj): def dropout(attrs, inputs, proto_obj): """Dropout Regularization.""" mode = 'training' + opset_version = proto_obj.opset_version if 'is_test' in attrs and attrs['is_test'] == 0: mode = 'always' - new_attrs = translation_utils._fix_attribute_names(attrs, - {'ratio': 'p'}) - new_attrs = translation_utils._remove_attributes(new_attrs, ['is_test']) + new_attrs = translation_utils._remove_attributes(attrs, ['is_test']) new_attrs = translation_utils._add_extra_attributes(new_attrs, {'mode': mode}) + if opset_version >= 12: + new_attrs = translation_utils._remove_attributes(new_attrs, ['seed']) + if len(inputs) == 2: + ratio_float = proto_obj._params[inputs[1].name].asnumpy()[0] + new_attrs = translation_utils._remove_attributes(new_attrs, ['p']) + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'p': ratio_float}) + elif len(inputs) == 1: + new_attrs = translation_utils._fix_attribute_names(new_attrs, {'ratio': 'p'}) + return 'Dropout', new_attrs, inputs[0] + else: + new_attrs = translation_utils._fix_attribute_names(new_attrs, {'ratio': 'p'}) return 'Dropout', new_attrs, inputs # Changing shape and type. @@ -467,15 +490,30 @@ def _slice(attrs, inputs, proto_obj): """Returns a slice of the input tensor along multiple axes.""" input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0] input_shape = input_tensor_data[1] - new_attrs = translation_utils._fix_attribute_names(attrs, - {'axes' : 'axis', - 'ends' : 'end', - 'starts' : 'begin'}) - # onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator - # for multiple axes from mxnet - begin = new_attrs.get('begin') - end = list(new_attrs.get('end')) - axes = new_attrs.get('axis', tuple(range(len(begin)))) + + if proto_obj.opset_version >= 10: + begin = proto_obj._params[inputs[1].name].asnumpy() + end = proto_obj._params[inputs[2].name].asnumpy() + if len(inputs) >= 4: + axes = list(proto_obj._params[inputs[3].name].asnumpy()) + axes = tuple([int(i) for i in axes]) + else: + axes = tuple(range(len(begin))) + new_attrs = translation_utils._add_extra_attributes(attrs, {'axes' : axes, + 'begin' : begin, + 'end' : end + }) + else: + new_attrs = translation_utils._fix_attribute_names(attrs, + {'axes' : 'axis', + 'ends' : 'end', + 'starts' : 'begin'}) + # onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator + # for multiple axes from mxnet + begin = new_attrs.get('begin') + end = list(new_attrs.get('end')) + axes = new_attrs.get('axis', tuple(range(len(begin)))) + for i, axis in enumerate(axes): end[i] = None if end[i] >= input_shape[axis] else end[i] slice_op = symbol.slice_axis(inputs[0], axis=axes[0], begin=begin[0], end=end[0]) @@ -515,13 +553,28 @@ def flatten(attrs, inputs, proto_obj): def clip(attrs, inputs, proto_obj): """Clips (limits) the values in an array.""" - new_attrs = translation_utils._fix_attribute_names(attrs, {'min' : 'a_min', - 'max' : 'a_max'}) - if 'a_max' not in new_attrs: - new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_max' : np.inf}) - if 'a_min' not in new_attrs: - new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_min' : -np.inf}) - return 'clip', new_attrs, inputs + opset_version = proto_obj.opset_version + if opset_version >= 11: + if len(inputs) == 1: + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_max' : np.inf, + 'a_min' : -np.inf}) + elif len(inputs) == 2: + min_float = proto_obj._params[inputs[1].name].asnumpy() + new_attrs = translation_utils._add_extra_attributes(attrs, {'a_min': min_float[0], + 'a_max': np.inf}) + elif len(inputs) == 3: + min_float = proto_obj._params[inputs[1].name].asnumpy() + max_float = proto_obj._params[inputs[2].name].asnumpy() + new_attrs = translation_utils._add_extra_attributes(attrs, {'a_min': min_float[0], + 'a_max': max_float[0]}) + else: + new_attrs = translation_utils._fix_attribute_names(attrs, {'min' : 'a_min', + 'max' : 'a_max'}) + if 'a_max' not in new_attrs: + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_max' : np.inf}) + if 'a_min' not in new_attrs: + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_min' : -np.inf}) + return 'clip', new_attrs, inputs[0] def gather(attrs, inputs, proto_obj): """Gather elements from an input array along the given axis.""" @@ -756,4 +809,10 @@ def topk(attrs, inputs, proto_obj): new_attrs = translation_utils._add_extra_attributes(attrs, {'ret_typ': 'both', 'dtype': 'int64'}) - return 'topk', new_attrs, inputs + opset_version = proto_obj.opset_version + if opset_version >= 10: + k_vals = proto_obj._params[inputs[1].name].asnumpy() + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'k': k_vals}) + return 'topk', new_attrs, inputs[0] + else: + return 'topk', new_attrs, inputs diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_model.py b/python/mxnet/contrib/onnx/onnx2mx/import_model.py index 1c195435729b..d060b082cc5c 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/import_model.py +++ b/python/mxnet/contrib/onnx/onnx2mx/import_model.py @@ -56,7 +56,8 @@ def import_model(model_file): + "Instructions to install - https://github.com/onnx/onnx") # loads model file and returns ONNX protobuf object model_proto = onnx.load_model(model_file) - sym, arg_params, aux_params = graph.from_onnx(model_proto.graph) + model_opset_version = max([x.version for x in model_proto.opset_import]) + sym, arg_params, aux_params = graph.from_onnx(model_proto.graph, opset_version=model_opset_version) return sym, arg_params, aux_params def get_model_metadata(model_file): diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py index 72913ddaec4e..c2be83d8f12e 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py +++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py @@ -36,6 +36,7 @@ def __init__(self): self.aux_dict = {} self.arg_dict = {} self.model_metadata = {} + self.opset_version = 0 def _convert_operator(self, node_name, op_name, attrs, inputs): """Convert from onnx operator to mxnet operator. @@ -72,7 +73,7 @@ def _convert_operator(self, node_name, op_name, attrs, inputs): return mxnet_sym return op_name - def from_onnx(self, graph): + def from_onnx(self, graph, opset_version): """Construct symbol from onnx graph. Parameters @@ -87,6 +88,7 @@ def from_onnx(self, graph): params : dict A dict of name: nd.array pairs, used as pretrained weights """ + self.opset_version = opset_version # get input, output shapes self.model_metadata = self.get_graph_metadata(graph) # parse network inputs, aka parameters @@ -156,7 +158,7 @@ def get_graph_metadata(self, graph): } return metadata - def graph_to_gluon(self, graph, ctx): + def graph_to_gluon(self, graph, ctx, opset_version): """Construct SymbolBlock from onnx graph. Parameters @@ -171,7 +173,7 @@ def graph_to_gluon(self, graph, ctx): sym_block :gluon.nn.SymbolBlock The returned gluon SymbolBlock """ - sym, arg_params, aux_params = self.from_onnx(graph) + sym, arg_params, aux_params = self.from_onnx(graph, opset_version) metadata = self.get_graph_metadata(graph) data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']] data_inputs = [symbol.var(data_name) for data_name in data_names] diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py index 13ad5b9f8fa1..f6e10365d5d1 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py +++ b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py @@ -49,5 +49,6 @@ def import_to_gluon(model_file, ctx): raise ImportError("Onnx and protobuf need to be installed. Instructions to" + " install - https://github.com/onnx/onnx#installation") model_proto = onnx.load_model(model_file) - net = graph.graph_to_gluon(model_proto.graph, ctx) + model_opset_version = max([x.version for x in model_proto.opset_import]) + net = graph.graph_to_gluon(model_proto.graph, ctx, model_opset_version) return net diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py new file mode 100644 index 000000000000..052b24185735 --- /dev/null +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np +import onnxruntime + +import json +import os +import shutil +import tempfile + + +def test_cv_model_inference_onnxruntime(): + def get_gluon_cv_model(model_name, tmp): + tmpfile = os.path.join(tmp, model_name) + ctx = mx.cpu(0) + net_fp32 = mx.gluon.model_zoo.vision.get_model(model_name, pretrained=True, ctx=ctx, root=tmp) + net_fp32.hybridize() + data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx) + net_fp32.forward(data) + net_fp32.export(tmpfile, 0) + sym_file = tmpfile + '-symbol.json' + params_file = tmpfile + '-0000.params' + return sym_file, params_file + + def export_model_to_onnx(sym_file, params_file): + input_shape = (1,3,224,224) + onnx_file = os.path.join(os.path.dirname(sym_file), "model.onnx") + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, [input_shape], + np.float32, onnx_file) + return onnx_file + + def normalize_image(imgfile): + image = mx.image.imread(imgfile).asnumpy() + image_data = np.array(image).transpose(2, 0, 1) + img_data = image_data.astype('float32') + mean_vec = np.array([0.485, 0.456, 0.406]) + stddev_vec = np.array([0.229, 0.224, 0.225]) + norm_img_data = np.zeros(img_data.shape).astype('float32') + for i in range(img_data.shape[0]): + norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i] + return norm_img_data.reshape(1, 3, 224, 224).astype('float32') + + def get_prediction(model, image): + pass + + def softmax(x): + x = x.reshape(-1) + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum(axis=0) + + def load_imgnet_labels(): + mx.test_utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/image_net_labels.json') + return np.array(json.load(open('image_net_labels.json', 'r'))) + + def download_test_images(): + test_images = [ + ['dog.jpg',['boxer']], + ['apron.jpg', ['apron', 'maillot']], + ['dolphin.jpg', ['great white shark','grey whale']], + ['hammerheadshark.jpg', ['tiger shark']], + ['lotus.jpg', ['pinwheel','pot']] + ] + for f,_ in test_images: + mx.test_utils.download('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/onnx/images/'+f+'?raw=true', + fname=f) + return test_images + + + test_models = [ + 'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25', + 'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25', + 'resnet18_v1', 'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1', 'resnet50_v2', + 'resnet101_v1', 'resnet101_v2', 'resnet152_v1', 'resnet152_v2', + 'squeezenet1.0', 'squeezenet1.1', + 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn' + ] + labels = load_imgnet_labels() + test_images = download_test_images() + + for model in test_models: + tmpdir = tempfile.mkdtemp() + sym_file, params_file = get_gluon_cv_model(model, tmpdir) + onnx_file = export_model_to_onnx(sym_file, params_file) + #print("exported onnx file: ",onnx_file) + + # create onnxruntime session using the generated onnx file + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + input_name = session.get_inputs()[0].name + + for img,classes in test_images: + img_data = normalize_image(img) + raw_result = session.run([], {input_name: img_data}) + res = softmax(np.array(raw_result)).tolist() + class_idx = np.argmax(res) + #print("Image top classification:",labels[class_idx]) + sort_idx = np.flip(np.squeeze(np.argsort(res))) + #print("\tTop labels: " + ",".join(labels[sort_idx[:5]])) + correct_classification = False + for label in labels[sort_idx[:5]]: + for c in classes: + if c in label: + correct_classification = True + assert correct_classification == True + + # cleanup + shutil.rmtree(tmpdir) + + + + +if __name__ == "__main__": + test_cv_model_inference_onnxruntime() + diff --git a/tests/python/unittest/onnx/backend.py b/tests/python/unittest/onnx/backend.py index 2f9e2470d225..eb803f790332 100644 --- a/tests/python/unittest/onnx/backend.py +++ b/tests/python/unittest/onnx/backend.py @@ -26,6 +26,7 @@ try: from onnx import helper, TensorProto, mapping from onnx.backend.base import Backend + from onnx.defs import onnx_opset_version except ImportError: raise ImportError("Onnx and protobuf need to be installed. Instructions to" + " install - https://github.com/onnx/onnx#installation") @@ -57,13 +58,16 @@ def perform_import_export(sym, arg_params, aux_params, input_shape): params = {} params.update(arg_params) params.update(aux_params) + # use the latest opset version supported by the onnx library + opset_version = onnx_opset_version() # exporting to onnx graph proto format converter = MXNetGraph() graph_proto = converter.create_onnx_graph_proto(sym, params, in_shape=input_shape, - in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')]) + in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')], + opset_version=opset_version) # importing back to MXNET for verifying result. - sym, arg_params, aux_params = graph.from_onnx(graph_proto) + sym, arg_params, aux_params = graph.from_onnx(graph_proto, opset_version) return sym, arg_params, aux_params @@ -95,8 +99,11 @@ def prepare(cls, model, device='CPU', **kwargs): else: raise NotImplementedError("ONNX tests are run only for CPU context.") + # determine opset version model uses + model_opset_version = max([x.version for x in model.opset_import]) + if backend == 'mxnet': - sym, arg_params, aux_params = graph.from_onnx(model.graph) + sym, arg_params, aux_params = graph.from_onnx(model.graph, model_opset_version) if operation == 'export': metadata = graph.get_graph_metadata(model.graph) input_data = metadata['input_tensor_data'] @@ -107,7 +114,7 @@ def prepare(cls, model, device='CPU', **kwargs): return MXNetBackendRep(sym, arg_params, aux_params, device) elif backend == 'gluon': if operation == 'import': - net = graph.graph_to_gluon(model.graph, ctx) + net = graph.graph_to_gluon(model.graph, ctx, model_opset_version) return GluonBackendRep(net, device) elif operation == 'export': raise NotImplementedError("Gluon->ONNX export not implemented.")