From bc857c8d384354714b180fb7ec97cc1ca19c36f4 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Mon, 27 Mar 2017 14:52:00 -0400 Subject: [PATCH] add get_sample_shape() method; add to docstring --- edward/models/random_variable.py | 40 +++++++++++++------ .../test-models/test_random_variable_value.py | 5 ++- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/edward/models/random_variable.py b/edward/models/random_variable.py index 7bbbc0927..6f8c46264 100644 --- a/edward/models/random_variable.py +++ b/edward/models/random_variable.py @@ -24,6 +24,17 @@ class RandomVariable(object): graph, allowing random variables to be used in conjunction with other TensorFlow ops. + The random variable's shape is given by + + ``sample_shape + batch_shape + event_shape``, + + where ``sample_shape`` is an optional argument representing the + dimensions of samples drawn from the distribution (default is + a scalar); ``batch_shape`` is the number of independent random variables + (determined by the shape of its parameters); and ``event_shape`` is + the shape of one draw from the distribution (e.g., ``Normal`` has a + scalar ``event_shape``; ``Dirichlet`` has a vector ``event_shape``). + Notes ----- ``RandomVariable`` assumes use in a multiple inheritance setting. The @@ -51,8 +62,8 @@ class in ``tf.contrib.distributions``. With Python's method resolution >>> p = tf.constant(0.5) >>> x = Bernoulli(p=p) >>> - >>> z1 = tf.constant([[2.0, 8.0]]) - >>> z2 = tf.constant([[1.0, 2.0]]) + >>> z1 = tf.constant([[2.0, 8.0], [1.0, 2.0]]) + >>> z2 = tf.constant([[1.0, 2.0], [3.0, 1.0]]) >>> x = Bernoulli(p=tf.matmul(z1, z2)) >>> >>> mu = Normal(mu=tf.constant(0.0), sigma=tf.constant(1.0)) @@ -63,23 +74,24 @@ def __init__(self, *args, **kwargs): self._args = args self._kwargs = kwargs - # need to temporarily pop value before __init__ + # temporarily pop before calling parent __init__ value = kwargs.pop('value', None) - sample_shape = kwargs.pop('sample_shape', ()) + self._sample_shape = kwargs.pop('sample_shape', tf.TensorShape([])) super(RandomVariable, self).__init__(*args, **kwargs) + # reinsert (needed for copying) if value is not None: - self._kwargs['value'] = value # reinsert (needed for copying) - if sample_shape is not None: - self._kwargs['sample_shape'] = sample_shape + self._kwargs['value'] = value + if self._sample_shape != tf.TensorShape([]): + self._kwargs['sample_shape'] = self._sample_shape tf.add_to_collection(RANDOM_VARIABLE_COLLECTION, self) if value is not None: t_value = tf.convert_to_tensor(value, self.dtype) - expected_shape = (self.get_batch_shape().as_list() + - self.get_event_shape().as_list()) - value_shape = t_value.get_shape().as_list() - if value_shape[:len(expected_shape)] != expected_shape: + value_shape = t_value.shape + expected_shape = self.get_sample_shape().concatenate( + self.get_batch_shape()).concatenate(self.get_event_shape()) + if value_shape != expected_shape: raise ValueError( "Incompatible shape for initialization argument 'value'. " "Expected %s, got %s." % (expected_shape, value_shape)) @@ -87,7 +99,7 @@ def __init__(self, *args, **kwargs): self._value = t_value else: try: - self._value = self.sample(sample_shape) + self._value = self.sample(self._sample_shape) except NotImplementedError: raise NotImplementedError( "sample is not implemented for {0}. You must either pass in the " @@ -294,6 +306,10 @@ def get_shape(self): """Get shape of random variable.""" return self.shape + def get_sample_shape(self): + """Sample shape of random variable.""" + return self._sample_shape + @staticmethod def _session_run_conversion_fetch_function(tensor): return ([tensor.value()], lambda val: val[0]) diff --git a/tests/test-models/test_random_variable_value.py b/tests/test-models/test_random_variable_value.py index 3f4f90830..4d7b0eccc 100644 --- a/tests/test-models/test_random_variable_value.py +++ b/tests/test-models/test_random_variable_value.py @@ -14,8 +14,9 @@ class test_random_variable_value_class(tf.test.TestCase): def _test_sample(self, RV, value, *args, **kwargs): rv = RV(*args, value=value, **kwargs) value_shape = rv.value().shape - expected_shape = rv.get_batch_shape().concatenate(rv.get_event_shape()) - self.assertEqual(value_shape[-len(expected_shape):], expected_shape) + expected_shape = rv.get_sample_shape().concatenate( + rv.get_batch_shape()).concatenate(rv.get_event_shape()) + self.assertEqual(value_shape, expected_shape) self.assertEqual(rv.dtype, rv.value().dtype) def _test_copy(self, RV, value, *args, **kwargs):