Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/run_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time
import traceback
import zipfile
import logging

import PIL.Image
import numpy as np
Expand All @@ -34,6 +35,9 @@

# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("tf2onnx")

TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained")
PERFITER = 1000

Expand Down Expand Up @@ -246,9 +250,10 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
for k in inputs.keys(): # pylint: disable=consider-iterating-dictionary
t = sess.graph.get_tensor_by_name(k)
dtype = tf.as_dtype(t.dtype).name
if type != "float32":
v = inputs[k]
inputs[k] = v.astype(dtype)
v = inputs[k]
if dtype != v.dtype:
log.warning("input dtype doesn't match tensorflow's")
inputs[k] = np.array(v, dtype=dtype)
if self.force_input_shape:
for k, v in inputs.items():
shape_override[k] = list(v.shape)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,15 @@ def test_min(self):
_ = tf.identity(mi, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})

tf.reset_default_graph()
x_val1 = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.int32).reshape((2, 2))
x_val2 = np.array([4.0, 4.0, 4.0, 4.0], dtype=np.int32).reshape((2, 2))
x1 = tf.placeholder(tf.int32, x_val1.shape, name=_TFINPUT)
x2 = tf.placeholder(tf.int32, x_val2.shape, name=_TFINPUT1)
mi = tf.minimum(x1, x2)
_ = tf.identity(mi, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})

@skip_caffe2_backend("issue with broadcasting scalar")
@check_onnxruntime_incompatibility("Sub")
def test_min_broadcast(self):
Expand Down Expand Up @@ -788,6 +797,35 @@ def test_concat(self):
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2, "input3:0": x_val3})

def test_concat_empty_const_input(self):
x_val1 = np.array([1, 2, 3], dtype=np.float32)
x_val2 = np.array([], dtype=np.float32)
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
x2 = tf.constant(x_val2, dtype=tf.float32)
x_ = tf.concat([x1, x2], 0)
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1})

tf.reset_default_graph()
x_val1 = np.array([[1, 2, 3]], dtype=np.float32)
x_val2 = np.array([[]], dtype=np.float32)
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
x2 = tf.constant(x_val2, dtype=tf.float32)
x_ = tf.concat([x1, x2], 1)
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1})

tf.reset_default_graph()
x_val1 = np.array([1, 2, 3], dtype=np.float32)
x_val2 = np.array([], dtype=np.float32)
x_val3 = np.array([13, 14, 15], dtype=np.float32)
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
x2 = tf.constant(x_val2, dtype=tf.float32)
x3 = tf.placeholder(tf.float32, x_val3.shape, name=_TFINPUT1)
x_ = tf.concat([x1, x2, x3], 0)
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val3})

@check_opset_min_version(6, "cast")
def test_concat_int64(self):
x_val1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
Expand Down
34 changes: 33 additions & 1 deletion tests/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from collections import namedtuple

import graphviz as gv
import numpy as np
from onnx import TensorProto
from onnx import helper
from onnx import helper, numpy_helper

import tensorflow as tf
from tf2onnx import utils
Expand Down Expand Up @@ -247,6 +248,37 @@ def test_node_attr_onnx(self):
self.assertTrue("my_attr" in n1.attr)
self.assertTrue("my_attr" in n1.attr_onnx)

def test_tensor_data(self):
tensors = {
"empty_tensor": np.array([], dtype=np.float32),
"multi_dim_empty_tensor": np.array([[], []], dtype=np.float32),
"scalar": np.array(1., dtype=np.float32),
"one_item_array": np.array([1.], dtype=np.float32),
"normal_array": np.array([[1., 2.], [2., 3.]], dtype=np.float32)
}
tf.reset_default_graph()
with tf.Session() as sess:
for n, data in tensors.items():
tf.constant(data, dtype=tf.float32, name=n)

for tf_node in sess.graph.get_operations():
name = tf_node.name
self.assertTrue(name in tensors.keys())

self.assertTrue("value" in tf_node.node_def.attr)
# convert to onnx tensor value
tensor_value = utils.tf_to_onnx_tensor(
utils.get_tf_node_attr(tf_node, "value"),
name=utils.port_name(tf_node.name)
)
attr = helper.make_attribute("value", tensor_value)
# same as node.get_tensor_value(is_list=False)
actual = numpy_helper.to_array(helper.get_attribute_value(attr))

expected = tensors[name]

self.assertTrue(np.array_equal(expected, actual))


if __name__ == '__main__':
unittest_main()
2 changes: 1 addition & 1 deletion tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
break
return new_node

