Skip to content

Commit

Permalink
implement vimco by ourself; pin ZhuSuan dependency to the last commit…
Browse files Browse the repository at this point in the history
… (48c0f4e) of 3.x
  • Loading branch information
haowen-xu committed May 26, 2019
1 parent 6bc0e96 commit 94b7f6d
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 217 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v0.2.0-alpha.3
## v0.2.0-alpha.4
This version introduces breaking changes. Existing code might better stick to [v0.1.2](https://github.com/haowen-xu/tfsnippet/tree/v0.1.2)

### Added
Expand All @@ -17,6 +17,7 @@ This version introduces breaking changes. Existing code might better stick to [v
- Added `utils.EventSource`.

### Changed
- Pin `ZhuSuan` dependency to the last commit (48c0f4e) of 3.x.
- `global_reuse`, `instance_reuse`, `reopen_variable_scope`, `root_variable_scope` and `VarScopeObject` have been rewritten, and their behaviors have been slightly changed. This might cause existing code to be malfunction, if these code relies heavily on the precise variable scope or name scope of certain variables or tensors.
- `Trainer` now accepts `summaries` argument on construction.
- `flows` package now moved to `layers.flows`, and all its contents
Expand All @@ -34,3 +35,4 @@ This version introduces breaking changes. Existing code might better stick to [v
- `auto_reuse_variables` has been removed.
- `VariableSaver` has been removed.
- `EarlyStopping` has been removed.
- `VariationalTrainingObjectives.rws_wake` has been removed.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ scipy >= 1.2.0
semver >= 2.7.9
six >= 1.11.0
tqdm >= 4.23.0
git+https://github.com/thu-ml/zhusuan.git
git+https://github.com/thu-ml/zhusuan.git@48c0f4e
158 changes: 158 additions & 0 deletions tests/variational/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from tfsnippet.utils import get_static_shape, ensure_variables_initialized
from tfsnippet.variational import *
from tfsnippet.variational.estimators import (_vimco_replace_diag,
_vimco_control_variate)


def prepare_test_payload(is_reparameterized):
Expand Down Expand Up @@ -200,3 +202,159 @@ def test_nvil(self):
-2 * (f - baseline) * (-3.14 * tf.sin(y)),
axis=0) / 7
]))


def log_mean_exp(x, axis, keepdims=False):
x_max = np.max(x, axis=axis, keepdims=True)
x_max_reduced = x_max if keepdims else np.squeeze(x_max, axis=axis)
out = x_max_reduced + np.log(
np.mean(np.exp(x - x_max), axis=axis, keepdims=keepdims))
return out


def slice_at(arr, axis, start, stop=None, step=None):
if axis < 0:
axis += len(arr.shape)
s = (slice(None, None, None),) * axis + (slice(start, stop, step),)
return arr[s]


def vimco_control_variate(log_f, axis):
K = log_f.shape[axis]
mean_except_k = (np.sum(log_f, axis=axis, keepdims=True) - log_f) / (K - 1)

def sub_k(k):
tmp = np.concatenate(
[slice_at(log_f, axis, 0, k),
slice_at(mean_except_k, axis, k, k + 1),
slice_at(log_f, axis, k+1)],
axis=axis
)
return log_mean_exp(tmp, axis=axis, keepdims=True)

return np.concatenate([sub_k(k) for k in range(K)], axis=axis)


class VIMCOEstimatorTestCase(tf.test.TestCase):

def test_vimco_replace_diag(self):
with self.test_session() as sess:
# 2-d
x = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
y = tf.constant([[10], [11], [12]])
z = sess.run(_vimco_replace_diag(x, y, -2))
np.testing.assert_equal(z, [[10, 2, 3], [4, 11, 6], [7, 8, 12]])

