Skip to content

Commit

Permalink
Fixed a bug where rv isn't part of blanket
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Hoffman committed Apr 7, 2017
1 parent e945c63 commit b4a774d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
6 changes: 3 additions & 3 deletions edward/inferences/conjugacy/conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def get_log_joint(blanket):
def complete_conditional(rv, blanket, log_joint=None):
with tf.name_scope('complete_conditional_%s' % rv.name) as scope:
# log_joint holds all the information we need to get a conditional.
extended_blanket = copy(blanket)
blanket = set([rv] + list(blanket))
if log_joint is None:
log_joint = get_log_joint(extended_blanket)
log_joint = get_log_joint(blanket)
else:
log_joint = log_joint(extended_blanket)
log_joint = log_joint(blanket)

# Pull out the nodes that are nonlinear functions of rv into s_stats.
stop_nodes = set([i.value() for i in blanket])
Expand Down
11 changes: 11 additions & 0 deletions tests/test-inferences/test_conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ def test_basic_bernoulli(self):

self.assertAllClose(p_val, 0.75 + np.zeros(N, np.float32))

def test_incomplete_blanket(self):
N = 10
z = rvs.Bernoulli(p=0.75, sample_shape=N)
z_cond = conj.complete_conditional(z, [])
self.assertIsInstance(z_cond, rvs.Bernoulli)

sess = tf.InteractiveSession()
p_val = sess.run(z_cond.p)

self.assertAllClose(p_val, 0.75 + np.zeros(N, np.float32))

def test_beta_bernoulli(self):
x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1])

Expand Down

0 comments on commit b4a774d

Please sign in to comment.