Skip to content

Commit

Permalink
Add regularization terms in klqp.py (#813)
Browse files Browse the repository at this point in the history
* Add regularization terms in klqp.py

* Fix pep8 error in klqp.py

* Add notes regarding regularization in klqp.py
  • Loading branch information
siddharth-agrawal authored and dustinvtran committed Jan 5, 2018
1 parent 7a583d1 commit 48dbee4
Showing 1 changed file with 60 additions and 11 deletions.
71 changes: 60 additions & 11 deletions edward/inferences/klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class KLqp(VariationalInference):
where $z^{(s)} \sim q(z; \lambda)$ and $\\beta^{(s)}
\sim q(\\beta)$.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -164,6 +167,9 @@ class ReparameterizationKLqp(VariationalInference):
This class minimizes the objective using the reparameterization
gradient.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -221,6 +227,9 @@ class ReparameterizationKLKLqp(VariationalInference):
This class minimizes the objective using the reparameterization
gradient and an analytic KL term.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -292,6 +301,9 @@ class ReparameterizationEntropyKLqp(VariationalInference):
This class minimizes the objective using the reparameterization
gradient and an analytic entropy term.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -350,6 +362,9 @@ class ScoreKLqp(VariationalInference):
This class minimizes the objective using the score function
gradient.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -407,6 +422,9 @@ class ScoreKLKLqp(VariationalInference):
This class minimizes the objective using the score function gradient
and an analytic KL term.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -478,6 +496,9 @@ class ScoreEntropyKLqp(VariationalInference):
This class minimizes the objective using the score function gradient
and an analytic entropy term.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -542,6 +563,9 @@ class ScoreRBKLqp(VariationalInference):
stochastic nodes in the computation graph. It does not
Rao-Blackwellize within a node such as when a node represents
multiple random variables via non-scalar batch shape.
The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -640,14 +664,17 @@ def build_reparam_loss_and_gradients(inference, var_list):

p_log_prob = tf.reduce_mean(p_log_prob)
q_log_prob = tf.reduce_mean(q_log_prob)
reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", p_log_prob,
collections=[inference._summary_key])
tf.summary.scalar("loss/q_log_prob", q_log_prob,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(p_log_prob - q_log_prob)
loss = -(p_log_prob - q_log_prob - reg_penalty)

grads = tf.gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
Expand Down Expand Up @@ -702,13 +729,17 @@ def build_reparam_kl_loss_and_gradients(inference, var_list):
tf.reduce_sum(inference.kl_scaling.get(z, 1.0) * kl_divergence(qz, z))
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_lik", p_log_lik,
collections=[inference._summary_key])
tf.summary.scalar("loss/kl_penalty", kl_penalty,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(p_log_lik - kl_penalty)
loss = -(p_log_lik - kl_penalty - reg_penalty)

grads = tf.gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
Expand Down Expand Up @@ -766,13 +797,17 @@ def build_reparam_entropy_loss_and_gradients(inference, var_list):
tf.reduce_sum(qz.entropy())
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", p_log_prob,
collections=[inference._summary_key])
tf.summary.scalar("loss/q_entropy", q_entropy,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(p_log_prob + q_entropy)
loss = -(p_log_prob + q_entropy - reg_penalty)

grads = tf.gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
Expand Down Expand Up @@ -823,21 +858,24 @@ def build_score_loss_and_gradients(inference, var_list):

p_log_prob = tf.stack(p_log_prob)
q_log_prob = tf.stack(q_log_prob)
reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

losses = p_log_prob - q_log_prob
loss = -tf.reduce_mean(losses)
loss = -(tf.reduce_mean(losses) - reg_penalty)

q_rvs = list(six.itervalues(inference.latent_vars))
q_vars = [v for v in var_list
if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0]
q_grads = tf.gradients(
-tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)),
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)) - reg_penalty),
q_vars)
p_vars = [v for v in var_list if v not in q_vars]
p_grads = tf.gradients(loss, p_vars)
Expand Down Expand Up @@ -891,19 +929,24 @@ def build_score_kl_loss_and_gradients(inference, var_list):
tf.reduce_sum(inference.kl_scaling.get(z, 1.0) * kl_divergence(qz, z))
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_lik", tf.reduce_mean(p_log_lik),
collections=[inference._summary_key])
tf.summary.scalar("loss/kl_penalty", kl_penalty,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(tf.reduce_mean(p_log_lik) - kl_penalty)
loss = -(tf.reduce_mean(p_log_lik) - kl_penalty - reg_penalty)

q_rvs = list(six.itervalues(inference.latent_vars))
q_vars = [v for v in var_list
if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0]
q_grads = tf.gradients(
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty),
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty -
reg_penalty),
q_vars)
p_vars = [v for v in var_list if v not in q_vars]
p_grads = tf.gradients(loss, p_vars)
Expand Down Expand Up @@ -962,22 +1005,26 @@ def build_score_entropy_loss_and_gradients(inference, var_list):
tf.reduce_sum(qz.entropy())
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/q_entropy", q_entropy,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(tf.reduce_mean(p_log_prob) + q_entropy)
loss = -(tf.reduce_mean(p_log_prob) + q_entropy - reg_penalty)

q_rvs = list(six.itervalues(inference.latent_vars))
q_vars = [v for v in var_list
if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0]
q_grads = tf.gradients(
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_prob)) +
q_entropy),
q_entropy - reg_penalty),
q_vars)
p_vars = [v for v in var_list if v not in q_vars]
p_grads = tf.gradients(loss, p_vars)
Expand Down Expand Up @@ -1062,7 +1109,8 @@ def build_score_rb_loss_and_gradients(inference, var_list):
qi_log_prob = tf.stack(qi_log_prob)
grad = tf.gradients(
-tf.reduce_mean(qi_log_prob *
tf.stop_gradient(pi_log_prob - qi_log_prob)),
tf.stop_gradient(pi_log_prob - qi_log_prob)) +
tf.reduce_sum(tf.losses.get_regularization_losses()),
var)
grads.extend(grad)
grads_vars.append(var)
Expand All @@ -1071,7 +1119,8 @@ def build_score_rb_loss_and_gradients(inference, var_list):
loss = -(tf.reduce_mean([tf.reduce_sum(list(six.itervalues(p_log_prob)))
for p_log_prob in p_log_probs]) -
tf.reduce_mean([tf.reduce_sum(list(six.itervalues(q_log_prob)))
for q_log_prob in q_log_probs]))
for q_log_prob in q_log_probs]) -
tf.reduce_sum(tf.losses.get_regularization_losses()))
model_vars = [v for v in var_list if v not in grads_vars]
model_grads = tf.gradients(loss, model_vars)
grads.extend(model_grads)
Expand Down

0 comments on commit 48dbee4

Please sign in to comment.