Skip to content

Commit

Permalink
Merge c1267d5 into 14de728
Browse files Browse the repository at this point in the history
  • Loading branch information
YoshikawaMasashi committed Mar 18, 2018
2 parents 14de728 + c1267d5 commit 09c14b2
Show file tree
Hide file tree
Showing 9 changed files with 361 additions and 128 deletions.
12 changes: 9 additions & 3 deletions edward/inferences/conjugacy/conjugacy.py
Expand Up @@ -137,9 +137,15 @@ def complete_conditional(rv, cond_set=None):
swap_dict = {}
swap_back = {}
for s_stat_expr in six.itervalues(s_stat_exprs):
s_stat_placeholder = tf.placeholder(tf.float32,
s_stat_expr[0][0].get_shape())
swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32)
if rv.dtype == tf.float64:
s_stat_placeholder = tf.placeholder(rv.dtype,
s_stat_expr[0][0].get_shape())
swap_back[s_stat_placeholder] = rv.value()
else:
s_stat_placeholder = tf.placeholder(tf.float32,
s_stat_expr[0][0].get_shape())
swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32)

s_stat_placeholders.append(s_stat_placeholder)
for s_stat_node, multiplier in s_stat_expr:
fake_node = s_stat_placeholder * multiplier
Expand Down
2 changes: 1 addition & 1 deletion edward/inferences/conjugacy/conjugate_log_probs.py
Expand Up @@ -137,7 +137,7 @@ def normal_log_prob(self, val):
prec = tf.reciprocal(tf.square(scale))
result = prec * (-0.5 * tf.square(val) - 0.5 * tf.square(loc) +
val * loc)
result -= tf.log(scale) + 0.5 * tf.log(2 * np.pi)
result -= tf.log(scale) + 0.5 * tf.cast(tf.log(2 * np.pi), dtype=result.dtype)
return result


Expand Down
5 changes: 3 additions & 2 deletions edward/inferences/gibbs.py
Expand Up @@ -41,14 +41,15 @@ def __init__(self, latent_vars, proposal_vars=None, data=None):
binded to its complete conditionals which Gibbs cycles draws on.
If not specified, default is to use `ed.complete_conditional`.
"""
super(Gibbs, self).__init__(latent_vars, data)

if proposal_vars is None:
proposal_vars = {z: complete_conditional(z)
for z in six.iterkeys(latent_vars)}
for z in six.iterkeys(self.latent_vars)}
else:
check_latent_vars(proposal_vars)

self.proposal_vars = proposal_vars
super(Gibbs, self).__init__(latent_vars, data)

def initialize(self, scan_order='random', *args, **kwargs):
"""Initialize inference algorithm. It initializes hyperparameters
Expand Down
53 changes: 38 additions & 15 deletions tests/inferences/gibbs_test.py
Expand Up @@ -11,6 +11,40 @@

class test_gibbs_class(tf.test.TestCase):

def _test_normal_normal(self, default, dtype):
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

mu = Normal(loc=tf.constant(0.0, dtype=dtype),
scale=tf.constant(1.0, dtype=dtype))
x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype),
sample_shape=50)

n_samples = 2000
# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
if not default:
qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype)))
inference = ed.Gibbs({mu: qmu}, data={x: x_data})
else:
inference = ed.Gibbs([mu], data={x: x_data})
qmu = inference.latent_vars[mu]
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-1, atol=1e-1)

old_t, old_n_accept = sess.run([inference.t, inference.n_accept])
if not default:
self.assertEqual(old_t, n_samples)
else:
self.assertEqual(old_t, 1e4)
self.assertGreater(old_n_accept, 0.1)
sess.run(inference.reset)
new_t, new_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(new_t, 0)
self.assertEqual(new_n_accept, 0)

def test_beta_bernoulli(self):
with self.test_session() as sess:
x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1])
Expand All @@ -31,21 +65,10 @@ def test_beta_bernoulli(self):
self.assertAllClose(val_est, val_true, rtol=1e-2, atol=1e-2)

def test_normal_normal(self):
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

mu = Normal(0.0, 1.0)
x = Normal(mu, 1.0, sample_shape=50)

qmu = Empirical(tf.Variable(tf.ones(1000)))

# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
inference = ed.Gibbs({mu: qmu}, data={x: x_data})
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-2, atol=1e-2)
self._test_normal_normal(True, tf.float32)
self._test_normal_normal(False, tf.float32)
self._test_normal_normal(True, tf.float64)
self._test_normal_normal(False, tf.float64)

