-
Notifications
You must be signed in to change notification settings - Fork 761
/
gibbs.py
148 lines (121 loc) · 4.92 KB
/
gibbs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import six
import tensorflow as tf
from collections import OrderedDict
from edward.inferences.conjugacy import complete_conditional
from edward.inferences.monte_carlo import MonteCarlo
from edward.models import RandomVariable
from edward.util import check_latent_vars, get_session
class Gibbs(MonteCarlo):
"""Gibbs sampling (Geman and Geman, 1984).
#### Examples
```python
x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1])
p = Beta(1.0, 1.0)
x = Bernoulli(probs=p, sample_shape=10)
qp = Empirical(tf.Variable(tf.zeros(500)))
inference = ed.Gibbs({p: qp}, data={x: x_data})
```
"""
def __init__(self, latent_vars, proposal_vars=None, data=None):
"""Create an inference algorithm.
Args:
proposal_vars: dict of RandomVariable to RandomVariable, optional.
Collection of random variables to perform inference on; each is
binded to its complete conditionals which Gibbs cycles draws on.
If not specified, default is to use `ed.complete_conditional`.
"""
if proposal_vars is None:
proposal_vars = {z: complete_conditional(z)
for z in six.iterkeys(latent_vars)}
else:
check_latent_vars(proposal_vars)
self.proposal_vars = proposal_vars
super(Gibbs, self).__init__(latent_vars, data)
def initialize(self, scan_order='random', *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
and builds ops for the algorithm's computation graph.
Args:
scan_order: list or str, optional.
The scan order for each Gibbs update. If list, it is the
deterministic order of latent variables. An element in the list
can be a `RandomVariable` or itself a list of
`RandomVariable`s (this defines a blocked Gibbs sampler). If
'random', will use a random order at each update.
"""
self.scan_order = scan_order
self.feed_dict = {}
return super(Gibbs, self).initialize(*args, **kwargs)
def update(self, feed_dict=None):
"""Run one iteration of sampling.
Args:
feed_dict: dict, optional.
Feed dictionary for a TensorFlow session run. It is used to feed
placeholders that are not fed during initialization.
Returns:
dict.
Dictionary of algorithm-specific information. In this case, the
acceptance rate of samples since (and including) this iteration.
"""
sess = get_session()
if not self.feed_dict:
# Initialize feed for all conditionals to be the draws at step 0.
samples = OrderedDict(self.latent_vars)
inits = sess.run([qz.params[0] for qz in six.itervalues(samples)])
for z, init in zip(six.iterkeys(samples), inits):
self.feed_dict[z] = init
for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
self.feed_dict[key] = value
elif isinstance(key, RandomVariable) and \
isinstance(value, (tf.Tensor, tf.Variable)):
self.feed_dict[key] = sess.run(value)
if feed_dict is None:
feed_dict = {}
feed_dict.update(self.feed_dict)
# Determine scan order.
if self.scan_order == 'random':
scan_order = list(six.iterkeys(self.latent_vars))
random.shuffle(scan_order)
else: # list
scan_order = self.scan_order
# Fetch samples by iterating over complete conditional draws.
for z in scan_order:
if isinstance(z, RandomVariable):
draw = sess.run(self.proposal_vars[z], feed_dict)
feed_dict[z] = draw
self.feed_dict[z] = draw
else: # list
draws = sess.run([self.proposal_vars[zz] for zz in z], feed_dict)
for zz, draw in zip(z, draws):
feed_dict[zz] = draw
self.feed_dict[zz] = draw
# Assign the samples to the Empirical random variables.
_, accept_rate = sess.run([self.train, self.n_accept_over_t], feed_dict)
t = sess.run(self.increment_t)
if self.debug:
sess.run(self.op_check, feed_dict)
if self.logging and self.n_print != 0:
if t == 1 or t % self.n_print == 0:
summary = sess.run(self.summarize, feed_dict)
self.train_writer.add_summary(summary, t)
return {'t': t, 'accept_rate': accept_rate}
def build_update(self):
"""
#### Notes
The updates assume each Empirical random variable is directly
parameterized by `tf.Variable`s.
"""
# Update Empirical random variables according to the complete
# conditionals. We will feed the conditionals when calling `update()`.
assign_ops = []
for z, qz in six.iteritems(self.latent_vars):
variable = qz.get_variables()[0]
assign_ops.append(
tf.scatter_update(variable, self.t, self.proposal_vars[z]))
# Increment n_accept (if accepted).
assign_ops.append(self.n_accept.assign_add(1))
return tf.group(*assign_ops)