-
Notifications
You must be signed in to change notification settings - Fork 761
/
random_variable.py
344 lines (268 loc) · 10.9 KB
/
random_variable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
try:
from tensorflow.python.client.session import \
register_session_run_conversion_functions
except Exception as e:
raise ImportError("{0}. Your TensorFlow version is not supported.".format(e))
RANDOM_VARIABLE_COLLECTION = "_random_variable_collection_"
class RandomVariable(object):
"""Base class for random variables.
A random variable is an object parameterized by tensors. It is
equipped with methods such as the log-density, mean, and sample.
It also wraps a tensor, where the tensor corresponds to a sample
from the random variable. This enables operations on the TensorFlow
graph, allowing random variables to be used in conjunction with
other TensorFlow ops.
The random variable's shape is given by
``sample_shape + batch_shape + event_shape``,
where ``sample_shape`` is an optional argument representing the
dimensions of samples drawn from the distribution (default is
a scalar); ``batch_shape`` is the number of independent random variables
(determined by the shape of its parameters); and ``event_shape`` is
the shape of one draw from the distribution (e.g., ``Normal`` has a
scalar ``event_shape``; ``Dirichlet`` has a vector ``event_shape``).
Notes
-----
``RandomVariable`` assumes use in a multiple inheritance setting. The
child class must first inherit ``RandomVariable``, then second inherit a
class in ``tf.contrib.distributions``. With Python's method resolution
order, this implies the following during initialization (using
``distributions.Bernoulli`` as an example):
1. Start the ``__init__()`` of the child class, which passes all
``*args, **kwargs`` to ``RandomVariable``.
2. This in turn passes all ``*args, **kwargs`` to
``distributions.Bernoulli``, completing the ``__init__()`` of
``distributions.Bernoulli``.
3. Complete the ``__init__()`` of ``RandomVariable``, which calls
``self.sample()``, relying on the method from
``distributions.Bernoulli``.
4. Complete the ``__init__()`` of the child class.
Methods from both ``RandomVariable`` and ``distributions.Bernoulli``
populate the namespace of the child class. Methods from
``RandomVariable`` will take higher priority if there are conflicts.
Examples
--------
>>> p = tf.constant(0.5)
>>> x = Bernoulli(p=p)
>>>
>>> z1 = tf.constant([[2.0, 8.0], [1.0, 2.0]])
>>> z2 = tf.constant([[1.0, 2.0], [3.0, 1.0]])
>>> x = Bernoulli(p=tf.matmul(z1, z2))
>>>
>>> mu = Normal(mu=tf.constant(0.0), sigma=tf.constant(1.0))
>>> x = Normal(mu=mu, sigma=tf.constant(1.0))
"""
def __init__(self, *args, **kwargs):
# storing args, kwargs for easy graph copying
self._args = args
self._kwargs = kwargs
# temporarily pop before calling parent __init__
value = kwargs.pop('value', None)
self._sample_shape = kwargs.pop('sample_shape', tf.TensorShape([]))
super(RandomVariable, self).__init__(*args, **kwargs)
# reinsert (needed for copying)
if value is not None:
self._kwargs['value'] = value
if self._sample_shape != tf.TensorShape([]):
self._kwargs['sample_shape'] = self._sample_shape
tf.add_to_collection(RANDOM_VARIABLE_COLLECTION, self)
if value is not None:
t_value = tf.convert_to_tensor(value, self.dtype)
value_shape = t_value.shape
expected_shape = self.get_sample_shape().concatenate(
self.get_batch_shape()).concatenate(self.get_event_shape())
if value_shape != expected_shape:
raise ValueError(
"Incompatible shape for initialization argument 'value'. "
"Expected %s, got %s." % (expected_shape, value_shape))
else:
self._value = t_value
else:
try:
self._value = self.sample(self._sample_shape)
except NotImplementedError:
raise NotImplementedError(
"sample is not implemented for {0}. You must either pass in the "
"value argument or implement sample for {0}."
.format(self.__class__.__name__))
@property
def shape(self):
"""Shape of random variable."""
return self._value.shape
def __str__(self):
return "RandomVariable(\"%s\"%s%s%s)" % (
self.name,
(", shape=%s" % self.shape)
if self.shape.ndims is not None else "",
(", dtype=%s" % self.dtype.name) if self.dtype else "",
(", device=%s" % self.value().device) if self.value().device else "")
def __repr__(self):
return "<ed.RandomVariable '%s' shape=%s dtype=%s>" % (
self.name, self.shape, self.dtype.name)
def __add__(self, other):
return tf.add(self, other)
def __radd__(self, other):
return tf.add(other, self)
def __sub__(self, other):
return tf.subtract(self, other)
def __rsub__(self, other):
return tf.subtract(other, self)
def __mul__(self, other):
return tf.multiply(self, other)
def __rmul__(self, other):
return tf.multiply(other, self)
def __div__(self, other):
return tf.div(self, other)
__truediv__ = __div__
def __rdiv__(self, other):
return tf.div(other, self)
__rtruediv__ = __rdiv__
def __floordiv__(self, other):
return tf.floor(tf.div(self, other))
def __rfloordiv__(self, other):
return tf.floor(tf.div(other, self))
def __mod__(self, other):
return tf.mod(self, other)
def __rmod__(self, other):
return tf.mod(other, self)
def __lt__(self, other):
return tf.less(self, other)
def __le__(self, other):
return tf.less_equal(self, other)
def __gt__(self, other):
return tf.greater(self, other)
def __ge__(self, other):
return tf.greater_equal(self, other)
def __and__(self, other):
return tf.logical_and(self, other)
def __rand__(self, other):
return tf.logical_and(other, self)
def __or__(self, other):
return tf.logical_or(self, other)
def __ror__(self, other):
return tf.logical_or(other, self)
def __xor__(self, other):
return tf.logical_xor(self, other)
def __rxor__(self, other):
return tf.logical_xor(other, self)
def __getitem__(self, key):
"""Subset the tensor associated to the random variable, not the
random variable itself."""
return self.value()[key]
def __pow__(self, other):
return tf.pow(self, other)
def __rpow__(self, other):
return tf.pow(other, self)
def __invert__(self):
return tf.logical_not(self)
def __neg__(self):
return tf.negative(self)
def __abs__(self):
return tf.abs(self)
def __hash__(self):
return id(self)
def __eq__(self, other):
return id(self) == id(other)
def __iter__(self):
raise TypeError("'RandomVariable' object is not iterable.")
def __bool__(self):
raise TypeError(
"Using a `ed.RandomVariable` as a Python `bool` is not allowed. "
"Use `if t is not None:` instead of `if t:` to test if a "
"random variable is defined, and use TensorFlow ops such as "
"tf.cond to execute subgraphs conditioned on a draw from "
"a random variable.")
def __nonzero__(self):
raise TypeError(
"Using a `ed.RandomVariable` as a Python `bool` is not allowed. "
"Use `if t is not None:` instead of `if t:` to test if a "
"random variable is defined, and use TensorFlow ops such as "
"tf.cond to execute subgraphs conditioned on a draw from "
"a random variable.")
def eval(self, session=None, feed_dict=None):
"""In a session, computes and returns the value of this random variable.
This is not a graph construction method, it does not add ops to the graph.
This convenience method requires a session where the graph
containing this variable has been launched. If no session is
passed, the default session is used.
Parameters
----------
session : tf.BaseSession, optional
The ``tf.Session`` to use to evaluate this random variable. If
none, the default session is used.
feed_dict : dict, optional
A dictionary that maps ``tf.Tensor`` objects to feed values. See
``tf.Session.run()`` for a description of the valid feed values.
Examples
--------
>>> x = Normal(0.0, 1.0)
>>> with tf.Session() as sess:
>>> # Usage passing the session explicitly.
>>> print(x.eval(sess))
>>> # Usage with the default session. The 'with' block
>>> # above makes 'sess' the default session.
>>> print(x.eval())
"""
return self.value().eval(session=session, feed_dict=feed_dict)
def value(self):
"""Get tensor that the random variable corresponds to."""
return self._value
def get_ancestors(self, collection=None):
"""Get ancestor random variables."""
from edward.util.random_variables import get_ancestors
return get_ancestors(self, collection)
def get_children(self, collection=None):
"""Get child random variables."""
from edward.util.random_variables import get_children
return get_children(self, collection)
def get_descendants(self, collection=None):
"""Get descendant random variables."""
from edward.util.random_variables import get_descendants
return get_descendants(self, collection)
def get_parents(self, collection=None):
"""Get parent random variables."""
from edward.util.random_variables import get_parents
return get_parents(self, collection)
def get_siblings(self, collection=None):
"""Get sibling random variables."""
from edward.util.random_variables import get_siblings
return get_siblings(self, collection)
def get_variables(self, collection=None):
"""Get TensorFlow variables that the random variable depends on."""
from edward.util.random_variables import get_variables
return get_variables(self, collection)
def get_shape(self):
"""Get shape of random variable."""
return self.shape
def get_sample_shape(self):
"""Sample shape of random variable."""
return self._sample_shape
@staticmethod
def _session_run_conversion_fetch_function(tensor):
return ([tensor.value()], lambda val: val[0])
@staticmethod
def _session_run_conversion_feed_function(feed, feed_val):
return [(feed.value(), feed_val)]
@staticmethod
def _session_run_conversion_feed_function_for_partial_run(feed):
return [feed.value()]
@staticmethod
def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False):
_ = name
if dtype and not dtype.is_compatible_with(v.dtype):
raise ValueError(
"Incompatible type conversion requested to type '%s' for variable "
"of type '%s'" % (dtype.name, v.dtype.name))
if as_ref:
raise ValueError("%s: Ref type is not supported." % v)
return v.value()
register_session_run_conversion_functions(
RandomVariable,
RandomVariable._session_run_conversion_fetch_function,
RandomVariable._session_run_conversion_feed_function,
RandomVariable._session_run_conversion_feed_function_for_partial_run)
tf.register_tensor_conversion_function(
RandomVariable, RandomVariable._tensor_conversion_function)