def test_data_tensor(self):
with self.test_session() as sess:
Expand Down
106 changes: 80 additions & 26 deletions tests/inferences/hmc_test.py
Expand Up @@ -11,42 +11,96 @@

class test_hmc_class(tf.test.TestCase):

def test_normalnormal_float32(self):
def _test_normal_normal(self, default, dtype):
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

mu = Normal(loc=0.0, scale=1.0)
x = Normal(loc=mu, scale=1.0, sample_shape=50)

qmu = Empirical(params=tf.Variable(tf.ones(2000)))
mu = Normal(loc=tf.constant(0.0, dtype=dtype),
scale=tf.constant(1.0, dtype=dtype))
x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype),
sample_shape=50)

n_samples = 2000
# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
inference = ed.HMC({mu: qmu}, data={x: x_data})
if not default:
qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype)))
inference = ed.HMC({mu: qmu}, data={x: x_data})
else:
inference = ed.HMC([mu], data={x: x_data})
qmu = inference.latent_vars[mu]
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2)
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-2, atol=1e-2)
rtol=1e-1, atol=1e-1)

old_t, old_n_accept = sess.run([inference.t, inference.n_accept])
if not default:
self.assertEqual(old_t, n_samples)
else:
self.assertEqual(old_t, 1e4)
self.assertGreater(old_n_accept, 0.1)
sess.run(inference.reset)
new_t, new_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(new_t, 0)
self.assertEqual(new_n_accept, 0)

def _test_linear_regression(self, default, dtype):
def build_toy_dataset(N, w, noise_std=0.1):
D = len(w)
x = np.random.randn(N, D)
y = np.dot(x, w) + np.random.normal(0, noise_std, size=N)
return x, y

def test_normalnormal_float64(self):
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float64)

mu = Normal(loc=tf.constant(0.0, dtype=tf.float64),
scale=tf.constant(1.0, dtype=tf.float64))
x = Normal(loc=mu,
scale=tf.constant(1.0, dtype=tf.float64),
sample_shape=50)

qmu = Empirical(params=tf.Variable(tf.ones(2000, dtype=tf.float64)))

# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
inference = ed.HMC({mu: qmu}, data={x: x_data})
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-2, atol=1e-2)
N = 40 # number of data points
D = 10 # number of features

w_true = np.random.randn(D)
X_train, y_train = build_toy_dataset(N, w_true)
X_test, y_test = build_toy_dataset(N, w_true)

X = tf.placeholder(dtype, [N, D])
w = Normal(loc=tf.zeros(D, dtype=dtype), scale=tf.ones(D, dtype=dtype))
b = Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype))
y = Normal(loc=ed.dot(X, w) + b, scale=0.1 * tf.ones(N, dtype=dtype))

n_samples = 2000
if not default:
qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype)))
qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype)))
inference = ed.HMC({w: qw, b: qb}, data={X: X_train, y: y_train})
else:
inference = ed.HMC([w, b], data={X: X_train, y: y_train})
qw = inference.latent_vars[w]
qb = inference.latent_vars[b]
inference.run(step_size=0.01)

self.assertAllClose(qw.mean().eval(), w_true, rtol=5e-1, atol=5e-1)
self.assertAllClose(qb.mean().eval(), [0.0], rtol=5e-1, atol=5e-1)

old_t, old_n_accept = sess.run([inference.t, inference.n_accept])
if not default:
self.assertEqual(old_t, n_samples)
else:
self.assertEqual(old_t, 1e4)
self.assertGreater(old_n_accept, 0.1)
sess.run(inference.reset)
new_t, new_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(new_t, 0)
self.assertEqual(new_n_accept, 0)

def test_normal_normal(self):
self._test_normal_normal(True, tf.float32)
self._test_normal_normal(False, tf.float32)
self._test_normal_normal(True, tf.float64)
self._test_normal_normal(False, tf.float64)

def test_linear_regression(self):
self._test_linear_regression(True, tf.float32)
self._test_linear_regression(False, tf.float32)
self._test_linear_regression(True, tf.float64)
self._test_linear_regression(False, tf.float64)

