Skip to content

Commit

Permalink
change op-level seed during copy; fix #501
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Apr 19, 2017
1 parent b3a0f34 commit 3a5654b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
18 changes: 11 additions & 7 deletions edward/util/random_variables.py
Expand Up @@ -93,11 +93,13 @@ def copy(org_instance, dict_swap=None, scope="copied",
where any of its ancestors existing in `dict_swap` are
replaced with `dict_swap`'s corresponding value.
The copying is done recursively, so any `Operation` whose output
is required to evaluate `org_instance` is also copied (if it isn't
already copied within the new scope). This is with the exception of
`tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue`, which
are reused and not newly copied.
Copying is done recursively. Any `Operation` whose output is
required to copy `org_instance` is also copied (if it isn't already
copied within the new scope).
`tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are
always reused and not copied. In addition, `tf.Operation`s with
operation-level seeds are copied with a new operation-level seed.
Parameters
----------
Expand Down Expand Up @@ -259,10 +261,12 @@ def copy(org_instance, dict_swap=None, scope="copied",
return op

# Copy the node def.
# It stores string-based info such as name, device, and type of
# the op. It is unique to every Operation instance.
# It is unique to every Operation instance. Replace the name and
# its operation-level seed if it has one.
node_def = deepcopy(op.node_def)
node_def.name = new_name
if 'seed2' in node_def.attr:
node_def.attr['seed2'].i = tf.get_seed(None)[1]

# Copy other arguments needed for initialization.
output_types = op._output_types[:]
Expand Down
21 changes: 9 additions & 12 deletions tests/test-util/test_copy.py
Expand Up @@ -54,6 +54,15 @@ def test_list(self):
z_new = ed.copy(z, {x: y.value()})
self.assertGreater(z_new.value().eval(), 5.0)

def test_random(self):
with self.test_session() as sess:
ed.set_seed(3742)
x = tf.random_normal([])
x_copy = ed.copy(x)

result_copy, result = sess.run([x_copy, x])
self.assertNotAlmostEquals(result_copy, result)

def test_scan(self):
with self.test_session() as sess:
ed.set_seed(42)
Expand All @@ -64,18 +73,6 @@ def test_scan(self):
self.assertAllClose(result_copy, [2.0, 5.0, 6.0])
self.assertAllClose(result, [2.0, 5.0, 6.0])

def test_scan_random(self):
with self.test_session() as sess:
ed.set_seed(1234)
op = tf.scan(lambda a, x: a + x, tf.random_normal([3]))
copy_op = ed.copy(op)

# check that the random inputs are different
# currently ed.set_seed prevents variate generation to work
# result_copy, result = sess.run([copy_op, op])
# for elem_copy, elem in zip(result_copy, result):
# self.assertNotAlmostEquals(elem_copy, elem)

def test_swap_tensor_tensor(self):
with self.test_session():
x = tf.constant(2.0)
Expand Down

0 comments on commit 3a5654b

Please sign in to comment.