Skip to content

Commit

Permalink
all VariationalInference methods must use build_loss_and_gradients (#385
Browse files Browse the repository at this point in the history
)

* all VariationalInference methods must use build_loss_and_gradients

* update iwvi

* update docs

* minor fix for debug
  • Loading branch information
dustinvtran committed Dec 17, 2016
1 parent 7d65f73 commit bbe6a70
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 113 deletions.
7 changes: 5 additions & 2 deletions docs/tex/api/inference-development.tex
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ \subsubsection{Developing Inference Algorithms}
have their own default methods.

For example, developing a new variational inference algorithm is as simple as
inheriting from \texttt{VariationalInference} or one of its derived
classes. \texttt{VariationalInference} implements many default methods such
inheriting from \texttt{VariationalInference} and writing a
\texttt{build_loss_and_gradients()} method. \texttt{VariationalInference} implements many default methods such
as \texttt{initialize()} with options for an optimizer.
For example, see the
\href{https://github.com/blei-lab/edward/blob/master/examples/tf_iwvi.py}{importance
weighted variational inference} script.

\begin{center}\rule{3in}{0.4pt}\end{center}

Expand Down
6 changes: 3 additions & 3 deletions edward/inferences/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def initialize(self, n_iter=1000, n_print=None, n_minibatch=None, scale=None,
logdir : str, optional
Directory where event file will be written. For details,
see `tf.train.SummaryWriter`. Default is to write nothing.
debug: boolean, optional
debug: bool, optional
If True, add checks for NaN and Inf to all computations in the graph.
May result in substantially slower execution times.
"""
Expand Down Expand Up @@ -319,8 +319,8 @@ def initialize(self, n_iter=1000, n_print=None, n_minibatch=None, scale=None,
else:
self.logging = False

if debug:
self.debug = True
self.debug = debug
if self.debug:
self.op_check = tf.add_check_numerics_ops()

def update(self):
Expand Down
55 changes: 33 additions & 22 deletions edward/inferences/klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,11 @@ def build_loss_and_gradients(self, var_list):
(z_is_normal or hasattr(self.model_wrapper, 'log_lik'))
if is_reparameterizable:
if is_analytic_kl:
loss = build_reparam_kl_loss(self)
return build_reparam_kl_loss_and_gradients(self, var_list)
# elif is_analytic_entropy:
# loss = build_reparam_entropy_loss(self)
# return build_reparam_entropy_loss_and_gradients(self, var_list)
else:
loss = build_reparam_loss(self)

if var_list is None:
var_list = tf.trainable_variables()

grads = tf.gradients(loss, [v.ref() for v in var_list])
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars
return build_reparam_loss_and_gradients(self, var_list)
else:
if is_analytic_kl:
return build_score_kl_loss_and_gradients(self, var_list)
Expand Down Expand Up @@ -155,8 +148,8 @@ def initialize(self, n_samples=1, *args, **kwargs):
self.n_samples = n_samples
return super(ReparameterizationKLqp, self).initialize(*args, **kwargs)

def build_loss(self):
return build_reparam_loss(self)
def build_loss_and_gradients(self, var_list):
return build_reparam_loss_and_gradients(self, var_list)


class ReparameterizationKLKLqp(VariationalInference):
Expand Down Expand Up @@ -184,8 +177,8 @@ def initialize(self, n_samples=1, *args, **kwargs):
self.n_samples = n_samples
return super(ReparameterizationKLKLqp, self).initialize(*args, **kwargs)

def build_loss(self):
return build_reparam_kl_loss(self)
def build_loss_and_gradients(self, var_list):
return build_reparam_kl_loss_and_gradients(self, var_list)


class ReparameterizationEntropyKLqp(VariationalInference):
Expand Down Expand Up @@ -214,8 +207,8 @@ def initialize(self, n_samples=1, *args, **kwargs):
return super(ReparameterizationEntropyKLqp, self).initialize(
*args, **kwargs)

def build_loss(self):
return build_reparam_entropy_loss(self)
def build_loss_and_gradients(self, var_list):
return build_reparam_entropy_loss_and_gradients(self, var_list)


class ScoreKLqp(VariationalInference):
Expand Down Expand Up @@ -305,7 +298,7 @@ def build_loss_and_gradients(self, var_list):
return build_score_entropy_loss_and_gradients(self, var_list)


def build_reparam_loss(inference):
def build_reparam_loss_and_gradients(inference, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
Expand Down Expand Up @@ -369,10 +362,16 @@ def build_reparam_loss(inference):
p_log_prob = tf.pack(p_log_prob)
q_log_prob = tf.pack(q_log_prob)
loss = -tf.reduce_mean(p_log_prob - q_log_prob)
return loss

if var_list is None:
var_list = tf.trainable_variables()

grads = tf.gradients(loss, [v.ref() for v in var_list])
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars


def build_reparam_kl_loss(inference):
def build_reparam_kl_loss_and_gradients(inference, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
Expand Down Expand Up @@ -436,10 +435,16 @@ def build_reparam_kl_loss(inference):
for qz in six.itervalues(inference.latent_vars)])

loss = -(tf.reduce_mean(p_log_lik) - kl)
return loss

if var_list is None:
var_list = tf.trainable_variables()

grads = tf.gradients(loss, [v.ref() for v in var_list])
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars


def build_reparam_entropy_loss(inference):
def build_reparam_entropy_loss_and_gradients(inference, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
Expand Down Expand Up @@ -502,7 +507,13 @@ def build_reparam_entropy_loss(inference):
for z, qz in six.iteritems(inference.latent_vars)])

loss = -(tf.reduce_mean(p_log_prob) + q_entropy)
return loss

if var_list is None:
var_list = tf.trainable_variables()

grads = tf.gradients(loss, [v.ref() for v in var_list])
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars


def build_score_loss_and_gradients(inference, var_list):
Expand Down
11 changes: 9 additions & 2 deletions edward/inferences/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):

super(MAP, self).__init__(latent_vars, data, model_wrapper)

def build_loss(self):
def build_loss_and_gradients(self):
"""Build loss function. Its automatic differentiation
is the gradient of
Expand Down Expand Up @@ -141,7 +141,14 @@ def build_loss(self):
x = self.data
p_log_prob = self.model_wrapper.log_prob(x, z_mode)

return -p_log_prob
loss = -p_log_prob

if var_list is None:
var_list = tf.trainable_variables()

grads = tf.gradients(loss, [v.ref() for v in var_list])
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars


class Laplace(MAP):
Expand Down
69 changes: 30 additions & 39 deletions edward/inferences/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,31 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False,
"""
super(VariationalInference, self).initialize(*args, **kwargs)

if var_list is None:
if self.model_wrapper is None:
# Traverse random variable graphs to get default list of variables.
var_list = set([])
trainables = tf.trainable_variables()
for z, qz in six.iteritems(self.latent_vars):
if isinstance(z, RandomVariable):
var_list.update(get_variables(z, collection=trainables))

var_list.update(get_variables(qz, collection=trainables))

for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable) and \
not isinstance(qx, RandomVariable):
var_list.update(get_variables(x, collection=trainables))

var_list = list(var_list)
else:
# Variables may not be instantiated for model wrappers until
# their methods are first called. For now, hard-code
# ``var_list`` inside build_losses.
var_list = None

self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)

if optimizer is None:
# Use ADAM with a decaying scale factor.
global_step = tf.Variable(0, trainable=False)
Expand Down Expand Up @@ -77,46 +102,12 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False,
else:
raise TypeError()

if var_list is None:
if self.model_wrapper is None:
# Traverse random variable graphs to get default list of variables.
var_list = set([])
trainables = tf.trainable_variables()
for z, qz in six.iteritems(self.latent_vars):
if isinstance(z, RandomVariable):
var_list.update(get_variables(z, collection=trainables))

var_list.update(get_variables(qz, collection=trainables))

for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable) and \
not isinstance(qx, RandomVariable):
var_list.update(get_variables(x, collection=trainables))

var_list = list(var_list)
else:
# Variables may not be instantiated for model wrappers until
# their methods are first called. For now, hard-code
# ``var_list`` inside build_losses.
var_list = None

if getattr(self, 'build_loss_and_gradients', None) is not None:
self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)
else:
self.loss = self.build_loss()
if var_list is None:
var_list = tf.trainable_variables()

grads_and_vars = optimizer.compute_gradients(self.loss, var_list=var_list)

if not use_prettytensor:
self.train = optimizer.apply_gradients(grads_and_vars,
global_step=global_step)
else:
if getattr(self, 'build_loss_and_gradients', None) is not None:
raise NotImplementedError("PrettyTensor optimizer does not accept "
"manual gradients.")

# Note PrettyTensor optimizer does not accept manual updates;
# it autodiffs the loss directly.
self.train = pt.apply_optimizer(optimizer, losses=[self.loss],
global_step=global_step,
var_list=var_list)
Expand Down Expand Up @@ -167,11 +158,11 @@ def print_progress(self, info_dict):
string += ': Loss = {0:.3f}'.format(loss)
print(string)

def build_loss(self):
def build_loss_and_gradients(self, var_list):
"""Build loss function.
Any derived class of ``VariationalInference`` must implement
this method or ``build_loss_and_gradients``.
Any derived class of ``VariationalInference`` **must** implement
this method.
Raises
------
Expand Down
54 changes: 9 additions & 45 deletions examples/tf_iwvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,7 @@ def initialize(self, K=5, *args, **kwargs):
self.K = K
return super(IWVI, self).initialize(*args, **kwargs)

def build_loss(self):
if self.score:
return self.build_score_loss()
else:
return self.build_reparam_loss()

def build_score_loss(self):
def build_loss_and_gradients(self, var_list):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
Expand Down Expand Up @@ -91,46 +85,16 @@ def build_score_loss(self):
log_w = tf.reshape(log_w, [self.n_samples, self.K])
# Take log mean exp across importance weights (columns).
losses = log_mean_exp(log_w, 1)
self.loss = tf.reduce_mean(losses)
return -tf.reduce_mean(q_log_prob * tf.stop_gradient(losses))

def build_reparam_loss(self):
"""Build loss function. Its automatic differentiation
is a stochastic gradient of
.. math::
-E_{q(z^1; \lambda), ..., q(z^K; \lambda)} [
\log 1/K \sum_{k=1}^K p(x, z^k)/q(z^k; \lambda) ]
based on the reparameterization trick. (Kingma and Welling, 2014)
Computed by sampling from :math:`q(z;\lambda)` and evaluating
the expectation using Monte Carlo sampling. Note there is a
difference between the number of samples to approximate the
expectations (`n_samples`) and the number of importance
samples to determine how many expectations (`K`).
"""
x = self.data
# Form n_samples x K matrix of log importance weights.
log_w = []
for s in range(self.n_samples * self.K):
z_sample = {}
q_log_prob = 0.0
for z, qz in six.iteritems(self.latent_vars):
# Copy q(z) to obtain new set of posterior samples.
qz_copy = copy(qz, scope='inference_' + str(s))
z_sample[z] = qz_copy.value()
q_log_prob += tf.reduce_sum(qz.log_prob(z_sample[z]))
loss = -tf.reduce_mean(losses)

p_log_prob = self.model_wrapper.log_prob(x, z_sample)
log_w += [p_log_prob - q_log_prob]
if var_list is None:
var_list = tf.trainable_variables()

log_w = tf.reshape(log_w, [self.n_samples, self.K])
# Take log mean exp across importance weights (columns).
losses = log_mean_exp(log_w, 1)
self.loss = tf.reduce_mean(losses)
return -self.loss
grads = tf.gradients(
-tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)),
[v.ref() for v in var_list])
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars


class BetaBernoulli:
Expand Down

0 comments on commit bbe6a70

Please sign in to comment.