-
Notifications
You must be signed in to change notification settings - Fork 761
/
laplace.py
126 lines (108 loc) · 4.8 KB
/
laplace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
from edward.inferences.map import MAP
from edward.models import \
MultivariateNormalCholesky, MultivariateNormalDiag, \
MultivariateNormalFull, PointMass, RandomVariable
from edward.util import get_session, get_variables
class Laplace(MAP):
"""Laplace approximation (Laplace, 1774).
It approximates the posterior distribution using a multivariate
normal distribution centered at the mode of the posterior.
We implement this by running ``MAP`` to find the posterior mode.
This forms the mean of the normal approximation. We then compute the
inverse Hessian at the mode of the posterior. This forms the
covariance of the normal approximation.
"""
def __init__(self, latent_vars, data=None):
"""
Parameters
----------
latent_vars : list of RandomVariable or
dict of RandomVariable to RandomVariable
Collection of random variables to perform inference on. If list,
each random variable will be implictly optimized using a
``MultivariateNormalCholesky`` random variable that is defined
internally (with unconstrained support). If dictionary, each
random variable must be a ``MultivariateNormalCholesky``,
``MultivariateNormalFull``, or ``MultivariateNormalDiag`` random
variable.
Notes
-----
If ``MultivariateNormalDiag`` random variables are specified as
approximations, then the Laplace approximation will only produce
the diagonal. This does not capture correlation among the
variables but it does not require a potentially expensive matrix
inversion.
Examples
--------
>>> X = tf.placeholder(tf.float32, [N, D])
>>> w = Normal(mu=tf.zeros(D), sigma=tf.ones(D))
>>> y = Normal(mu=ed.dot(X, w), sigma=tf.ones(N))
>>>
>>> qw = MultivariateNormalFull(mu=tf.Variable(tf.random_normal([D])),
>>> sigma=tf.Variable(tf.random_normal([D, D])))
>>>
>>> inference = ed.Laplace({w: qw}, data={X: X_train, y: y_train})
"""
if isinstance(latent_vars, list):
with tf.variable_scope("posterior"):
latent_vars = {rv: MultivariateNormalCholesky(
mu=tf.Variable(tf.random_normal(rv.batch_shape())),
chol=tf.Variable(tf.random_normal(
rv.get_batch_shape().concatenate(rv.get_batch_shape()[-1]))))
for rv in latent_vars}
elif isinstance(latent_vars, dict):
for qz in six.itervalues(latent_vars):
if not isinstance(
qz, (MultivariateNormalCholesky, MultivariateNormalDiag,
MultivariateNormalFull)):
raise TypeError("Posterior approximation must consist of only "
"MultivariateCholesky, MultivariateNormalDiag, "
"or MultivariateNormalFull random variables.")
# call grandparent's method; avoid parent (MAP)
super(MAP, self).__init__(latent_vars, data)
def initialize(self, var_list=None, *args, **kwargs):
# Store latent variables in a temporary attribute; MAP will
# optimize ``PointMass`` random variables, which subsequently
# optimizes mean parameters of the normal approximations.
self.latent_vars_normal = self.latent_vars.copy()
self.latent_vars = {z: PointMass(params=qz.mu)
for z, qz in six.iteritems(self.latent_vars_normal)}
super(Laplace, self).initialize(var_list, *args, **kwargs)
def finalize(self, feed_dict=None):
"""Function to call after convergence.
Computes the Hessian at the mode.
Parameters
----------
feed_dict : dict, optional
Feed dictionary for a TensorFlow session run during evaluation
of Hessian. It is used to feed placeholders that are not fed
during initialization.
"""
if feed_dict is None:
feed_dict = {}
for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value
var_list = list(six.itervalues(self.latent_vars))
hessians = tf.hessians(self.loss, var_list)
assign_ops = []
for z, hessian in zip(six.iterkeys(self.latent_vars), hessians):
qz = self.latent_vars_normal[z]
sigma_var = get_variables(qz.sigma)[0]
if isinstance(qz, MultivariateNormalCholesky):
sigma = tf.matrix_inverse(tf.cholesky(hessian))
elif isinstance(qz, MultivariateNormalDiag):
sigma = 1.0 / tf.diag_part(hessian)
else: # qz is MultivariateNormalFull
sigma = tf.matrix_inverse(hessian)
assign_ops.append(sigma_var.assign(sigma))
sess = get_session()
sess.run(assign_ops, feed_dict)
self.latent_vars = self.latent_vars_normal.copy()
del self.latent_vars_normal
super(Laplace, self).finalize()