Skip to content

Commit

Permalink
generalize RandomVariable's value shape check; fix #519
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Apr 19, 2017
1 parent 9854a55 commit be60dec
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion edward/models/random_variable.py
Expand Up @@ -106,7 +106,7 @@ def __init__(self, *args, **kwargs):
value_shape = t_value.shape
expected_shape = self.get_sample_shape().concatenate(
self.get_batch_shape()).concatenate(self.get_event_shape())
if value_shape != expected_shape:
if not value_shape.is_compatible_with(expected_shape):
raise ValueError(
"Incompatible shape for initialization argument 'value'. "
"Expected %s, got %s." % (expected_shape, value_shape))
Expand Down
6 changes: 5 additions & 1 deletion tests/test-models/test_random_variable_value.py
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import tensorflow as tf

from edward.models import Normal, Poisson, RandomVariable
from edward.models import Bernoulli, Normal, Poisson, RandomVariable
from edward.util import copy


Expand All @@ -32,6 +32,10 @@ def test_shape_and_dtype(self):
self._test_sample(Normal, [2], mu=[0.5], sigma=[1.0])
self._test_sample(Poisson, 2, lam=0.5)

def test_unknown_shape(self):
with self.test_session():
x = Bernoulli(0.5, value=tf.placeholder(tf.int32))

def test_mismatch_raises(self):
with self.test_session():
self.assertRaises(ValueError, self._test_sample, Normal, 2,
Expand Down

0 comments on commit be60dec

Please sign in to comment.