Skip to content

Commit

Permalink
fix pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Apr 17, 2017
1 parent 8c46b20 commit 88018c7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions edward/inferences/conjugacy/conjugacy.py
Expand Up @@ -15,16 +15,16 @@
from edward.util import copy, get_blanket


def normal_from_natural_params(p1, p2):
def mvn_diag_from_natural_params(p1, p2):
sigmasq = 0.5 * tf.reciprocal(-p1)
mu = sigmasq * p2
return {'mu': mu, 'sigma': tf.sqrt(sigmasq)}
return {'mu': mu, 'diag_stdev': tf.sqrt(sigmasq)}


def multivariate_normal_diag_from_natural_params(p1, p2):
def normal_from_natural_params(p1, p2):
sigmasq = 0.5 * tf.reciprocal(-p1)
mu = sigmasq * p2
return {'mu': mu, 'diag_stdev': tf.sqrt(sigmasq)}
return {'mu': mu, 'sigma': tf.sqrt(sigmasq)}


_suff_stat_to_dist = defaultdict(dict)
Expand All @@ -43,12 +43,12 @@ def multivariate_normal_diag_from_natural_params(p1, p2):
_suff_stat_to_dist['nonnegative'][(('#CPow-1.0000e+00', ('#x',)),
('#Log', ('#x',)))] = (
rvs.InverseGamma, lambda p1, p2: {'alpha': -p2 - 1, 'beta': -p1})
_suff_stat_to_dist['multivariate_real'][(('#CPow2.0000e+00', ('#x',)),
('#x',))] = (
rvs.MultivariateNormalDiag, mvn_diag_from_natural_params)
_suff_stat_to_dist['real'][(('#CPow2.0000e+00', ('#x',)),
('#x',))] = (
rvs.Normal, normal_from_natural_params)
_suff_stat_to_dist['multivariate_real'][(('#CPow2.0000e+00', ('#x',)),
('#x',))] = (
rvs.MultivariateNormalDiag, multivariate_normal_diag_from_natural_params)


def complete_conditional(rv, cond_set=None):
Expand Down
4 changes: 2 additions & 2 deletions edward/inferences/conjugacy/conjugate_log_probs.py
Expand Up @@ -77,15 +77,15 @@ def poisson_log_prob(self, val):


@_val_wrapper
def multivariate_normal_diag_log_prob(self, val):
def mvn_diag_log_prob(self, val):
mu = self.parameters['mu']
sigma = self.parameters['diag_stdev']
prec = tf.reciprocal(tf.square(sigma))
result = prec * (-0.5 * tf.square(val) - 0.5 * tf.square(mu) +
val * mu)
result -= tf.log(sigma) + 0.5 * tf.log(2 * np.pi)
return result
rvs.MultivariateNormalDiag.conjugate_log_prob = multivariate_normal_diag_log_prob
rvs.MultivariateNormalDiag.conjugate_log_prob = mvn_diag_log_prob


@_val_wrapper
Expand Down

0 comments on commit 88018c7

Please sign in to comment.