# 4-d
x = np.arange(4 * 3 * 3 * 5, dtype=np.int32).reshape([4, 3, 3, 5])
y = -np.arange(4 * 3 * 1 * 5, dtype=np.int32).reshape([4, 3, 1, 5])
x_ph = tf.placeholder(tf.int32, [None] * 4)
y_ph = tf.placeholder(tf.int32, [None, None, 1, None])
diag_mask = np.eye(3, 3).reshape([1, 3, 3, 1])
z = sess.run(_vimco_replace_diag(
tf.convert_to_tensor(x_ph), tf.convert_to_tensor(y_ph), -3),
feed_dict={x_ph: x, y_ph: y}
)
np.testing.assert_equal(z, x * (1 - diag_mask) + y * diag_mask)

def test_vimco_control_variate(self):
with self.test_session() as sess:
np.random.seed(1234)
log_f = np.random.randn(4, 5, 6, 7).astype(np.float64)
log_f_ph = tf.placeholder(tf.float64, [None] * 4)
rank = len(log_f.shape)

for axis in range(rank):
out = sess.run(_vimco_control_variate(log_f, axis=axis - rank))
out2 = sess.run(
_vimco_control_variate(log_f_ph, axis=axis - rank),
feed_dict={log_f_ph: log_f}
)
ans = vimco_control_variate(log_f, axis=axis - rank)
np.testing.assert_allclose(out, ans)
np.testing.assert_allclose(out2, ans)

def test_error(self):
x, y, z, f, log_f, log_q = \
prepare_test_payload(is_reparameterized=False)

with pytest.raises(ValueError,
match='vimco_estimator requires multi-samples of '
'latent variables'):
_ = vimco_estimator(log_f, log_q, axis=None)

with pytest.raises(TypeError,
match=r'vimco_estimator only supports integer '
r'`axis`: got \[0, 1\]'):
_ = vimco_estimator(log_f, log_q, axis=[0, 1])

with pytest.raises(ValueError,
match='`axis` out of range: rank 2 vs axis 2'):
_ = vimco_estimator(log_f, log_q, axis=2)

with pytest.raises(ValueError,
match='`axis` out of range: rank 2 vs axis -3'):
_ = vimco_estimator(log_f, log_q, axis=-3)

with pytest.raises(ValueError,
match='vimco_estimator only supports `log_values` '
'with deterministic ndims'):
_ = vimco_estimator(
tf.placeholder(tf.float32, None),
tf.zeros([1, 2]),
axis=0
)

with pytest.raises(ValueError,
match='VIMCO requires sample size >= 2: '
'sample axis is 0'):
_ = vimco_estimator(
tf.placeholder(tf.float32, [1, None]),
tf.zeros([1, 2]),
axis=0
)

with pytest.raises(Exception,
match='VIMCO requires sample size >= 2: '
'sample axis is 1'):
ph = tf.placeholder(tf.float32, [3, None])
with tf.Session() as sess:
sess.run(vimco_estimator(ph, tf.zeros([3, 1]), axis=1),
feed_dict={ph: np.zeros([3, 1])})

def test_vimco(self):
assert_allclose = functools.partial(
np.testing.assert_allclose, rtol=1e-5, atol=1e-6)

with self.test_session() as sess:
x, y, z, f, log_f, log_q = \
prepare_test_payload(is_reparameterized=False)

# compute the gradient
x_out, y_out, z_out, f_out, log_f_out, log_q_out = \
sess.run([x, y, z, f, log_f, log_q])
log_q_grad_out = (x_out ** 2 - 1) * 3 * (y_out ** 2)
log_f_out = y_out * z_out

t = np.sum(
log_q_grad_out * (
log_mean_exp(log_f_out, axis=0, keepdims=True) -
vimco_control_variate(log_f_out, axis=0)
),
axis=0
)
w_k_hat = f_out / np.sum(f_out, axis=0, keepdims=True)
log_f_grad_out = z_out
t += np.sum(
w_k_hat * log_f_grad_out,
axis=0
)

cost = vimco_estimator(log_f, log_q, axis=0)
cost_shape = cost.get_shape().as_list()
assert_allclose(sess.run(tf.gradients([cost], [y])[0]), t)

