diff --git a/tests/run_pretrained_models.py b/tests/run_pretrained_models.py index bc585f338..8ac41127e 100644 --- a/tests/run_pretrained_models.py +++ b/tests/run_pretrained_models.py @@ -15,6 +15,7 @@ import time import traceback import zipfile +import logging import PIL.Image import numpy as np @@ -34,6 +35,9 @@ # pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("tf2onnx") + TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained") PERFITER = 1000 @@ -246,9 +250,10 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No for k in inputs.keys(): # pylint: disable=consider-iterating-dictionary t = sess.graph.get_tensor_by_name(k) dtype = tf.as_dtype(t.dtype).name - if type != "float32": - v = inputs[k] - inputs[k] = v.astype(dtype) + v = inputs[k] + if dtype != v.dtype: + log.warning("input dtype doesn't match tensorflow's") + inputs[k] = np.array(v, dtype=dtype) if self.force_input_shape: for k, v in inputs.items(): shape_override[k] = list(v.shape) diff --git a/tests/test_backend.py b/tests/test_backend.py index d81edab9f..a47c802ee 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -596,6 +596,15 @@ def test_min(self): _ = tf.identity(mi, name=_TFOUTPUT) self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2}) + tf.reset_default_graph() + x_val1 = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.int32).reshape((2, 2)) + x_val2 = np.array([4.0, 4.0, 4.0, 4.0], dtype=np.int32).reshape((2, 2)) + x1 = tf.placeholder(tf.int32, x_val1.shape, name=_TFINPUT) + x2 = tf.placeholder(tf.int32, x_val2.shape, name=_TFINPUT1) + mi = tf.minimum(x1, x2) + _ = tf.identity(mi, name=_TFOUTPUT) + self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2}) + @skip_caffe2_backend("issue with broadcasting scalar") @check_onnxruntime_incompatibility("Sub") def test_min_broadcast(self): @@ -788,6 +797,35 @@ def test_concat(self): _ = tf.identity(x_, name=_TFOUTPUT) self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3}) + def test_concat_empty_const_input(self): + x_val1 = np.array([1, 2, 3], dtype=np.float32) + x_val2 = np.array([], dtype=np.float32) + x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT) + x2 = tf.constant(x_val2, dtype=tf.float32) + x_ = tf.concat([x1, x2], 0) + _ = tf.identity(x_, name=_TFOUTPUT) + self._run_test_case([_OUTPUT], {_INPUT: x_val1}) + + tf.reset_default_graph() + x_val1 = np.array([[1, 2, 3]], dtype=np.float32) + x_val2 = np.array([[]], dtype=np.float32) + x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT) + x2 = tf.constant(x_val2, dtype=tf.float32) + x_ = tf.concat([x1, x2], 1) + _ = tf.identity(x_, name=_TFOUTPUT) + self._run_test_case([_OUTPUT], {_INPUT: x_val1}) + + tf.reset_default_graph() + x_val1 = np.array([1, 2, 3], dtype=np.float32) + x_val2 = np.array([], dtype=np.float32) + x_val3 = np.array([13, 14, 15], dtype=np.float32) + x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT) + x2 = tf.constant(x_val2, dtype=tf.float32) + x3 = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT1) + x_ = tf.concat([x1, x2, x3], 0) + _ = tf.identity(x_, name=_TFOUTPUT) + self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val3}) + @check_opset_min_version(6, "cast") def test_concat_int64(self): x_val1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) diff --git a/tests/test_internals.py b/tests/test_internals.py index bdce6ff8b..1a3d6b340 100644 --- a/tests/test_internals.py +++ b/tests/test_internals.py @@ -12,8 +12,9 @@ from collections import namedtuple import graphviz as gv +import numpy as np from onnx import TensorProto -from onnx import helper +from onnx import helper, numpy_helper import tensorflow as tf from tf2onnx import utils @@ -247,6 +248,37 @@ def test_node_attr_onnx(self): self.assertTrue("my_attr" in n1.attr) self.assertTrue("my_attr" in n1.attr_onnx) + def test_tensor_data(self): + tensors = { + "empty_tensor": np.array([], dtype=np.float32), + "multi_dim_empty_tensor": np.array([[], []], dtype=np.float32), + "scalar": np.array(1., dtype=np.float32), + "one_item_array": np.array([1.], dtype=np.float32), + "normal_array": np.array([[1., 2.], [2., 3.]], dtype=np.float32) + } + tf.reset_default_graph() + with tf.Session() as sess: + for n, data in tensors.items(): + tf.constant(data, dtype=tf.float32, name=n) + + for tf_node in sess.graph.get_operations(): + name = tf_node.name + self.assertTrue(name in tensors.keys()) + + self.assertTrue("value" in tf_node.node_def.attr) + # convert to onnx tensor value + tensor_value = utils.tf_to_onnx_tensor( + utils.get_tf_node_attr(tf_node, "value"), + name=utils.port_name(tf_node.name) + ) + attr = helper.make_attribute("value", tensor_value) + # same as node.get_tensor_value(is_list=False) + actual = numpy_helper.to_array(helper.get_attribute_value(attr)) + + expected = tensors[name] + + self.assertTrue(np.array_equal(expected, actual)) + if __name__ == '__main__': unittest_main() diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py index 172169da4..e20e44a03 100644 --- a/tf2onnx/graph.py +++ b/tf2onnx/graph.py @@ -908,7 +908,7 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain= break return new_node - def insert_new_node_on_output(self, op_type, output_name, name=None, domain=None, **kwargs): + def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **kwargs): """Create and insert a new node into the graph. Args: op_type: type for new operation diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 6b4190a33..f35b86e7d 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -766,6 +766,15 @@ def concat_op(ctx, node, name, args): def concatv2_op(ctx, node, name, args): # T output = ConcatV2(T values, Tidx axis, @int N, @type Tidx) # T concat_result = Concat(T inputs, @INT axis) + # if any input is empty, remove the input and concat the others + # NOTE: workaround for https://github.com/Microsoft/onnxruntime/issues/681 + for i, inp in enumerate(node.inputs): + if inp.is_const() and inp.get_tensor_value(as_list=False).size == 0: + ctx.remove_input(node, node.input[i]) + # all inputs are deleted + if not node.input: + raise RuntimeError("all inputs of {} are empty".format(name)) + axis_node = node.inputs[-1] axis_val = axis_node.get_tensor_value() ctx.remove_input(node, node.input[-1]) @@ -1154,6 +1163,28 @@ def minmax_op(ctx, node, name, args): # handle this by doing something like: # y = min(x1, add(x2, sub(x1, x1))), where x1, x2 are the inputs and x2 is a scalar # this will create a tensor of zeros of the shape of x1, adds x2 to it (which broadcasts) and use that for min. + # support more dtype + supported_dtypes = [ + onnx_pb.TensorProto.FLOAT, + onnx_pb.TensorProto.FLOAT16, + onnx_pb.TensorProto.DOUBLE + ] + target_dtype = onnx_pb.TensorProto.FLOAT + for inp in node.input: + dtype = ctx.get_dtype(inp) + utils.make_sure(dtype, "dtype of {} is None".format(inp)) + if dtype not in supported_dtypes: + inp_cast = ctx.insert_new_node_on_input(node, "Cast", inp, to=target_dtype) + ctx.copy_shape(inp, inp_cast.output[0]) + ctx.set_dtype(inp_cast.output[0], target_dtype) + origin_dtype = ctx.get_dtype(node.output[0]) + utils.make_sure(origin_dtype is not None, "dtype of {} is None".format(node.output[0])) + ctx.set_dtype(node.output[0], target_dtype) + cast_name = utils.make_name(name) + cast_node = ctx.insert_new_node_on_output("Cast", node.output[0], name=cast_name, to=origin_dtype) + to_replace = [n for n in ctx.get_nodes() if n != cast_node] + ctx.replace_all_inputs(to_replace, node.output[0], cast_node.output[0]) + shapeo = ctx.get_shape(node.output[0]) needs_broadcast_op = [] has_correct_shape = [] @@ -1676,6 +1707,7 @@ def where_op(ctx, node, name, args): "BiasAdd": (biasadd_op, []), "BiasAddV1": (biasadd_op, []), "Cast": (cast_op, []), + "CheckNumerics": (identity_op, ["Identity"]), "Concat": (concat_op, ["Concat"]), "ConcatV2": (concatv2_op, ["Concat"]), "Const": (direct_op, []), diff --git a/tf2onnx/utils.py b/tf2onnx/utils.py index 98cba4d49..6e93095d4 100644 --- a/tf2onnx/utils.py +++ b/tf2onnx/utils.py @@ -120,21 +120,39 @@ def split_nodename_and_shape(name): def tf_to_onnx_tensor(tensor, name=""): - """Convert tensorflow tensor to onnx tensor.""" + """ + Convert tensorflow tensor to onnx tensor. + Here deal with three types of tensor: + 1. normal tensor, e.g., np.array([1,2,3], dtype=DTYPE): + tensor_content: raw data of [1,2,3] + tensor_shape.dim: [3] + DTYPE_val: empty + 2. scalar tensor, e.g., np.array(1, dtype=DTYPE): + tensor_content: empty + tensor_shape.dim: [0] + DTYPE_val: 1 + 3. empty tensor, e.g., np.array([], dtype=DTYPE) and np.array([[]], dtype=DTYPE): + tensor_content: empty + tensor_shape.dim: [0] and [1, 0] + DTYPE_val: empty + """ new_type = TF_TO_ONNX_DTYPE[tensor.dtype] tdim = tensor.tensor_shape.dim dims = [d.size for d in tdim] - # FIXME: something is fishy here - if dims == [0]: - dims = [1] is_raw, data = get_tf_tensor_data(tensor) + # empty tensor + if not is_raw and data is None: + np_data = np.array([], dtype=map_onnx_to_numpy_type(new_type)).reshape(dims) + return numpy_helper.from_array(np_data, name=name) + make_sure(data, "tensor data isn't expected to be None or empty") + # scalar tensor + if dims == [0] and not is_raw and len(data) == 1: + return helper.make_tensor(name, new_type, [], data, False) if not is_raw and len(data) == 1 and np.prod(dims) > 1: batch_data = np.zeros(dims, dtype=map_onnx_to_numpy_type(new_type)) batch_data.fill(data[0]) - onnx_tensor = numpy_helper.from_array(batch_data, name=name) - else: - onnx_tensor = helper.make_tensor(name, new_type, dims, data, is_raw) - return onnx_tensor + return numpy_helper.from_array(batch_data, name=name) + return helper.make_tensor(name, new_type, dims, data, is_raw) def get_tf_tensor_data(tensor): @@ -154,16 +172,15 @@ def get_tf_tensor_data(tensor): data = tensor.int64_val elif tensor.bool_val: data = tensor.bool_val - elif tensor.dtype == tf.int32: - data = [0] - elif tensor.dtype == tf.int64: - data = [0] - elif tensor.dtype == tf.float32: - data = [0.] - elif tensor.dtype == tf.float16: - data = [0] elif tensor.string_val: data = tensor.string_val + elif tensor.dtype in [ + tf.int32, + tf.int64, + tf.float32, + tf.float16 + ]: + data = None else: raise ValueError('tensor data not supported') return [is_raw, data]