Skip to content

Commit

Permalink
Merge de44561 into 62f4ad5
Browse files Browse the repository at this point in the history
  • Loading branch information
YoshikawaMasashi committed Mar 16, 2018
2 parents 62f4ad5 + de44561 commit b227ff3
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 115 deletions.
3 changes: 2 additions & 1 deletion edward/inferences/monte_carlo.py
Expand Up @@ -78,7 +78,8 @@ def __init__(self, latent_vars=None, data=None):
if isinstance(latent_vars, list):
with tf.variable_scope(None, default_name="posterior"):
latent_vars = {z: Empirical(params=tf.Variable(tf.zeros(
[1e4] + z.batch_shape.concatenate(z.event_shape).as_list())))
[1e4] + z.batch_shape.concatenate(z.event_shape).as_list(),
dtype=z.dtype)))
for z in latent_vars}
elif isinstance(latent_vars, dict):
for qz in six.itervalues(latent_vars):
Expand Down
154 changes: 66 additions & 88 deletions edward/inferences/replica_exchange_mc.py
Expand Up @@ -15,7 +15,7 @@
class _stateful_lambda:
"""Class to use instead of lambda.
lambda is affected by the change of x,
so memory_lambdda output x at the time of definition.
so _stateful_lambda(x)() output x at the time of definition.
"""