cost_k = vimco_estimator(log_f, log_q, axis=0, keepdims=True)
self.assertListEqual(
[1] + cost_shape, cost_k.get_shape().as_list())
assert_allclose(sess.run(tf.gradients([cost], [y])[0]), t)
152 changes: 20 additions & 132 deletions tests/variational/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import functools
from collections import namedtuple

import numpy as np
import pytest
import tensorflow as tf
import zhusuan as zs

from tfsnippet import Normal
from tfsnippet.ops import add_n_broadcast
from tfsnippet.utils import ensure_variables_initialized, get_static_shape
from tfsnippet.utils import ensure_variables_initialized
from tfsnippet.variational import *


Expand Down Expand Up @@ -43,10 +39,6 @@ def test_errors(self):
ValueError, match='vimco training objective requires '
'multi-samples'):
_ = vi.training.vimco()
with pytest.raises(
ValueError, match='reweighted wake-sleep training objective '
'requires multi-samples'):
_ = vi.training.rws_wake()

def test_elbo(self):
with self.test_session() as sess:
Expand Down Expand Up @@ -167,6 +159,25 @@ def test_iwae(self):
log_p - (log_q1 + log_q2), axis=[0, 1])
np.testing.assert_allclose(*sess.run([output, answer]))

def test_vimco(self):
# test no sampling axis should cause errors
vi = VariationalInference(tf.constant(0.), [tf.constant(0.)],
axis=None)
with pytest.raises(
ValueError, match='iwae training objective '
'requires multi-samples'):
_ = vi.training.iwae()

with self.test_session() as sess:
log_p = tf.random_normal(shape=[5, 7])
log_q = tf.random_normal(shape=[4, 5, 7])

vi = VariationalInference(log_p, [log_q], axis=0)
output = vi.training.vimco()
answer = -vimco_estimator(
log_p - log_q, log_q, axis=0)
np.testing.assert_allclose(*sess.run([output, answer]))

