Skip to content

Commit

Permalink
add get_sample_shape() method; add to docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 27, 2017
1 parent 230d505 commit bc857c8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
40 changes: 28 additions & 12 deletions edward/models/random_variable.py
Expand Up @@ -24,6 +24,17 @@ class RandomVariable(object):
graph, allowing random variables to be used in conjunction with
other TensorFlow ops.
The random variable's shape is given by
``sample_shape + batch_shape + event_shape``,
where ``sample_shape`` is an optional argument representing the
dimensions of samples drawn from the distribution (default is
a scalar); ``batch_shape`` is the number of independent random variables
(determined by the shape of its parameters); and ``event_shape`` is
the shape of one draw from the distribution (e.g., ``Normal`` has a
scalar ``event_shape``; ``Dirichlet`` has a vector ``event_shape``).
Notes
-----
``RandomVariable`` assumes use in a multiple inheritance setting. The
Expand Down Expand Up @@ -51,8 +62,8 @@ class in ``tf.contrib.distributions``. With Python's method resolution
>>> p = tf.constant(0.5)
>>> x = Bernoulli(p=p)
>>>
>>> z1 = tf.constant([[2.0, 8.0]])
>>> z2 = tf.constant([[1.0, 2.0]])
>>> z1 = tf.constant([[2.0, 8.0], [1.0, 2.0]])
>>> z2 = tf.constant([[1.0, 2.0], [3.0, 1.0]])
>>> x = Bernoulli(p=tf.matmul(z1, z2))
>>>
>>> mu = Normal(mu=tf.constant(0.0), sigma=tf.constant(1.0))
Expand All @@ -63,31 +74,32 @@ def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs

# need to temporarily pop value before __init__
# temporarily pop before calling parent __init__
value = kwargs.pop('value', None)
sample_shape = kwargs.pop('sample_shape', ())
self._sample_shape = kwargs.pop('sample_shape', tf.TensorShape([]))
super(RandomVariable, self).__init__(*args, **kwargs)
# reinsert (needed for copying)
if value is not None:
self._kwargs['value'] = value # reinsert (needed for copying)
if sample_shape is not None:
self._kwargs['sample_shape'] = sample_shape
self._kwargs['value'] = value
if self._sample_shape != tf.TensorShape([]):
self._kwargs['sample_shape'] = self._sample_shape

tf.add_to_collection(RANDOM_VARIABLE_COLLECTION, self)

if value is not None:
t_value = tf.convert_to_tensor(value, self.dtype)
expected_shape = (self.get_batch_shape().as_list() +
self.get_event_shape().as_list())
value_shape = t_value.get_shape().as_list()
if value_shape[:len(expected_shape)] != expected_shape:
value_shape = t_value.shape
expected_shape = self.get_sample_shape().concatenate(
self.get_batch_shape()).concatenate(self.get_event_shape())
if value_shape != expected_shape:
raise ValueError(
"Incompatible shape for initialization argument 'value'. "
"Expected %s, got %s." % (expected_shape, value_shape))
else:
self._value = t_value
else:
try:
self._value = self.sample(sample_shape)
self._value = self.sample(self._sample_shape)
except NotImplementedError:
raise NotImplementedError(
"sample is not implemented for {0}. You must either pass in the "
Expand Down Expand Up @@ -294,6 +306,10 @@ def get_shape(self):
"""Get shape of random variable."""
return self.shape

def get_sample_shape(self):
"""Sample shape of random variable."""
return self._sample_shape

@staticmethod
def _session_run_conversion_fetch_function(tensor):
return ([tensor.value()], lambda val: val[0])
Expand Down
5 changes: 3 additions & 2 deletions tests/test-models/test_random_variable_value.py
Expand Up @@ -14,8 +14,9 @@ class test_random_variable_value_class(tf.test.TestCase):
def _test_sample(self, RV, value, *args, **kwargs):
rv = RV(*args, value=value, **kwargs)
value_shape = rv.value().shape
expected_shape = rv.get_batch_shape().concatenate(rv.get_event_shape())
self.assertEqual(value_shape[-len(expected_shape):], expected_shape)
expected_shape = rv.get_sample_shape().concatenate(
rv.get_batch_shape()).concatenate(rv.get_event_shape())
self.assertEqual(value_shape, expected_shape)
self.assertEqual(rv.dtype, rv.value().dtype)

def _test_copy(self, RV, value, *args, **kwargs):
Expand Down

0 comments on commit bc857c8

Please sign in to comment.