def __init__(self, x):
Expand Down Expand Up @@ -54,26 +54,23 @@ def __init__(self, latent_vars, proposal_vars, data=None,
"""
check_latent_vars(proposal_vars)
self.proposal_vars = proposal_vars
super(ReplicaExchangeMC, self).__init__(latent_vars, data)

self.n_replica = len(inverse_temperatures)
if inverse_temperatures[0] != 1:
raise ValueError("inverse_temperatures[0] must be 1.")
self.inverse_temperatures = [tf.convert_to_tensor(inverse_temperature,
dtype=list(latent_vars.values())[0].dtype)
for inverse_temperature in
inverse_temperatures]
self.inverse_temperatures = tf.cast(inverse_temperatures,
dtype=list(self.latent_vars)[0].dtype)

# Make replica.
self.replica_vars = []
for inverse_temperature in self.inverse_temperatures:
for i in range(self.n_replica):
self.replica_vars.append({z: Empirical(params=tf.Variable(tf.zeros(
qz.params.shape, dtype=latent_vars[z].dtype))) for z, qz in
six.iteritems(latent_vars)})
qz.params.shape, dtype=self.latent_vars[z].dtype))) for z, qz in
six.iteritems(self.latent_vars)})

self.exchange_freq = exchange_freq

super(ReplicaExchangeMC, self).__init__(latent_vars, data)

def initialize(self, *args, **kwargs):
kwargs['auto_transform'] = False
return super(ReplicaExchangeMC, self).initialize(*args, **kwargs)
Expand All @@ -84,9 +81,9 @@ def build_update(self):
# Sample by Metropolis-Hastings for each replica.
replica_sample = []
replica_accept = []
for i, inverse_temperature in enumerate(self.inverse_temperatures):
for i in range(self.n_replica):
sample_, accept_ = self._mh_sample(self.replica_vars[i],
inverse_temperature)
self.inverse_temperatures[i])
replica_sample.append(sample_)
replica_accept.append(accept_)
accept = replica_accept[0]
Expand All @@ -95,33 +92,27 @@ def build_update(self):
new_replica_idx = tf.Variable(tf.range(self.n_replica))
new_replica_idx = tf.assign(new_replica_idx, tf.range(self.n_replica))

# Exchange adjacent replicas at frequency of exchange_freq
i = tf.random_uniform((), maxval=2, dtype=tf.int32)

def cond(i, new_replica_idx):
return tf.less(i, self.n_replica - 1)

def body(i, new_replica_idx):
return [i + 2, self._replica_exchange(i, i + 1, replica_sample,
new_replica_idx)]

def exchange_all():
return tf.while_loop(cond, body, loop_vars=[i, new_replica_idx])
# Variable to store ratio of current samples
replica_ratio = tf.Variable(tf.zeros(
self.n_replica, dtype=list(self.latent_vars)[0].dtype))
replica_ratio = self._replica_ratio(replica_ratio, replica_sample)

# Exchange adjacent replicas at frequency of exchange_freq
u = tf.random_uniform([])
exchange = u < self.exchange_freq
i, new_replica_idx = tf.cond(exchange,
exchange_all,
lambda: [i, new_replica_idx])
new_replica_idx = tf.cond(
exchange, lambda: self._replica_exchange(
new_replica_idx, replica_ratio), lambda: new_replica_idx)

# New replica sorted by new_replica_idx
new_replica_sample = []
for i in range(self.n_replica):
new_replica_sample.append(tf.case(
{tf.equal(tf.gather(new_replica_idx, i), j):
_stateful_lambda(replica_sample[j])
for j in range(self.n_replica)}, default=lambda: replica_sample[0],
exclusive=True))
new_replica_sample.append(
{z: tf.case({tf.equal(tf.gather(new_replica_idx, i), j):
_stateful_lambda(replica_sample[j][z])
for j in range(self.n_replica)},
default=lambda: replica_sample[0][z], exclusive=True) for z, qz in
six.iteritems(self.latent_vars)})

assign_ops = []

Expand All @@ -131,7 +122,7 @@ def exchange_all():
assign_ops.append(tf.scatter_update(variable, self.t,
new_replica_sample[0][z]))

for i, inverse_temperature in enumerate(self.inverse_temperatures):
for i in range(self.n_replica):
for z, qz in six.iteritems(self.replica_vars[i]):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t,
Expand Down Expand Up @@ -164,7 +155,6 @@ def _mh_sample(self, latent_vars, inverse_temperature):
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

dict_swap_old = dict_swap.copy()
dict_swap_old.update(old_sample)
base_scope = tf.get_default_graph().unique_name("inference") + '/'
Expand Down Expand Up @@ -223,34 +213,33 @@ def _mh_sample(self, latent_vars, inverse_temperature):
zip(six.iterkeys(new_sample), sample_values)}
return sample, accept

def _replica_exchange(self, candi, candj, replica_sample, new_replica_idx):
def _replica_exchange(self, new_replica_idx, replica_ratio):
"""Exchange replica according to the Metropolis-Hastings criterion.
$\\text{ratio} =
(\log p(x, z_i) - \log p(x, x_j))(\\text{inverse_temperature}_j -
\\text{inverse_temperature}_i)
"""
sample_i = tf.case({tf.equal(new_replica_idx[candi], i): _stateful_lambda(
replica_sample[i])for i in range(self.n_replica)},
default=lambda: replica_sample[0], exclusive=True)
inverse_temperature_i = tf.case({tf.equal(candi, i):
_stateful_lambda(inverse_temperature)
for i, inverse_temperature in
enumerate(self.inverse_temperatures)},
default=lambda:
self.inverse_temperatures[0],
exclusive=True)
sample_j = tf.case({tf.equal(new_replica_idx[candj], i): _stateful_lambda(
replica_sample[i])for i in range(self.n_replica)},
default=lambda: replica_sample[0], exclusive=True)
inverse_temperature_j = tf.case({tf.equal(candj, i):
_stateful_lambda(inverse_temperature)
for i, inverse_temperature in
enumerate(self.inverse_temperatures)},
default=lambda:
self.inverse_temperatures[0],
exclusive=True)
i = tf.random_uniform((), maxval=2, dtype=tf.int32)

ratio = 0.0
def cond(i, new_replica_idx):
return tf.less(i, self.n_replica - 1)

def body(i, new_replica_idx):
ratio = replica_ratio[i] - replica_ratio[i + 1]
ratio *= (self.inverse_temperatures[i + 1] - self.inverse_temperatures[i])
u = tf.random_uniform([], dtype=ratio.dtype)
exchange = tf.log(u) < ratio
new_replica_idx = tf.cond(
exchange,
lambda: tf.scatter_update(new_replica_idx, [i, i + 1], [i + 1, i]),
lambda: new_replica_idx)
return [i + 2, new_replica_idx]

return tf.while_loop(cond, body, loop_vars=[i, new_replica_idx])[1]

def _replica_ratio(self, replica_ratio, replica_sample):
replica_ratio = tf.assign(replica_ratio, tf.zeros(
self.n_replica, dtype=list(self.latent_vars)[0].dtype))

dict_swap = {}
for x, qx in six.iteritems(self.data):
Expand All @@ -260,39 +249,28 @@ def _replica_exchange(self, candi, candj, replica_sample, new_replica_idx):
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx
dict_swap_i = dict_swap.copy()
dict_swap_i.update(sample_i)
dict_swap_j = dict_swap.copy()
dict_swap_j.update(sample_j)

base_scope = tf.get_default_graph().unique_name("inference") + '/'
scope_i = base_scope + '_i'
scope_j = base_scope + '_j'
for i in range(self.n_replica):
dict_swap_i = dict_swap.copy()
dict_swap_i.update(replica_sample[i])

for z in six.iterkeys(self.latent_vars):
# Build priors p(z_i) and p(z_j).
z_i = copy(z, dict_swap_i, scope=scope_i)
z_j = copy(z, dict_swap_j, scope=scope_j)
# Increment ratio.
ratio += tf.reduce_sum(z_i.log_prob(dict_swap_i[z]))
ratio -= tf.reduce_sum(z_j.log_prob(dict_swap_j[z]))
base_scope = tf.get_default_graph().unique_name("inference") + '/'
scope_i = base_scope + '_%d' % i

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
# Build likelihoods p(x | z_i) and p(x | z_j).
x_z_i = copy(x, dict_swap_i, scope=scope_i)
x_z_j = copy(x, dict_swap_j, scope=scope_j)
for z in six.iterkeys(self.latent_vars):
# Build priors p(z_i) and p(z_j).
z_i = copy(z, dict_swap_i, scope=scope_i)
# Increment ratio.
ratio += tf.reduce_sum(x_z_i.log_prob(dict_swap[x]))
ratio -= tf.reduce_sum(x_z_j.log_prob(dict_swap[x]))

ratio *= inverse_temperature_j - inverse_temperature_i

u = tf.random_uniform([], dtype=ratio.dtype)
exchange = tf.log(u) < ratio

# exchange new_replica_idx
return tf.cond(exchange,
lambda: tf.scatter_update(new_replica_idx, [candi, candj],
[candj, candi]),
lambda: new_replica_idx)
replica_ratio = tf.scatter_update(
replica_ratio, i,
replica_ratio[i] + tf.reduce_sum(z_i.log_prob(dict_swap_i[z])))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
# Build likelihoods p(x | z_i) and p(x | z_j).
x_z_i = copy(x, dict_swap_i, scope=scope_i)
# Increment ratio.
replica_ratio = tf.scatter_update(
replica_ratio, i,
replica_ratio[i] + tf.reduce_sum(x_z_i.log_prob(dict_swap[x])))
return replica_ratio
105 changes: 79 additions & 26 deletions tests/inferences/replicaexchangemc_test.py
Expand Up @@ -11,56 +11,109 @@

class test_metropolishastings_class(tf.test.TestCase):

def test_normalnormal_float32(self):
def _test_normal_normal(self, default, dtype):
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

mu = Normal(loc=0.0, scale=1.0)
x = Normal(loc=mu, scale=1.0, sample_shape=50)
mu = Normal(loc=tf.constant(0.0, dtype=dtype),
scale=tf.constant(1.0, dtype=dtype))
x = Normal(loc=mu, scale=tf.constant(1.0, dtype=dtype),
sample_shape=50)

n_samples = 2000
qmu = Empirical(params=tf.Variable(tf.ones(n_samples)))

# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
inference = ed.ReplicaExchangeMC({mu: qmu},
{mu: mu},
data={x: x_data})
if not default:
qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=dtype)))

# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
inference = ed.ReplicaExchangeMC({mu: qmu},
{mu: mu},
data={x: x_data})
else:
inference = ed.ReplicaExchangeMC([mu],
{mu: mu},
data={x: x_data})
qmu = inference.latent_vars[mu]
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-1, atol=1e-1)

old_t, old_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(old_t, n_samples)
if not default:
self.assertEqual(old_t, n_samples)
else:
self.assertEqual(old_t, 1e4)
self.assertGreater(old_n_accept, 0.1)
sess.run(inference.reset)
new_t, new_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(new_t, 0)
self.assertEqual(new_n_accept, 0)

def test_normalnormal_float64(self):
def _test_linear_regression(self, default, dtype):
def build_toy_dataset(N, w, noise_std=0.1):
D = len(w)
x = np.random.randn(N, D)
y = np.dot(x, w) + np.random.normal(0, noise_std, size=N)
return x, y

with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)
N = 40 # number of data points
D = 10 # number of features

mu = Normal(loc=tf.constant(0.0, dtype=tf.float64),
scale=tf.constant(1.0, dtype=tf.float64))
x = Normal(loc=mu,
scale=tf.constant(1.0, dtype=tf.float64),
sample_shape=50)
w_true = np.random.randn(D)
X_train, y_train = build_toy_dataset(N, w_true)
X_test, y_test = build_toy_dataset(N, w_true)

n_samples = 2000
qmu = Empirical(params=tf.Variable(tf.ones(n_samples, dtype=tf.float64)))
X = tf.placeholder(dtype, [N, D])
w = Normal(loc=tf.zeros(D, dtype=dtype), scale=tf.ones(D, dtype=dtype))
b = Normal(loc=tf.zeros(1, dtype=dtype), scale=tf.ones(1, dtype=dtype))
y = Normal(loc=ed.dot(X, w) + b, scale=0.1 * tf.ones(N, dtype=dtype))

proposal_w = Normal(loc=w, scale=0.5 * tf.ones(D, dtype=dtype))
proposal_b = Normal(loc=b, scale=0.5 * tf.ones(1, dtype=dtype))

# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140)
inference = ed.ReplicaExchangeMC({mu: qmu},
{mu: mu},
data={x: x_data})
n_samples = 2000
if not default:
qw = Empirical(tf.Variable(tf.zeros([n_samples, D], dtype=dtype)))
qb = Empirical(tf.Variable(tf.zeros([n_samples, 1], dtype=dtype)))

inference = ed.ReplicaExchangeMC(
{w: qw, b: qb}, {w: proposal_w, b: proposal_b},
data={X: X_train, y: y_train})
else:
inference = ed.ReplicaExchangeMC(
[w, b], {w: proposal_w, b: proposal_b},
data={X: X_train, y: y_train})
qw = inference.latent_vars[w]
qb = inference.latent_vars[b]
inference.run()

self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1)
self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51),
rtol=1e-1, atol=1e-1)
self.assertAllClose(qw.mean().eval(), w_true, rtol=5e-1, atol=5e-1)
self.assertAllClose(qb.mean().eval(), [0.0], rtol=5e-1, atol=5e-1)

old_t, old_n_accept = sess.run([inference.t, inference.n_accept])
if not default:
self.assertEqual(old_t, n_samples)
else:
self.assertEqual(old_t, 1e4)
self.assertGreater(old_n_accept, 0.1)
sess.run(inference.reset)
new_t, new_n_accept = sess.run([inference.t, inference.n_accept])
self.assertEqual(new_t, 0)
self.assertEqual(new_n_accept, 0)

def test_normalnormal(self):
self._test_normal_normal(True, tf.float32)
self._test_normal_normal(False, tf.float32)
self._test_normal_normal(True, tf.float64)
self._test_normal_normal(False, tf.float64)

def test_linearregression(self):
self._test_linear_regression(True, tf.float32)
self._test_linear_regression(False, tf.float32)
self._test_linear_regression(True, tf.float64)
self._test_linear_regression(False, tf.float64)

if __name__ == '__main__':
ed.set_seed(42)
Expand Down

0 comments on commit b227ff3

Please sign in to comment.