def test_is_loglikelihood(self):
# test no sampling axis should cause errors
vi = VariationalInference(tf.constant(0.), [tf.constant(0.)],
Expand All @@ -186,126 +197,3 @@ def test_is_loglikelihood(self):
answer = importance_sampling_log_likelihood(
log_p, log_q1 + log_q2, axis=[0, 1])
np.testing.assert_allclose(*sess.run([output, answer]))


class VariationalInferenceZhuSuanTestCase(tf.test.TestCase):

def prepare_model(self, zs_func, axis, n_z):
PreparedModel = namedtuple(
'PreparedModel',
['model_func', 'log_joint_func', 'q_net', 'zs_obj', 'log_joint',
'vi']
)

x = tf.constant([1., 2., 3.])
with zs.BayesianNet() as q_net:
z_posterior = zs.Normal('z', mean=x, std=tf.ones([3]),
n_samples=n_z)

def model_func(observed):
with zs.BayesianNet(observed) as net:
z = zs.Normal('z', mean=tf.zeros([3]), std=tf.ones([3]),
n_samples=n_z)
x = zs.Normal('x', mean=z, std=tf.ones([3]))
return net

def log_joint_func(observed):
net = model_func(observed)
return add_n_broadcast(net.local_log_prob(['z', 'x']))

# derive :class:`zhusuan.variational.VariationalObjective`
# by ZhuSuan utilities
latent = {'z': q_net.query(['z'], outputs=True,
local_log_prob=True)[0]}
zs_obj = zs_func(log_joint_func, observed={'x': x}, latent=latent,
axis=axis)

# derive :class:`zhusuan.variational.VariationalObjective`
# by :class:`VariationalInference`
log_joint = log_joint_func({'z': q_net.outputs('z'), 'x': x})
vi = VariationalInference(log_joint, [q_net.local_log_prob('z')],
axis=axis)

return PreparedModel(model_func, log_joint_func, q_net, zs_obj,
log_joint, vi)

def test_elbo(self):
# test no sampling axis
with self.test_session() as sess:
prepared = self.prepare_model(
zs.variational.elbo, axis=None, n_z=None)
vi = prepared.vi
zs_obj = prepared.zs_obj

# test :meth:`VariationalInference.zs_objective`
vi_obj = prepared.vi.zs_objective(zs.variational.elbo)
self.assertIsInstance(
vi_obj, zs.variational.EvidenceLowerBoundObjective)
np.testing.assert_allclose(*sess.run([zs_obj, vi_obj]))

# test :meth:`VariationalInference.zs_elbo`
vi_obj = prepared.vi.zs_elbo()
self.assertIsInstance(
vi_obj, zs.variational.EvidenceLowerBoundObjective)
np.testing.assert_allclose(*sess.run([zs_obj, vi_obj]))

# test with sampling axis
with self.test_session() as sess:
prepared = self.prepare_model(
zs.variational.elbo, axis=0, n_z=7)
vi = prepared.vi
zs_obj = prepared.zs_obj

# test :meth:`VariationalInference.zs_objective`
vi_obj = prepared.vi.zs_objective(zs.variational.elbo)
self.assertIsInstance(
vi_obj, zs.variational.EvidenceLowerBoundObjective)
np.testing.assert_allclose(*sess.run([zs_obj, vi_obj]))

# test :meth:`VariationalInference.zs_elbo`
vi_obj = prepared.vi.zs_elbo()
self.assertIsInstance(
vi_obj, zs.variational.EvidenceLowerBoundObjective)
np.testing.assert_allclose(*sess.run([zs_obj, vi_obj]))

def test_importance_weighted_objective(self):
with self.test_session() as sess:
prepared = self.prepare_model(
zs.variational.importance_weighted_objective, axis=0, n_z=7)
vi = prepared.vi
zs_obj = prepared.zs_obj

# test :meth:`VariationalInference.zs_objective`
vi_obj = prepared.vi.zs_objective(
zs.variational.importance_weighted_objective)
self.assertIsInstance(
vi_obj, zs.variational.ImportanceWeightedObjective)
np.testing.assert_allclose(*sess.run([zs_obj, vi_obj]))

# test :meth:`VariationalInference.zs_importance_weighted_objective`
vi_obj = prepared.vi.zs_importance_weighted_objective()
self.assertIsInstance(
vi_obj, zs.variational.ImportanceWeightedObjective)
np.testing.assert_allclose(*sess.run([zs_obj, vi_obj]))

# test :meth:`VariationalTrainingObjectives.vimco`
np.testing.assert_allclose(
*sess.run([zs_obj.vimco(), vi.training.vimco()]))

def test_klpq(self):
with self.test_session() as sess:
prepared = self.prepare_model(zs.variational.klpq, axis=0, n_z=7)
vi = prepared.vi
zs_obj = prepared.zs_obj

# test :meth:`VariationalInference.zs_objective`
vi_obj = prepared.vi.zs_objective(zs.variational.klpq)
self.assertIsInstance(vi_obj, zs.variational.InclusiveKLObjective)

# test :meth:`VariationalInference.zs_klpq`
vi_obj = prepared.vi.zs_klpq()
self.assertIsInstance(vi_obj, zs.variational.InclusiveKLObjective)

# test :meth:`VariationalTrainingObjectives.rws_wake`
np.testing.assert_allclose(
*sess.run([zs_obj.rws(), vi.training.rws_wake()]))
2 changes: 1 addition & 1 deletion tfsnippet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.0a3'
__version__ = '0.2.0a4'


from . import (dataflows, datasets, distributions, evaluation, layers,
Expand Down
1 change: 1 addition & 0 deletions tfsnippet/variational/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
'VariationalLowerBounds', 'VariationalTrainingObjectives',
'elbo_objective', 'importance_sampling_log_likelihood', 'iwae_estimator',
'monte_carlo_objective', 'nvil_estimator', 'sgvb_estimator',
'vimco_estimator',
]

0 comments on commit 94b7f6d

Please sign in to comment.