From a0a0122618a224804f1d2b5ca9b410c0b3924ef4 Mon Sep 17 00:00:00 2001 From: masashi yoshikawa Date: Sun, 18 Mar 2018 09:20:26 +0900 Subject: [PATCH 1/2] [add] add unit tests and small fix --- edward/inferences/conjugacy/conjugacy.py | 12 +- .../conjugacy/conjugate_log_probs.py | 2 +- edward/inferences/gibbs.py | 5 +- tests/inferences/gibbs_test.py | 53 ++++++--- tests/inferences/hmc_test.py | 106 +++++++++++++----- tests/inferences/metropolishastings_test.py | 97 +++++++++++----- tests/inferences/replicaexchangemc_test.py | 18 +-- tests/inferences/sghmc_test.py | 98 ++++++++++++---- tests/inferences/sgld_test.py | 98 ++++++++++++---- 9 files changed, 361 insertions(+), 128 deletions(-) diff --git a/edward/inferences/conjugacy/conjugacy.py b/edward/inferences/conjugacy/conjugacy.py index ce0941398..e394f8939 100644 --- a/edward/inferences/conjugacy/conjugacy.py +++ b/edward/inferences/conjugacy/conjugacy.py @@ -137,9 +137,15 @@ def complete_conditional(rv, cond_set=None): swap_dict = {} swap_back = {} for s_stat_expr in six.itervalues(s_stat_exprs): - s_stat_placeholder = tf.placeholder(tf.float32, - s_stat_expr[0][0].get_shape()) - swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32) + if rv.dtype == tf.float64: + s_stat_placeholder = tf.placeholder(rv.dtype, + s_stat_expr[0][0].get_shape()) + swap_back[s_stat_placeholder] = rv.value() + else: + s_stat_placeholder = tf.placeholder(tf.float32, + s_stat_expr[0][0].get_shape()) + swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32) + s_stat_placeholders.append(s_stat_placeholder) for s_stat_node, multiplier in s_stat_expr: fake_node = s_stat_placeholder * multiplier diff --git a/edward/inferences/conjugacy/conjugate_log_probs.py b/edward/inferences/conjugacy/conjugate_log_probs.py index a2e25f0ee..92bed54ab 100644 --- a/edward/inferences/conjugacy/conjugate_log_probs.py +++ b/edward/inferences/conjugacy/conjugate_log_probs.py @@ -137,7 +137,7 @@ def normal_log_prob(self, val): prec = tf.reciprocal(tf.square(scale)) result = prec * (-0.5 * tf.square(val) - 0.5 * tf.square(loc) + val * loc) - result -= tf.log(scale) + 0.5 * tf.log(2 * np.pi) + result -= tf.log(scale) + 0.5 * tf.cast(tf.log(2 * np.pi), dtype=result.dtype) return result diff --git a/edward/inferences/gibbs.py b/edward/inferences/gibbs.py index 3efb2d0c9..7b7b0d714 100644 --- a/edward/inferences/gibbs.py +++ b/edward/inferences/gibbs.py @@ -41,14 +41,15 @@ def __init__(self, latent_vars, proposal_vars=None, data=None): binded to its complete conditionals which Gibbs cycles draws on. If not specified, default is to use `ed.complete_conditional`. """ + super(Gibbs, self).__init__(latent_vars, data) + if proposal_vars is None: proposal_vars = {z: complete_conditional(z) - for z in six.iterkeys(latent_vars)} + for z in six.iterkeys(self.latent_vars)} else: check_latent_vars(proposal_vars) self.proposal_vars = proposal_vars - super(Gibbs, self).__init__(latent_vars, data) def initialize(self, scan_order='random', *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/tests/inferences/gibbs_test.py b/tests/inferences/gibbs_test.py index 60ca373b4..2f8d73f7f 100644 --- a/tests/inferences/gibbs_test.py +++ b/tests/inferences/gibbs_test.py @@ -11,6 +11,40 @@ class test_gibbs_class(tf.test.TestCase): + def _test_normal_normal(self, default, dtype): + with self.test_session() as sess: + x_data = np.array([0.0] * 50, dtype=np.float32) + + mu = Normal(loc=tf.constant(0.0, dtype=dtype), + scale=tf.constant(1.0, dtype=dtype)) + x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype), + sample_shape=50) + + n_samples = 2000 + # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) + if not default: + qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype))) + inference = ed.Gibbs({mu: qmu}, data={x: x_data}) + else: + inference = ed.Gibbs([mu], data={x: x_data}) + qmu = inference.latent_vars[mu] + inference.run() + + self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) + self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), + rtol=1e-1, atol=1e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + def test_beta_bernoulli(self): with self.test_session() as sess: x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1]) @@ -31,21 +65,10 @@ def test_beta_bernoulli(self): self.assertAllClose(val_est, val_true, rtol=1e-2, atol=1e-2) def test_normal_normal(self): - with self.test_session() as sess: - x_data = np.array([0.0] * 50, dtype=np.float32) - - mu = Normal(0.0, 1.0) - x = Normal(mu, 1.0, sample_shape=50) - - qmu = Empirical(tf.Variable(tf.ones(1000))) - - # analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140) - inference = ed.Gibbs({mu: qmu}, data={x: x_data}) - inference.run() - - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2) - self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=1e-2, atol=1e-2) + self._test_normal_normal(True, tf.float32) + self._test_normal_normal(False, tf.float32) + self._test_normal_normal(True, tf.float64) + self._test_normal_normal(False, tf.float64) def test_data_tensor(self): with self.test_session() as sess: diff --git a/tests/inferences/hmc_test.py b/tests/inferences/hmc_test.py index ace139c42..388f1b323 100644 --- a/tests/inferences/hmc_test.py +++ b/tests/inferences/hmc_test.py @@ -11,42 +11,96 @@ class test_hmc_class(tf.test.TestCase): - def test_normalnormal_float32(self): + def _test_normal_normal(self, default, dtype): with self.test_session() as sess: x_data = np.array([0.0] * 50, dtype=np.float32) - mu = Normal(loc=0.0, scale=1.0) - x = Normal(loc=mu, scale=1.0, sample_shape=50) - - qmu = Empirical(params=tf.Variable(tf.ones(2000))) + mu = Normal(loc=tf.constant(0.0, dtype=dtype), + scale=tf.constant(1.0, dtype=dtype)) + x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype), + sample_shape=50) + n_samples = 2000 # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.HMC({mu: qmu}, data={x: x_data}) + if not default: + qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype))) + inference = ed.HMC({mu: qmu}, data={x: x_data}) + else: + inference = ed.HMC([mu], data={x: x_data}) + qmu = inference.latent_vars[mu] inference.run() - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2) + self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=1e-2, atol=1e-2) + rtol=1e-1, atol=1e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + + def _test_linear_regression(self, default, dtype): + def build_toy_dataset(N, w, noise_std=0.1): + D = len(w) + x = np.random.randn(N, D) + y = np.dot(x, w) + np.random.normal(0, noise_std, size=N) + return x, y - def test_normalnormal_float64(self): with self.test_session() as sess: - x_data = np.array([0.0] * 50, dtype=np.float64) - - mu = Normal(loc=tf.constant(0.0, dtype=tf.float64), - scale=tf.constant(1.0, dtype=tf.float64)) - x = Normal(loc=mu, - scale=tf.constant(1.0, dtype=tf.float64), - sample_shape=50) - - qmu = Empirical(params=tf.Variable(tf.ones(2000, dtype=tf.float64))) - - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.HMC({mu: qmu}, data={x: x_data}) - inference.run() - - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2) - self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=1e-2, atol=1e-2) + N = 40 # number of data points + D = 10 # number of features + + w_true = np.random.randn(D) + X_train, y_train = build_toy_dataset(N, w_true) + X_test, y_test = build_toy_dataset(N, w_true) + + X = tf.placeholder(dtype, [N, D]) + w = Normal(loc=tf.zeros(D, dtype=dtype), scale=tf.ones(D, dtype=dtype)) + b = Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype)) + y = Normal(loc=ed.dot(X, w) + b, scale=0.1 * tf.ones(N, dtype=dtype)) + + n_samples = 2000 + if not default: + qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype))) + qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype))) + inference = ed.HMC({w: qw, b: qb}, data={X: X_train, y: y_train}) + else: + inference = ed.HMC([w, b], data={X: X_train, y: y_train}) + qw = inference.latent_vars[w] + qb = inference.latent_vars[b] + inference.run(step_size=0.01) + + self.assertAllClose(qw.mean().eval(), w_true, rtol=5e-1, atol=5e-1) + self.assertAllClose(qb.mean().eval(), [0.0], rtol=5e-1, atol=5e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + + def test_normal_normal(self): + self._test_normal_normal(True, tf.float32) + self._test_normal_normal(False, tf.float32) + self._test_normal_normal(True, tf.float64) + self._test_normal_normal(False, tf.float64) + + def test_linear_regression(self): + self._test_linear_regression(True, tf.float32) + self._test_linear_regression(False, tf.float32) + self._test_linear_regression(True, tf.float64) + self._test_linear_regression(False, tf.float64) def test_indexedslices(self): """Test that gradients accumulate when tf.gradients doesn't return diff --git a/tests/inferences/metropolishastings_test.py b/tests/inferences/metropolishastings_test.py index 4c65f4f20..4487827ca 100644 --- a/tests/inferences/metropolishastings_test.py +++ b/tests/inferences/metropolishastings_test.py @@ -11,20 +11,23 @@ class test_metropolishastings_class(tf.test.TestCase): - def test_normalnormal_float32(self): + def _test_normal_normal(self, default, dtype): with self.test_session() as sess: x_data = np.array([0.0] * 50, dtype=np.float32) - mu = Normal(loc=0.0, scale=1.0) - x = Normal(loc=mu, scale=1.0, sample_shape=50) + mu = Normal(loc=tf.constant(0.0, dtype=dtype), + scale=tf.constant(1.0, dtype=dtype)) + x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype), + sample_shape=50) n_samples = 2000 - qmu = Empirical(params=tf.Variable(tf.ones(n_samples))) - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.MetropolisHastings({mu: qmu}, - {mu: mu}, - data={x: x_data}) + if not default: + qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype))) + inference = ed.MetropolisHastings({mu: qmu}, {mu: mu}, data={x: x_data}) + else: + inference = ed.MetropolisHastings([mu], {mu: mu}, data={x: x_data}) + qmu = inference.latent_vars[mu] inference.run() self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) @@ -32,35 +35,79 @@ def test_normalnormal_float32(self): rtol=1e-1, atol=1e-1) old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) - self.assertEqual(old_t, n_samples) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) self.assertGreater(old_n_accept, 0.1) sess.run(inference.reset) new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) self.assertEqual(new_t, 0) self.assertEqual(new_n_accept, 0) - def test_normalnormal_float64(self): + def _test_linear_regression(self, default, dtype): + def build_toy_dataset(N, w, noise_std=0.1): + D = len(w) + x = np.random.randn(N, D) + y = np.dot(x, w) + np.random.normal(0, noise_std, size=N) + return x, y + with self.test_session() as sess: - x_data = np.array([0.0] * 50, dtype=np.float32) + N = 40 # number of data points + D = 10 # number of features - mu = Normal(loc=tf.constant(0.0, dtype=tf.float64), - scale=tf.constant(1.0, dtype=tf.float64)) - x = Normal(loc=mu, - scale=tf.constant(1.0, dtype=tf.float64), - sample_shape=50) + w_true = np.random.randn(D) + X_train, y_train = build_toy_dataset(N, w_true) + X_test, y_test = build_toy_dataset(N, w_true) - n_samples = 2000 - qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=tf.float64))) + X = tf.placeholder(dtype, [N, D]) + w = Normal(loc=tf.zeros(D, dtype=dtype), scale=tf.ones(D, dtype=dtype)) + b = Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype)) + y = Normal(loc=ed.dot(X, w) + b, scale=0.1 * tf.ones(N, dtype=dtype)) - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.MetropolisHastings({mu: qmu}, - {mu: mu}, - data={x: x_data}) + proposal_w = Normal(loc=w, scale=0.5 * tf.ones(D, dtype=dtype)) + proposal_b = Normal(loc=b, scale=0.5 * tf.ones(1, dtype=dtype)) + + n_samples = 2000 + if not default: + qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype))) + qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype))) + inference = ed.MetropolisHastings( + {w: qw, b: qb}, {w: proposal_w, b: proposal_b}, + data={X: X_train, y: y_train}) + else: + inference = ed.MetropolisHastings( + [w, b], {w: proposal_w, b: proposal_b}, + data={X: X_train, y: y_train}) + qw = inference.latent_vars[w] + qb = inference.latent_vars[b] inference.run() - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) - self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=1e-1, atol=1e-1) + self.assertAllClose(qw.mean().eval(), w_true, rtol=5e-1, atol=5e-1) + self.assertAllClose(qb.mean().eval(), [0.0], rtol=5e-1, atol=5e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + + def test_normal_normal(self): + self._test_normal_normal(True, tf.float32) + self._test_normal_normal(False, tf.float32) + self._test_normal_normal(True, tf.float64) + self._test_normal_normal(False, tf.float64) + + def test_linear_regression(self): + self._test_linear_regression(True, tf.float32) + self._test_linear_regression(False, tf.float32) + self._test_linear_regression(True, tf.float64) + self._test_linear_regression(False, tf.float64) if __name__ == '__main__': ed.set_seed(42) diff --git a/tests/inferences/replicaexchangemc_test.py b/tests/inferences/replicaexchangemc_test.py index e38f2aef6..da4e056a8 100644 --- a/tests/inferences/replicaexchangemc_test.py +++ b/tests/inferences/replicaexchangemc_test.py @@ -9,7 +9,7 @@ from edward.models import Normal, Empirical -class test_metropolishastings_class(tf.test.TestCase): +class test_replicaexchangemc_class(tf.test.TestCase): def _test_normal_normal(self, default, dtype): with self.test_session() as sess: @@ -21,17 +21,12 @@ def _test_normal_normal(self, default, dtype): sample_shape=50) n_samples = 2000 + # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) if not default: qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype))) - - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.ReplicaExchangeMC({mu: qmu}, - {mu: mu}, - data={x: x_data}) + inference = ed.ReplicaExchangeMC({mu: qmu}, {mu: mu}, data={x: x_data}) else: - inference = ed.ReplicaExchangeMC([mu], - {mu: mu}, - data={x: x_data}) + inference = ed.ReplicaExchangeMC([mu], {mu: mu}, data={x: x_data}) qmu = inference.latent_vars[mu] inference.run() @@ -77,7 +72,6 @@ def build_toy_dataset(N, w, noise_std=0.1): if not default: qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype))) qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype))) - inference = ed.ReplicaExchangeMC( {w: qw, b: qb}, {w: proposal_w, b: proposal_b}, data={X: X_train, y: y_train}) @@ -103,13 +97,13 @@ def build_toy_dataset(N, w, noise_std=0.1): self.assertEqual(new_t, 0) self.assertEqual(new_n_accept, 0) - def test_normalnormal(self): + def test_normal_normal(self): self._test_normal_normal(True, tf.float32) self._test_normal_normal(False, tf.float32) self._test_normal_normal(True, tf.float64) self._test_normal_normal(False, tf.float64) - def test_linearregression(self): + def test_linear_regression(self): self._test_linear_regression(True, tf.float32) self._test_linear_regression(False, tf.float32) self._test_linear_regression(True, tf.float64) diff --git a/tests/inferences/sghmc_test.py b/tests/inferences/sghmc_test.py index cd5d07a58..da3b7bab6 100644 --- a/tests/inferences/sghmc_test.py +++ b/tests/inferences/sghmc_test.py @@ -11,42 +11,96 @@ class test_sghmc_class(tf.test.TestCase): - def test_normalnormal_float32(self): + def _test_normal_normal(self, default, dtype): with self.test_session() as sess: x_data = np.array([0.0] * 50, dtype=np.float32) - mu = Normal(loc=0.0, scale=1.0) - x = Normal(loc=mu, scale=1.0, sample_shape=50) - - qmu = Empirical(params=tf.Variable(tf.ones(5000))) + mu = Normal(loc=tf.constant(0.0, dtype=dtype), + scale=tf.constant(1.0, dtype=dtype)) + x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype), + sample_shape=50) + n_samples = 2000 # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.SGHMC({mu: qmu}, data={x: x_data}) + if not default: + qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype))) + inference = ed.SGHMC({mu: qmu}, data={x: x_data}) + else: + inference = ed.SGHMC([mu], data={x: x_data}) + qmu = inference.latent_vars[mu] inference.run(step_size=0.025) - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1.5e-2) + self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=5e-2, atol=5e-2) + rtol=1e-1, atol=1e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + + def _test_linear_regression(self, default, dtype): + def build_toy_dataset(N, w, noise_std=0.1): + D = len(w) + x = np.random.randn(N, D) + y = np.dot(x, w) + np.random.normal(0, noise_std, size=N) + return x, y - def test_normalnormal_float64(self): with self.test_session() as sess: - x_data = np.array([0.0] * 50, dtype=np.float64) + N = 40 # number of data points + D = 10 # number of features - mu = Normal(loc=tf.constant(0.0, dtype=tf.float64), - scale=tf.constant(1.0, dtype=tf.float64)) - x = Normal(loc=mu, - scale=tf.constant(1.0, dtype=tf.float64), - sample_shape=50) + w_true = np.random.randn(D) + X_train, y_train = build_toy_dataset(N, w_true) + X_test, y_test = build_toy_dataset(N, w_true) - qmu = Empirical(params=tf.Variable(tf.ones(5000, dtype=tf.float64))) + X = tf.placeholder(dtype, [N, D]) + w = Normal(loc=tf.zeros(D, dtype=dtype), scale=tf.ones(D, dtype=dtype)) + b = Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype)) + y = Normal(loc=ed.dot(X, w) + b, scale=0.1 * tf.ones(N, dtype=dtype)) - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.SGHMC({mu: qmu}, data={x: x_data}) - inference.run(step_size=0.025) + n_samples = 2000 + if not default: + qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype))) + qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype))) + inference = ed.SGHMC({w: qw, b: qb}, data={X: X_train, y: y_train}) + else: + inference = ed.SGHMC([w, b], data={X: X_train, y: y_train}) + qw = inference.latent_vars[w] + qb = inference.latent_vars[b] + inference.run(step_size=0.0001) - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1.5e-2) - self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=5e-2, atol=5e-2) + self.assertAllClose(qw.mean().eval(), w_true, rtol=5e-1, atol=5e-1) + self.assertAllClose(qb.mean().eval(), [0.0], rtol=5e-1, atol=5e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + + def test_normal_normal(self): + self._test_normal_normal(True, tf.float32) + self._test_normal_normal(False, tf.float32) + self._test_normal_normal(True, tf.float64) + self._test_normal_normal(False, tf.float64) + + def test_linear_regression(self): + self._test_linear_regression(True, tf.float32) + self._test_linear_regression(False, tf.float32) + self._test_linear_regression(True, tf.float64) + self._test_linear_regression(False, tf.float64) if __name__ == '__main__': ed.set_seed(42) diff --git a/tests/inferences/sgld_test.py b/tests/inferences/sgld_test.py index d2a6d204a..bceb44053 100644 --- a/tests/inferences/sgld_test.py +++ b/tests/inferences/sgld_test.py @@ -11,42 +11,96 @@ class test_sgld_class(tf.test.TestCase): - def test_normalnormal_float32(self): + def _test_normal_normal(self, default, dtype): with self.test_session() as sess: x_data = np.array([0.0] * 50, dtype=np.float32) - mu = Normal(loc=0.0, scale=1.0) - x = Normal(loc=mu, scale=1.0, sample_shape=50) - - qmu = Empirical(params=tf.Variable(tf.ones(5000))) + mu = Normal(loc=tf.constant(0.0, dtype=dtype), + scale=tf.constant(1.0, dtype=dtype)) + x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype), + sample_shape=50) + n_samples = 2000 # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.SGLD({mu: qmu}, data={x: x_data}) + if not default: + qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype))) + inference = ed.SGLD({mu: qmu}, data={x: x_data}) + else: + inference = ed.SGLD([mu], data={x: x_data}) + qmu = inference.latent_vars[mu] inference.run(step_size=0.10) - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1.5e-2) + self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=5e-2, atol=5e-2) + rtol=1e-1, atol=1e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + + def _test_linear_regression(self, default, dtype): + def build_toy_dataset(N, w, noise_std=0.1): + D = len(w) + x = np.random.randn(N, D) + y = np.dot(x, w) + np.random.normal(0, noise_std, size=N) + return x, y - def test_normalnormal_float64(self): with self.test_session() as sess: - x_data = np.array([0.0] * 50, dtype=np.float64) + N = 40 # number of data points + D = 10 # number of features - mu = Normal(loc=tf.constant(0.0, dtype=tf.float64), - scale=tf.constant(1.0, dtype=tf.float64)) - x = Normal(loc=mu, - scale=tf.constant(1.0, dtype=tf.float64), - sample_shape=50) + w_true = np.random.randn(D) + X_train, y_train = build_toy_dataset(N, w_true) + X_test, y_test = build_toy_dataset(N, w_true) - qmu = Empirical(params=tf.Variable(tf.ones(5000, dtype=tf.float64))) + X = tf.placeholder(dtype, [N, D]) + w = Normal(loc=tf.zeros(D, dtype=dtype), scale=tf.ones(D, dtype=dtype)) + b = Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype)) + y = Normal(loc=ed.dot(X, w) + b, scale=0.1 * tf.ones(N, dtype=dtype)) - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = ed.SGLD({mu: qmu}, data={x: x_data}) - inference.run(step_size=0.10) + n_samples = 2000 + if not default: + qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype))) + qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype))) + inference = ed.SGLD({w: qw, b: qb}, data={X: X_train, y: y_train}) + else: + inference = ed.SGLD([w, b], data={X: X_train, y: y_train}) + qw = inference.latent_vars[w] + qb = inference.latent_vars[b] + inference.run(step_size=0.001) - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) - self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=1e-1, atol=1e-1) + self.assertAllClose(qw.mean().eval(), w_true, rtol=5e-1, atol=5e-1) + self.assertAllClose(qb.mean().eval(), [0.0], rtol=5e-1, atol=5e-1) + + old_t, old_n_accept = sess.run([inference.t, inference.n_accept]) + if not default: + self.assertEqual(old_t, n_samples) + else: + self.assertEqual(old_t, 1e4) + self.assertGreater(old_n_accept, 0.1) + sess.run(inference.reset) + new_t, new_n_accept = sess.run([inference.t, inference.n_accept]) + self.assertEqual(new_t, 0) + self.assertEqual(new_n_accept, 0) + + def test_normal_normal(self): + self._test_normal_normal(True, tf.float32) + self._test_normal_normal(False, tf.float32) + self._test_normal_normal(True, tf.float64) + self._test_normal_normal(False, tf.float64) + + def test_linear_regression(self): + self._test_linear_regression(True, tf.float32) + self._test_linear_regression(False, tf.float32) + self._test_linear_regression(True, tf.float64) + self._test_linear_regression(False, tf.float64) if __name__ == '__main__': ed.set_seed(42) From c1267d5700d030a016022aedec06a4f54d0514fd Mon Sep 17 00:00:00 2001 From: masashi yoshikawa Date: Sun, 18 Mar 2018 10:25:17 +0900 Subject: [PATCH 2/2] [fix] PEP8 fix --- edward/inferences/conjugacy/conjugacy.py | 2 +- edward/inferences/gibbs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/edward/inferences/conjugacy/conjugacy.py b/edward/inferences/conjugacy/conjugacy.py index e394f8939..a63f2ef9c 100644 --- a/edward/inferences/conjugacy/conjugacy.py +++ b/edward/inferences/conjugacy/conjugacy.py @@ -145,7 +145,7 @@ def complete_conditional(rv, cond_set=None): s_stat_placeholder = tf.placeholder(tf.float32, s_stat_expr[0][0].get_shape()) swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32) - + s_stat_placeholders.append(s_stat_placeholder) for s_stat_node, multiplier in s_stat_expr: fake_node = s_stat_placeholder * multiplier diff --git a/edward/inferences/gibbs.py b/edward/inferences/gibbs.py index 7b7b0d714..eac294ce7 100644 --- a/edward/inferences/gibbs.py +++ b/edward/inferences/gibbs.py @@ -42,7 +42,7 @@ def __init__(self, latent_vars, proposal_vars=None, data=None): If not specified, default is to use `ed.complete_conditional`. """ super(Gibbs, self).__init__(latent_vars, data) - + if proposal_vars is None: proposal_vars = {z: complete_conditional(z) for z in six.iterkeys(self.latent_vars)}