Skip to content

Commit

Permalink
add test to copy_parent_rvs
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Sep 10, 2017
1 parent ec4976e commit 7f3fa57
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
22 changes: 5 additions & 17 deletions edward/util/random_variables.py
Expand Up @@ -87,7 +87,7 @@ def _copy_default(x, *args, **kwargs):


def copy(org_instance, dict_swap=None, scope="copied",
replace_itself=False, copy_q=False, copy_parents=True):
replace_itself=False, copy_q=False, copy_parent_rvs=True):
"""Build a new node in the TensorFlow graph from `org_instance`,
where any of its ancestors existing in `dict_swap` are
replaced with `dict_swap`'s corresponding value.
Expand Down Expand Up @@ -118,9 +118,10 @@ def copy(org_instance, dict_swap=None, scope="copied",
copy_q: bool, optional.
Whether to copy the replaced tensors too (if not already
copied within the new scope). Otherwise will reuse them.
copy_parents:
copy_parent_rvs:
Whether to copy parent random variables `org_instance` depends
on before copy `org_instance`.
on. Otherwise will copy only the sample tensors and not the
random variable class itself.
Returns:
RandomVariable, tf.Variable, tf.Tensor, or tf.Operation.
Expand Down Expand Up @@ -219,22 +220,9 @@ def copy(org_instance, dict_swap=None, scope="copied",
# Preserve ordering of random variables. Random variables are always
# copied first (from parent -> child) before any deterministic
# operations that depend on them.
if copy_parents and \
if copy_parent_rvs and \
isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)):
for v in get_parents(org_instance):
# 'False' forces the top-most random variables to be copied
# first. This may be slow: suppose x[1] -> ... -> x[T] and we
# call copy(x[T]). get_parents finds x[t-1] and calls copy
# again; this leads to calling get_parents and copy on T many
# random variables.
# Subsequent calls will be True, so it never always calls this.
# TODO
# False has the unintended consequence of copying priors rather
# than just outputting their associated copy thing. so maybe i
# don't want False? this conflicts with the replace_itself
# clause to enter into this
# + this is why we need yet another arg such that we can still
# call replace itself
copy(v, dict_swap, scope, True, copy_q, True)

if isinstance(org_instance, RandomVariable):
Expand Down
10 changes: 10 additions & 0 deletions tests/test-util/test_copy.py
Expand Up @@ -38,6 +38,16 @@ def test_copy_q(self):
self.assertNotEqual(x_new_val, x_val)
self.assertNotEqual(x_new_val, y_val)

def test_copy_parent_rvs(self):
with self.test_session() as sess:
x = Normal(0.0, 1.0)
y = tf.constant(3.0)
z = x * y
z_new = ed.copy(z, scope='no_copy_parent_rvs', copy_parent_rvs=False)
self.assertEqual(len(ed.random_variables()), 1)
z_new = ed.copy(z, scope='copy_parent_rvs', copy_parent_rvs=True)
self.assertEqual(len(ed.random_variables()), 2)

def test_placeholder(self):
with self.test_session() as sess:
x = tf.placeholder(tf.float32, name="CustomName")
Expand Down

0 comments on commit 7f3fa57

Please sign in to comment.