-
Notifications
You must be signed in to change notification settings - Fork 761
/
sghmc.py
137 lines (112 loc) · 4.78 KB
/
sghmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
from edward.inferences.monte_carlo import MonteCarlo
from edward.models import RandomVariable, Empirical
from edward.util import copy
try:
from edward.models import Normal
except Exception as e:
raise ImportError("{0}. Your TensorFlow version is not supported.".format(e))
class SGHMC(MonteCarlo):
"""Stochastic gradient Hamiltonian Monte Carlo [@chen2014stochastic].
#### Notes
In conditional inference, we infer $z$ in $p(z, \\beta
\mid x)$ while fixing inference over $\\beta$ using another
distribution $q(\\beta)$.
`SGHMC` substitutes the model's log marginal density
$\log p(x, z) = \log \mathbb{E}_{q(\\beta)} [ p(x, z, \\beta) ]
\\approx \log p(x, z, \\beta^*)$
leveraging a single Monte Carlo sample, where $\\beta^* \sim
q(\\beta)$. This is unbiased (and therefore asymptotically exact as a
pseudo-marginal method) if $q(\\beta) = p(\\beta \mid x)$.
#### Examples
```python
mu = Normal(loc=0.0, scale=1.0)
x = Normal(loc=mu, scale=1.0, sample_shape=10)
qmu = Empirical(tf.Variable(tf.zeros(500)))
inference = ed.SGHMC({mu: qmu}, {x: np.zeros(10, dtype=np.float32)})
```
"""
def __init__(self, *args, **kwargs):
super(SGHMC, self).__init__(*args, **kwargs)
def initialize(self, step_size=0.25, friction=0.1, *args, **kwargs):
"""Initialize inference algorithm.
Args:
step_size: float, optional.
Constant scale factor of learning rate.
friction: float, optional.
Constant scale on the friction term in the Hamiltonian system.
"""
self.step_size = step_size
self.friction = friction
self.v = {z: tf.Variable(tf.zeros(qz.params.shape[1:], dtype=qz.dtype))
for z, qz in six.iteritems(self.latent_vars)}
return super(SGHMC, self).initialize(*args, **kwargs)
def build_update(self):
"""Simulate Hamiltonian dynamics with friction using a discretized
integrator. Its discretization error goes to zero as the learning
rate decreases.
Implements the update equations from (15) of @chen2014stochastic.
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}
old_v_sample = {z: v for z, v in six.iteritems(self.v)}
# Simulate Hamiltonian dynamics with friction.
learning_rate = self.step_size * 0.01
grad_log_joint = tf.gradients(self._log_joint(old_sample),
list(six.itervalues(old_sample)))
# v_sample is so named b/c it represents a velocity rather than momentum.
sample = {}
v_sample = {}
for z, grad_log_p in zip(six.iterkeys(old_sample), grad_log_joint):
qz = self.latent_vars[z]
event_shape = qz.event_shape
normal = Normal(
loc=tf.zeros(event_shape, dtype=qz.dtype),
scale=(tf.sqrt(tf.cast(learning_rate * self.friction, qz.dtype)) *
tf.ones(event_shape, dtype=qz.dtype)))
sample[z] = old_sample[z] + old_v_sample[z]
v_sample[z] = ((1.0 - 0.5 * self.friction) * old_v_sample[z] +
learning_rate * tf.convert_to_tensor(grad_log_p) +
normal.sample())
# Update Empirical random variables.
assign_ops = []
for z, qz in six.iteritems(self.latent_vars):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t, sample[z]))
assign_ops.append(tf.assign(self.v[z], v_sample[z]).op)
# Increment n_accept.
assign_ops.append(self.n_accept.assign_add(1))
return tf.group(*assign_ops)
def _log_joint(self, z_sample):
"""Utility function to calculate model's log joint density,
log p(x, z), for inputs z (and fixed data x).
Args:
z_sample: dict.
Latent variable keys to samples.
"""
scope = tf.get_default_graph().unique_name("inference")
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
dict_swap = z_sample.copy()
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx
log_joint = 0.0
for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
log_joint += tf.reduce_sum(
self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z]))
for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
log_joint += tf.reduce_sum(
self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x]))
return log_joint