def insert_new_node_on_output(self, op_type, output_name, name=None, domain=None, **kwargs):
def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **kwargs):
"""Create and insert a new node into the graph.
Args:
op_type: type for new operation
Expand Down
32 changes: 32 additions & 0 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,15 @@ def concat_op(ctx, node, name, args):
def concatv2_op(ctx, node, name, args):
# T output = ConcatV2(T values, Tidx axis, @int N, @type Tidx)
# T concat_result = Concat(T inputs, @INT axis)
# if any input is empty, remove the input and concat the others
# NOTE: workaround for https://github.com/Microsoft/onnxruntime/issues/681
for i, inp in enumerate(node.inputs):
if inp.is_const() and inp.get_tensor_value(as_list=False).size == 0:
ctx.remove_input(node, node.input[i])
# all inputs are deleted
if not node.input:
raise RuntimeError("all inputs of {} are empty".format(name))

axis_node = node.inputs[-1]
axis_val = axis_node.get_tensor_value()
ctx.remove_input(node, node.input[-1])
Expand Down Expand Up @@ -1154,6 +1163,28 @@ def minmax_op(ctx, node, name, args):
# handle this by doing something like:
# y = min(x1, add(x2, sub(x1, x1))), where x1, x2 are the inputs and x2 is a scalar
# this will create a tensor of zeros of the shape of x1, adds x2 to it (which broadcasts) and use that for min.
# support more dtype
supported_dtypes = [
onnx_pb.TensorProto.FLOAT,
onnx_pb.TensorProto.FLOAT16,
onnx_pb.TensorProto.DOUBLE
]
target_dtype = onnx_pb.TensorProto.FLOAT
for inp in node.input:
dtype = ctx.get_dtype(inp)
utils.make_sure(dtype, "dtype of {} is None".format(inp))
if dtype not in supported_dtypes:
inp_cast = ctx.insert_new_node_on_input(node, "Cast", inp, to=target_dtype)
ctx.copy_shape(inp, inp_cast.output[0])
ctx.set_dtype(inp_cast.output[0], target_dtype)
origin_dtype = ctx.get_dtype(node.output[0])
utils.make_sure(origin_dtype is not None, "dtype of {} is None".format(node.output[0]))
ctx.set_dtype(node.output[0], target_dtype)
cast_name = utils.make_name(name)
cast_node = ctx.insert_new_node_on_output("Cast", node.output[0], name=cast_name, to=origin_dtype)
to_replace = [n for n in ctx.get_nodes() if n != cast_node]
ctx.replace_all_inputs(to_replace, node.output[0], cast_node.output[0])

shapeo = ctx.get_shape(node.output[0])
needs_broadcast_op = []
has_correct_shape = []
Expand Down Expand Up @@ -1676,6 +1707,7 @@ def where_op(ctx, node, name, args):
"BiasAdd": (biasadd_op, []),
"BiasAddV1": (biasadd_op, []),
"Cast": (cast_op, []),
"CheckNumerics": (identity_op, ["Identity"]),
"Concat": (concat_op, ["Concat"]),
"ConcatV2": (concatv2_op, ["Concat"]),
"Const": (direct_op, []),
Expand Down
49 changes: 33 additions & 16 deletions tf2onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,39 @@ def split_nodename_and_shape(name):


def tf_to_onnx_tensor(tensor, name=""):
"""Convert tensorflow tensor to onnx tensor."""
"""
Convert tensorflow tensor to onnx tensor.
Here deal with three types of tensor:
1. normal tensor, e.g., np.array([1,2,3], dtype=DTYPE):
tensor_content: raw data of [1,2,3]
tensor_shape.dim: [3]
DTYPE_val: empty
2. scalar tensor, e.g., np.array(1, dtype=DTYPE):
tensor_content: empty
tensor_shape.dim: [0]
DTYPE_val: 1
3. empty tensor, e.g., np.array([], dtype=DTYPE) and np.array([[]], dtype=DTYPE):
tensor_content: empty
tensor_shape.dim: [0] and [1, 0]
DTYPE_val: empty
"""
new_type = TF_TO_ONNX_DTYPE[tensor.dtype]
tdim = tensor.tensor_shape.dim
dims = [d.size for d in tdim]
# FIXME: something is fishy here
if dims == [0]:
dims = [1]
is_raw, data = get_tf_tensor_data(tensor)
# empty tensor
if not is_raw and data is None:
np_data = np.array([], dtype=map_onnx_to_numpy_type(new_type)).reshape(dims)
return numpy_helper.from_array(np_data, name=name)
make_sure(data, "tensor data isn't expected to be None or empty")
# scalar tensor
if dims == [0] and not is_raw and len(data) == 1:
return helper.make_tensor(name, new_type, [], data, False)
if not is_raw and len(data) == 1 and np.prod(dims) > 1:
batch_data = np.zeros(dims, dtype=map_onnx_to_numpy_type(new_type))
batch_data.fill(data[0])
onnx_tensor = numpy_helper.from_array(batch_data, name=name)
else:
onnx_tensor = helper.make_tensor(name, new_type, dims, data, is_raw)
return onnx_tensor
return numpy_helper.from_array(batch_data, name=name)
return helper.make_tensor(name, new_type, dims, data, is_raw)


def get_tf_tensor_data(tensor):
Expand All @@ -154,16 +172,15 @@ def get_tf_tensor_data(tensor):
data = tensor.int64_val
elif tensor.bool_val:
data = tensor.bool_val
elif tensor.dtype == tf.int32:
data = [0]
elif tensor.dtype == tf.int64:
data = [0]
elif tensor.dtype == tf.float32:
data = [0.]
elif tensor.dtype == tf.float16:
data = [0]
elif tensor.string_val:
data = tensor.string_val
elif tensor.dtype in [
tf.int32,
tf.int64,
tf.float32,
tf.float16
]:
data = None
else:
raise ValueError('tensor data not supported')
return [is_raw, data]
Expand Down