Skip to content

Commit

Permalink
Gibbs Sampling (#607)
Browse files Browse the repository at this point in the history
* add MWE of ed.Gibbs

* add scan order as argument

* enable blocked gibbs

* improve gibbs test

* preserve order in looping over dicts

* avoid error when attempting conjugate_log_prob on Empirical

* minor revisions to docstrings

* remove unnecessary import statements
  • Loading branch information
dustinvtran committed Apr 14, 2017
1 parent 5872095 commit 973e38d
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 3 deletions.
2 changes: 1 addition & 1 deletion edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
KLpq, KLqp, ReparameterizationKLqp, ReparameterizationKLKLqp, \
ReparameterizationEntropyKLqp, ScoreKLqp, ScoreKLKLqp, ScoreEntropyKLqp, \
GANInference, WGANInference, ImplicitKLqp, MAP, Laplace, \
complete_conditional
complete_conditional, Gibbs
from edward.models import RandomVariable
from edward.util import check_data, check_latent_vars, copy, dot, \
get_ancestors, get_blanket, get_children, get_control_variate_coef, \
Expand Down
1 change: 1 addition & 0 deletions edward/inferences/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from edward.inferences.conjugacy import complete_conditional
from edward.inferences.gan_inference import *
from edward.inferences.gibbs import *
from edward.inferences.hmc import *
from edward.inferences.implicit_klqp import *
from edward.inferences.inference import *
Expand Down
4 changes: 2 additions & 2 deletions edward/inferences/conjugacy/conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from edward.inferences.conjugacy.simplify \
import symbolic_suff_stat, full_simplify, expr_contains, reconstruct_expr
from edward.models import random_variables as rvs
from edward.util import copy, random_variables
from edward.util import copy, get_blanket


def normal_from_natural_params(p1, p2):
Expand Down Expand Up @@ -94,7 +94,7 @@ def complete_conditional(rv, cond_set=None):
result in unpredictable behavior.
"""
if cond_set is None:
cond_set = random_variables()
cond_set = get_blanket(rv) + [rv]
with tf.name_scope('complete_conditional_%s' % rv.name) as scope:
# log_joint holds all the information we need to get a conditional.
cond_set = set([rv] + list(cond_set))
Expand Down
146 changes: 146 additions & 0 deletions edward/inferences/gibbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import six
import tensorflow as tf

from collections import OrderedDict
from edward.inferences.conjugacy import complete_conditional
from edward.inferences.monte_carlo import MonteCarlo
from edward.models import RandomVariable
from edward.util import check_latent_vars, get_session


class Gibbs(MonteCarlo):
"""Gibbs sampling (Geman and Geman, 1984).
"""
def __init__(self, latent_vars, proposal_vars=None, data=None):
"""
Parameters
----------
proposal_vars : dict of RandomVariable to RandomVariable, optional
Collection of random variables to perform inference on; each is
binded to its complete conditionals which Gibbs cycles draws on.
If not specified, default is to use ``ed.complete_conditional``.
Examples
--------
>>> x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1])
>>>
>>> p = Beta(1.0, 1.0)
>>> x = Bernoulli(p=p, sample_shape=10)
>>>
>>> qp = Empirical(tf.Variable(tf.zeros(500)))
>>> inference = ed.Gibbs({p: qp}, data={x: x_data})
"""
if proposal_vars is None:
proposal_vars = {z: complete_conditional(z)
for z in six.iterkeys(latent_vars)}
else:
check_latent_vars(proposal_vars)

self.proposal_vars = proposal_vars
super(Gibbs, self).__init__(latent_vars, data)

def initialize(self, scan_order='random', *args, **kwargs):
"""
Parameters
----------
scan_order : list or str, optional
The scan order for each Gibbs update. If list, it is the
deterministic order of latent variables. An element in the list
can be a ``RandomVariable`` or itself a list of
``RandomVariable``s (this defines a blocked Gibbs sampler). If
'random', will use a random order at each update.
"""
self.scan_order = scan_order
self.feed_dict = {}
return super(Gibbs, self).initialize(*args, **kwargs)

def update(self, feed_dict=None):
"""Run one iteration of Gibbs sampling.
Parameters
----------
feed_dict : dict, optional
Feed dictionary for a TensorFlow session run. It is used to feed
placeholders that are not fed during initialization.
Returns
-------
dict
Dictionary of algorithm-specific information. In this case, the
acceptance rate of samples since (and including) this iteration.
"""
sess = get_session()
if not self.feed_dict:
# Initialize feed for all conditionals to be the draws at step 0.
samples = OrderedDict(self.latent_vars)
inits = sess.run([qz.params[0] for qz in six.itervalues(samples)])
for z, init in zip(six.iterkeys(samples), inits):
self.feed_dict[z] = init

for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
self.feed_dict[key] = value
elif isinstance(key, RandomVariable) and isinstance(value, tf.Variable):
self.feed_dict[key] = sess.run(value)

if feed_dict is None:
feed_dict = {}

feed_dict.update(self.feed_dict)

# Determine scan order.
if self.scan_order == 'random':
scan_order = list(six.iterkeys(self.latent_vars))
random.shuffle(scan_order)
else: # list
scan_order = self.scan_order

# Fetch samples by iterating over complete conditional draws.
for z in scan_order:
if isinstance(z, RandomVariable):
draw = sess.run(self.proposal_vars[z], feed_dict)
feed_dict[z] = draw
self.feed_dict[z] = draw
else: # list
draws = sess.run([self.proposal_vars[zz] for zz in z], feed_dict)
for zz, draw in zip(z, draws):
feed_dict[zz] = draw
self.feed_dict[zz] = draw

# Assign the samples to the Empirical random variables.
_, accept_rate = sess.run([self.train, self.n_accept_over_t], feed_dict)
t = sess.run(self.increment_t)

if self.debug:
sess.run(self.op_check)

if self.logging and self.n_print != 0:
if t == 1 or t % self.n_print == 0:
summary = sess.run(self.summarize, feed_dict)
self.train_writer.add_summary(summary, t)

return {'t': t, 'accept_rate': accept_rate}

def build_update(self):
"""
Notes
-----
The updates assume each Empirical random variable is directly
parameterized by ``tf.Variable``s.
"""
# Update Empirical random variables according to the complete
# conditionals. We will feed the conditionals when calling ``update()``.
assign_ops = []
for z, qz in six.iteritems(self.latent_vars):
variable = qz.get_variables()[0]
assign_ops.append(
tf.scatter_update(variable, self.t, self.proposal_vars[z]))

# Increment n_accept (if accepted).
assign_ops.append(self.n_accept.assign_add(1))
return tf.group(*assign_ops)
52 changes: 52 additions & 0 deletions tests/test-inferences/test_gibbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import numpy as np
import tensorflow as tf

from edward.models import Bernoulli, Beta, Empirical, Normal


class test_gibbs_class(tf.test.TestCase):

def test_beta_bernoulli(self):
with self.test_session() as sess:
x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1])

p = Beta(a=1.0, b=1.0)
x = Bernoulli(p=p, sample_shape=10)

qp = Empirical(tf.Variable(tf.zeros(1000)))
inference = ed.Gibbs({p: qp}, data={x: x_data})
inference.run()

true_posterior = Beta(a=3.0, b=9.0)

val_est, val_true = sess.run([qp.mean(), true_posterior.mean()])
self.assertAllClose(val_est, val_true, rtol=1e-2, atol=1e-2)

val_est, val_true = sess.run([qp.variance(), true_posterior.variance()])
self.assertAllClose(val_est, val_true, rtol=1e-2, atol=1e-2)

def test_normal_normal(self):
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

mu = Normal(mu=0.0, sigma=1.0)
x = Normal(mu=mu, sigma=1.0, sample_shape=50)

qmu = Empirical(params=tf.Variable(tf.ones(1000)))

# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
inference = ed.Gibbs({mu: qmu}, data={x: x_data})
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2)
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51),
rtol=1e-2, atol=1e-2)

if __name__ == '__main__':
ed.set_seed(127832)
tf.test.main()

0 comments on commit 973e38d

Please sign in to comment.