Skip to content

Commit

Permalink
[Relay][Frontend][TF] fix _parse_param bug (apache#4711)
Browse files Browse the repository at this point in the history
  • Loading branch information
fwd4 authored and zhiics committed Mar 2, 2020
1 parent 7b362cb commit 76d93db
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,7 @@ def _parse_param(self, key, value, name, shape):
if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
if shape:
if shape and name in shape:
var_shape = shape[name]
else:
var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
Expand Down
18 changes: 11 additions & 7 deletions tests/python/frontend/tensorflow/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow

def run_relay(graph, *vars):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
def run_relay(graph, shape_dict=None, *vars):
mod, params = from_tensorflow(
graph.as_graph_def(add_shapes=True),
shape=shape_dict)
ex = relay.create_executor('debug', mod=mod)
return ex.evaluate()(*vars)

def test_assert_true():
g = tf.Graph()
shape = (1, 2)
with g.as_default():
x = tf.placeholder(tf.float32, shape=())
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"])
x = tf.placeholder(tf.float32, shape=shape, name="input")
assert_op = tf.Assert(tf.reduce_all(tf.less_equal(x, x)), ["it failed"])

with tf.Session() as sess:
x_value = np.random.rand()
x_value = np.random.rand(*shape)
assert sess.run(assert_op, feed_dict={x: x_value}) is None

# In TVM, tf.assert is converted to a no-op which is actually a 0,
Expand All @@ -44,7 +47,7 @@ def test_assert_true():
# do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in
# x_value here.
np.testing.assert_allclose(0, run_relay(g).asnumpy())
np.testing.assert_allclose(0, run_relay(g, {'input':shape}).asnumpy())

def test_assert_true_var_capture():
g = tf.Graph()
Expand All @@ -65,7 +68,8 @@ def test_assert_true_var_capture():
# the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the
# arity of the translated function should be 1, not 2.
np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy())
np.testing.assert_allclose(True,
run_relay(g, None, x_value, x_value).asnumpy())

def test_assert_false():
g = tf.Graph()
Expand Down

0 comments on commit 76d93db

Please sign in to comment.