Skip to content

Commit

Permalink
add test to verify debug=True works
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed May 28, 2017
1 parent 43e28ba commit 88e1ee8
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 3 deletions.
2 changes: 1 addition & 1 deletion edward/inferences/inference.py
Expand Up @@ -231,7 +231,7 @@ def update(self, feed_dict=None):
t = sess.run(self.increment_t)

if self.debug:
sess.run(self.op_check)
sess.run(self.op_check, feed_dict)

if self.logging and self.n_print != 0:
if t == 1 or t % self.n_print == 0:
Expand Down
File renamed without changes.
41 changes: 41 additions & 0 deletions tests/test-inferences/test_inference_debug.py
@@ -0,0 +1,41 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import numpy as np
import tensorflow as tf

from edward.models import Normal


class test_inference_debug_class(tf.test.TestCase):

def test_placeholder(self):
with self.test_session():
N = 5
mu = Normal(loc=0.0, scale=1.0)
x = Normal(loc=tf.ones(N) * mu, scale=tf.ones(N))

qmu = Normal(loc=tf.Variable(0.0), scale=tf.constant(1.0))

x_ph = tf.placeholder(tf.float32, [N])
inference = ed.KLqp({mu: qmu}, data={x: x_ph})
inference.initialize(debug=True)
tf.global_variables_initializer().run()
inference.update(feed_dict={x_ph: np.zeros(N, np.float32)})

def test_tensor(self):
with self.test_session():
N = 5
mu = Normal(loc=0.0, scale=1.0)
x = Normal(loc=tf.ones(N) * mu, scale=tf.ones(N))

qmu = Normal(loc=tf.Variable(0.0), scale=tf.constant(1.0))

x_data = tf.zeros(N)
inference = ed.KLqp({mu: qmu}, data={x: x_data})
inference.run(n_iter=1, debug=True)

if __name__ == '__main__':
tf.test.main()
Expand Up @@ -21,8 +21,7 @@ def test_scale_0d(self):
qmu = Normal(loc=tf.Variable(0.0), scale=tf.constant(1.0))

x_ph = tf.placeholder(tf.float32, [M])
data = {x: x_ph}
inference = ed.KLqp({mu: qmu}, data)
inference = ed.KLqp({mu: qmu}, data={x: x_ph})
inference.initialize(scale={x: float(N) / M})
self.assertAllEqual(inference.scale[x], float(N) / M)

Expand Down

0 comments on commit 88e1ee8

Please sign in to comment.