Skip to content

Commit

Permalink
Move all op creation to initialize(), including finalize()
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 21, 2017
1 parent 175e526 commit 8c2dd82
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
4 changes: 3 additions & 1 deletion edward/inferences/inference.py
Expand Up @@ -131,7 +131,9 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs):

def initialize(self, n_iter=1000, n_print=None, scale=None, logdir=None,
debug=False):
"""Initialize inference algorithm.
"""Initialize inference algorithm. It initializes hyperparameters
and builds ops for the algorithm's computational graph. No ops
should be created outside the call to ``initialize()``.
Parameters
----------
Expand Down
46 changes: 23 additions & 23 deletions edward/inferences/laplace.py
Expand Up @@ -75,14 +75,32 @@ def __init__(self, latent_vars, data=None):
# call grandparent's method; avoid parent (MAP)
super(MAP, self).__init__(latent_vars, data)

def initialize(self, var_list=None, *args, **kwargs):
def initialize(self, *args, **kwargs):
# Store latent variables in a temporary attribute; MAP will
# optimize ``PointMass`` random variables, which subsequently
# optimizes mean parameters of the normal approximations.
self.latent_vars_normal = self.latent_vars.copy()
latent_vars_normal = self.latent_vars.copy()
self.latent_vars = {z: PointMass(params=qz.mu)
for z, qz in six.iteritems(self.latent_vars_normal)}
super(Laplace, self).initialize(var_list, *args, **kwargs)
for z, qz in six.iteritems(latent_vars_normal)}

super(Laplace, self).initialize(*args, **kwargs)

hessians = tf.hessians(self.loss, list(six.itervalues(self.latent_vars)))
self.finalize_ops = []
for z, hessian in zip(six.iterkeys(self.latent_vars), hessians):
qz = latent_vars_normal[z]
sigma_var = get_variables(qz.sigma)[0]
if isinstance(qz, MultivariateNormalCholesky):
sigma = tf.matrix_inverse(tf.cholesky(hessian))
elif isinstance(qz, MultivariateNormalDiag):
sigma = 1.0 / tf.diag_part(hessian)
else: # qz is MultivariateNormalFull
sigma = tf.matrix_inverse(hessian)

self.finalize_ops.append(sigma_var.assign(sigma))

self.latent_vars = latent_vars_normal.copy()
del latent_vars_normal

def finalize(self, feed_dict=None):
"""Function to call after convergence.
Expand All @@ -103,24 +121,6 @@ def finalize(self, feed_dict=None):
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value

var_list = list(six.itervalues(self.latent_vars))
hessians = tf.hessians(self.loss, var_list)

assign_ops = []
for z, hessian in zip(six.iterkeys(self.latent_vars), hessians):
qz = self.latent_vars_normal[z]
sigma_var = get_variables(qz.sigma)[0]
if isinstance(qz, MultivariateNormalCholesky):
sigma = tf.matrix_inverse(tf.cholesky(hessian))
elif isinstance(qz, MultivariateNormalDiag):
sigma = 1.0 / tf.diag_part(hessian)
else: # qz is MultivariateNormalFull
sigma = tf.matrix_inverse(hessian)

assign_ops.append(sigma_var.assign(sigma))

sess = get_session()
sess.run(assign_ops, feed_dict)
self.latent_vars = self.latent_vars_normal.copy()
del self.latent_vars_normal
sess.run(self.finalize_ops, feed_dict)
super(Laplace, self).finalize()

0 comments on commit 8c2dd82

Please sign in to comment.