Skip to content

Commit

Permalink
use concatenate() rather than as_list() methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 23, 2017
1 parent 25e3285 commit 4184414
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
6 changes: 3 additions & 3 deletions edward/models/random_variable.py
Expand Up @@ -73,9 +73,9 @@ def __init__(self, *args, **kwargs):

if value is not None:
t_value = tf.convert_to_tensor(value, self.dtype)
expected_shape = (self.get_batch_shape().as_list() +
self.get_event_shape().as_list())
value_shape = t_value.shape.as_list()
value_shape = t_value.shape
expected_shape = self.get_batch_shape().concatenate(
self.get_event_shape())
if value_shape != expected_shape:
raise ValueError(
"Incompatible shape for initialization argument 'value'. "
Expand Down
9 changes: 4 additions & 5 deletions tests/test-models/test_random_variable_value.py
Expand Up @@ -13,17 +13,16 @@ class test_random_variable_value_class(tf.test.TestCase):

def _test_sample(self, RV, value, *args, **kwargs):
rv = RV(*args, value=value, **kwargs)
value_shape = rv.value().shape.as_list()
expected_shape = (rv.get_batch_shape().as_list() +
rv.get_event_shape().as_list())
value_shape = rv.value().shape
expected_shape = rv.get_batch_shape().concatenate(rv.get_event_shape())
self.assertEqual(value_shape, expected_shape)
self.assertEqual(rv.dtype, rv.value().dtype)

def _test_copy(self, RV, value, *args, **kwargs):
rv1 = RV(*args, value=value, **kwargs)
rv2 = copy(rv1)
value_shape1 = rv1.value().shape.as_list()
value_shape2 = rv2.value().shape.as_list()
value_shape1 = rv1.value().shape
value_shape2 = rv2.value().shape
self.assertEqual(value_shape1, value_shape2)

def test_shape_and_dtype(self):
Expand Down

0 comments on commit 4184414

Please sign in to comment.