Permalink
Browse files

all VariationalInference methods must use build_loss_and_gradients (#385

)

* all VariationalInference methods must use build_loss_and_gradients

* update iwvi

* update docs

* minor fix for debug
  • Loading branch information...
1 parent 7d65f73 commit bbe6a703c07acc37cca06389272f04012cc31b39 @dustinvtran dustinvtran committed on GitHub Dec 17, 2016
@@ -29,9 +29,12 @@ \subsubsection{Developing Inference Algorithms}
have their own default methods.
For example, developing a new variational inference algorithm is as simple as
-inheriting from \texttt{VariationalInference} or one of its derived
-classes. \texttt{VariationalInference} implements many default methods such
+inheriting from \texttt{VariationalInference} and writing a
+\texttt{build_loss_and_gradients()} method. \texttt{VariationalInference} implements many default methods such
as \texttt{initialize()} with options for an optimizer.
+For example, see the
+\href{https://github.com/blei-lab/edward/blob/master/examples/tf_iwvi.py}{importance
+weighted variational inference} script.
\begin{center}\rule{3in}{0.4pt}\end{center}
@@ -265,7 +265,7 @@ def initialize(self, n_iter=1000, n_print=None, n_minibatch=None, scale=None,
logdir : str, optional
Directory where event file will be written. For details,
see `tf.train.SummaryWriter`. Default is to write nothing.
- debug: boolean, optional
+ debug: bool, optional
If True, add checks for NaN and Inf to all computations in the graph.
May result in substantially slower execution times.
"""
@@ -319,8 +319,8 @@ def initialize(self, n_iter=1000, n_print=None, n_minibatch=None, scale=None,
else:
self.logging = False
- if debug:
- self.debug = True
+ self.debug = debug
+ if self.debug:
self.op_check = tf.add_check_numerics_ops()
def update(self):
@@ -101,18 +101,11 @@ def build_loss_and_gradients(self, var_list):
(z_is_normal or hasattr(self.model_wrapper, 'log_lik'))
if is_reparameterizable:
if is_analytic_kl:
- loss = build_reparam_kl_loss(self)
+ return build_reparam_kl_loss_and_gradients(self, var_list)
# elif is_analytic_entropy:
- # loss = build_reparam_entropy_loss(self)
+ # return build_reparam_entropy_loss_and_gradients(self, var_list)
else:
- loss = build_reparam_loss(self)
-
- if var_list is None:
- var_list = tf.trainable_variables()
-
- grads = tf.gradients(loss, [v.ref() for v in var_list])
- grads_and_vars = list(zip(grads, var_list))
- return loss, grads_and_vars
+ return build_reparam_loss_and_gradients(self, var_list)
else:
if is_analytic_kl:
return build_score_kl_loss_and_gradients(self, var_list)
@@ -155,8 +148,8 @@ def initialize(self, n_samples=1, *args, **kwargs):
self.n_samples = n_samples
return super(ReparameterizationKLqp, self).initialize(*args, **kwargs)
- def build_loss(self):
- return build_reparam_loss(self)
+ def build_loss_and_gradients(self, var_list):
+ return build_reparam_loss_and_gradients(self, var_list)
class ReparameterizationKLKLqp(VariationalInference):
@@ -184,8 +177,8 @@ def initialize(self, n_samples=1, *args, **kwargs):
self.n_samples = n_samples
return super(ReparameterizationKLKLqp, self).initialize(*args, **kwargs)
- def build_loss(self):
- return build_reparam_kl_loss(self)
+ def build_loss_and_gradients(self, var_list):
+ return build_reparam_kl_loss_and_gradients(self, var_list)
class ReparameterizationEntropyKLqp(VariationalInference):
@@ -214,8 +207,8 @@ def initialize(self, n_samples=1, *args, **kwargs):
return super(ReparameterizationEntropyKLqp, self).initialize(
*args, **kwargs)
- def build_loss(self):
- return build_reparam_entropy_loss(self)
+ def build_loss_and_gradients(self, var_list):
+ return build_reparam_entropy_loss_and_gradients(self, var_list)
class ScoreKLqp(VariationalInference):
@@ -305,7 +298,7 @@ def build_loss_and_gradients(self, var_list):
return build_score_entropy_loss_and_gradients(self, var_list)
-def build_reparam_loss(inference):
+def build_reparam_loss_and_gradients(inference, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
@@ -369,10 +362,16 @@ def build_reparam_loss(inference):
p_log_prob = tf.pack(p_log_prob)
q_log_prob = tf.pack(q_log_prob)
loss = -tf.reduce_mean(p_log_prob - q_log_prob)
- return loss
+
+ if var_list is None:
+ var_list = tf.trainable_variables()
+
+ grads = tf.gradients(loss, [v.ref() for v in var_list])
+ grads_and_vars = list(zip(grads, var_list))
+ return loss, grads_and_vars
-def build_reparam_kl_loss(inference):
+def build_reparam_kl_loss_and_gradients(inference, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
@@ -436,10 +435,16 @@ def build_reparam_kl_loss(inference):
for qz in six.itervalues(inference.latent_vars)])
loss = -(tf.reduce_mean(p_log_lik) - kl)
- return loss
+
+ if var_list is None:
+ var_list = tf.trainable_variables()
+
+ grads = tf.gradients(loss, [v.ref() for v in var_list])
+ grads_and_vars = list(zip(grads, var_list))
+ return loss, grads_and_vars
-def build_reparam_entropy_loss(inference):
+def build_reparam_entropy_loss_and_gradients(inference, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
@@ -502,7 +507,13 @@ def build_reparam_entropy_loss(inference):
for z, qz in six.iteritems(inference.latent_vars)])
loss = -(tf.reduce_mean(p_log_prob) + q_entropy)
- return loss
+
+ if var_list is None:
+ var_list = tf.trainable_variables()
+
+ grads = tf.gradients(loss, [v.ref() for v in var_list])
+ grads_and_vars = list(zip(grads, var_list))
+ return loss, grads_and_vars
def build_score_loss_and_gradients(inference, var_list):
@@ -99,7 +99,7 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
super(MAP, self).__init__(latent_vars, data, model_wrapper)
- def build_loss(self):
+ def build_loss_and_gradients(self):
"""Build loss function. Its automatic differentiation
is the gradient of
@@ -141,7 +141,14 @@ def build_loss(self):
x = self.data
p_log_prob = self.model_wrapper.log_prob(x, z_mode)
- return -p_log_prob
+ loss = -p_log_prob
+
+ if var_list is None:
+ var_list = tf.trainable_variables()
+
+ grads = tf.gradients(loss, [v.ref() for v in var_list])
+ grads_and_vars = list(zip(grads, var_list))
+ return loss, grads_and_vars
class Laplace(MAP):
@@ -44,6 +44,31 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False,
"""
super(VariationalInference, self).initialize(*args, **kwargs)
+ if var_list is None:
+ if self.model_wrapper is None:
+ # Traverse random variable graphs to get default list of variables.
+ var_list = set([])
+ trainables = tf.trainable_variables()
+ for z, qz in six.iteritems(self.latent_vars):
+ if isinstance(z, RandomVariable):
+ var_list.update(get_variables(z, collection=trainables))
+
+ var_list.update(get_variables(qz, collection=trainables))
+
+ for x, qx in six.iteritems(self.data):
+ if isinstance(x, RandomVariable) and \
+ not isinstance(qx, RandomVariable):
+ var_list.update(get_variables(x, collection=trainables))
+
+ var_list = list(var_list)
+ else:
+ # Variables may not be instantiated for model wrappers until
+ # their methods are first called. For now, hard-code
+ # ``var_list`` inside build_losses.
+ var_list = None
+
+ self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)
+
if optimizer is None:
# Use ADAM with a decaying scale factor.
global_step = tf.Variable(0, trainable=False)
@@ -77,46 +102,12 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False,
else:
raise TypeError()
- if var_list is None:
- if self.model_wrapper is None:
- # Traverse random variable graphs to get default list of variables.
- var_list = set([])
- trainables = tf.trainable_variables()
- for z, qz in six.iteritems(self.latent_vars):
- if isinstance(z, RandomVariable):
- var_list.update(get_variables(z, collection=trainables))
-
- var_list.update(get_variables(qz, collection=trainables))
-
- for x, qx in six.iteritems(self.data):
- if isinstance(x, RandomVariable) and \
- not isinstance(qx, RandomVariable):
- var_list.update(get_variables(x, collection=trainables))
-
- var_list = list(var_list)
- else:
- # Variables may not be instantiated for model wrappers until
- # their methods are first called. For now, hard-code
- # ``var_list`` inside build_losses.
- var_list = None
-
- if getattr(self, 'build_loss_and_gradients', None) is not None:
- self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)
- else:
- self.loss = self.build_loss()
- if var_list is None:
- var_list = tf.trainable_variables()
-
- grads_and_vars = optimizer.compute_gradients(self.loss, var_list=var_list)
-
if not use_prettytensor:
self.train = optimizer.apply_gradients(grads_and_vars,
global_step=global_step)
else:
- if getattr(self, 'build_loss_and_gradients', None) is not None:
- raise NotImplementedError("PrettyTensor optimizer does not accept "
- "manual gradients.")
-
+ # Note PrettyTensor optimizer does not accept manual updates;
+ # it autodiffs the loss directly.
self.train = pt.apply_optimizer(optimizer, losses=[self.loss],
global_step=global_step,
var_list=var_list)
@@ -167,11 +158,11 @@ def print_progress(self, info_dict):
string += ': Loss = {0:.3f}'.format(loss)
print(string)
- def build_loss(self):
+ def build_loss_and_gradients(self, var_list):
"""Build loss function.
- Any derived class of ``VariationalInference`` must implement
- this method or ``build_loss_and_gradients``.
+ Any derived class of ``VariationalInference`` **must** implement
+ this method.
Raises
------
View
@@ -50,13 +50,7 @@ def initialize(self, K=5, *args, **kwargs):
self.K = K
return super(IWVI, self).initialize(*args, **kwargs)
- def build_loss(self):
- if self.score:
- return self.build_score_loss()
- else:
- return self.build_reparam_loss()
-
- def build_score_loss(self):
+ def build_loss_and_gradients(self, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
@@ -91,46 +85,16 @@ def build_score_loss(self):
log_w = tf.reshape(log_w, [self.n_samples, self.K])
# Take log mean exp across importance weights (columns).
losses = log_mean_exp(log_w, 1)
- self.loss = tf.reduce_mean(losses)
- return -tf.reduce_mean(q_log_prob * tf.stop_gradient(losses))
-
- def build_reparam_loss(self):
- """Build loss function. Its automatic differentiation
- is a stochastic gradient of
-
- .. math::
-
- -E_{q(z^1; \lambda), ..., q(z^K; \lambda)} [
- \log 1/K \sum_{k=1}^K p(x, z^k)/q(z^k; \lambda) ]
-
- based on the reparameterization trick. (Kingma and Welling, 2014)
-
- Computed by sampling from :math:`q(z;\lambda)` and evaluating
- the expectation using Monte Carlo sampling. Note there is a
- difference between the number of samples to approximate the
- expectations (`n_samples`) and the number of importance
- samples to determine how many expectations (`K`).
- """
- x = self.data
- # Form n_samples x K matrix of log importance weights.
- log_w = []
- for s in range(self.n_samples * self.K):
- z_sample = {}
- q_log_prob = 0.0
- for z, qz in six.iteritems(self.latent_vars):
- # Copy q(z) to obtain new set of posterior samples.
- qz_copy = copy(qz, scope='inference_' + str(s))
- z_sample[z] = qz_copy.value()
- q_log_prob += tf.reduce_sum(qz.log_prob(z_sample[z]))
+ loss = -tf.reduce_mean(losses)
- p_log_prob = self.model_wrapper.log_prob(x, z_sample)
- log_w += [p_log_prob - q_log_prob]
+ if var_list is None:
+ var_list = tf.trainable_variables()
- log_w = tf.reshape(log_w, [self.n_samples, self.K])
- # Take log mean exp across importance weights (columns).
- losses = log_mean_exp(log_w, 1)
- self.loss = tf.reduce_mean(losses)
- return -self.loss
+ grads = tf.gradients(
+ -tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)),
+ [v.ref() for v in var_list])
+ grads_and_vars = list(zip(grads, var_list))
+ return loss, grads_and_vars
class BetaBernoulli:

0 comments on commit bbe6a70

Please sign in to comment.