def test_indexedslices(self):
"""Test that gradients accumulate when tf.gradients doesn't return
Expand Down
97 changes: 72 additions & 25 deletions tests/inferences/metropolishastings_test.py
Expand Up @@ -11,56 +11,103 @@

class test_metropolishastings_class(tf.test.TestCase):

def test_normalnormal_float32(self):
def _test_normal_normal(self, default, dtype):
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

mu = Normal(loc=0.0, scale=1.0)
x = Normal(loc=mu, scale=1.0, sample_shape=50)
mu = Normal(loc=tf.constant(0.0, dtype=dtype),
scale=tf.constant(1.0, dtype=dtype))
x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype),
sample_shape=50)

n_samples = 2000
qmu = Empirical(params=tf.Variable(tf.ones(n_samples)))

# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
inference = ed.MetropolisHastings({mu: qmu},
{mu: mu},
data={x: x_data})
if not default:
qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype)))
inference = ed.MetropolisHastings({mu: qmu}, {mu: mu}, data={x: x_data})
else:
inference = ed.MetropolisHastings([mu], {mu: mu}, data={x: x_data})
qmu = inference.latent_vars[mu]
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-1, atol=1e-1)

old_t, old_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(old_t, n_samples)
if not default:
self.assertEqual(old_t, n_samples)
else:
self.assertEqual(old_t, 1e4)
self.assertGreater(old_n_accept, 0.1)
sess.run(inference.reset)
new_t, new_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(new_t, 0)
self.assertEqual(new_n_accept, 0)

def test_normalnormal_float64(self):
def _test_linear_regression(self, default, dtype):
def build_toy_dataset(N, w, noise_std=0.1):
D = len(w)
x = np.random.randn(N, D)
y = np.dot(x, w) + np.random.normal(0, noise_std, size=N)
return x, y

with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)
N = 40 # number of data points
D = 10 # number of features

mu = Normal(loc=tf.constant(0.0, dtype=tf.float64),
scale=tf.constant(1.0, dtype=tf.float64))
x = Normal(loc=mu,
scale=tf.constant(1.0, dtype=tf.float64),
sample_shape=50)
w_true = np.random.randn(D)
X_train, y_train = build_toy_dataset(N, w_true)
X_test, y_test = build_toy_dataset(N, w_true)

n_samples = 2000
qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=tf.float64)))
X = tf.placeholder(dtype, [N, D])
w = Normal(loc=tf.zeros(D, dtype=dtype), scale=tf.ones(D, dtype=dtype))
b = Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype))
y = Normal(loc=ed.dot(X, w) + b, scale=0.1 * tf.ones(N, dtype=dtype))

# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
inference = ed.MetropolisHastings({mu: qmu},
{mu: mu},
data={x: x_data})
proposal_w = Normal(loc=w, scale=0.5 * tf.ones(D, dtype=dtype))
proposal_b = Normal(loc=b, scale=0.5 * tf.ones(1, dtype=dtype))

n_samples = 2000
if not default:
qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype)))
qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype)))
inference = ed.MetropolisHastings(
{w: qw, b: qb}, {w: proposal_w, b: proposal_b},
data={X: X_train, y: y_train})
else:
inference = ed.MetropolisHastings(
[w, b], {w: proposal_w, b: proposal_b},
data={X: X_train, y: y_train})
qw = inference.latent_vars[w]
qb = inference.latent_vars[b]
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-1, atol=1e-1)
self.assertAllClose(qw.mean().eval(), w_true, rtol=5e-1, atol=5e-1)
self.assertAllClose(qb.mean().eval(), [0.0], rtol=5e-1, atol=5e-1)

old_t, old_n_accept = sess.run([inference.t, inference.n_accept])
if not default:
self.assertEqual(old_t, n_samples)
else:
self.assertEqual(old_t, 1e4)
self.assertGreater(old_n_accept, 0.1)
sess.run(inference.reset)
new_t, new_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(new_t, 0)
self.assertEqual(new_n_accept, 0)

def test_normal_normal(self):
self._test_normal_normal(True, tf.float32)
self._test_normal_normal(False, tf.float32)
self._test_normal_normal(True, tf.float64)
self._test_normal_normal(False, tf.float64)

def test_linear_regression(self):
self._test_linear_regression(True, tf.float32)
self._test_linear_regression(False, tf.float32)
self._test_linear_regression(True, tf.float64)
self._test_linear_regression(False, tf.float64)

if __name__ == '__main__':
ed.set_seed(42)
Expand Down

0 comments on commit 09c14b2

Please sign in to comment.