-
Notifications
You must be signed in to change notification settings - Fork 761
/
normal_sghmc.py
70 lines (58 loc) · 1.85 KB
/
normal_sghmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
"""Correlated normal posterior. Inference with stochastic gradient Hamiltonian
Monte Carlo.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import edward as ed
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from edward.models import Empirical, MultivariateNormalFull
plt.style.use("ggplot")
# Plotting helper function.
def mvn_plot_contours(z, label=False, ax=None):
"""
Plot the contours of 2-d Normal or MultivariateNormalFull object.
Scale the axes to show 3 standard deviations.
"""
sess = ed.get_session()
mu = sess.run(z.mu)
mu_x, mu_y = mu
Sigma = sess.run(z.sigma)
sigma_x, sigma_y = np.sqrt(Sigma[0, 0]), np.sqrt(Sigma[1, 1])
xmin, xmax = mu_x - 3 * sigma_x, mu_x + 3 * sigma_x
ymin, ymax = mu_y - 3 * sigma_y, mu_y + 3 * sigma_y
xs = np.linspace(xmin, xmax, num=100)
ys = np.linspace(ymin, ymax, num=100)
X, Y = np.meshgrid(xs, ys)
T = tf.convert_to_tensor(np.c_[X.flatten(), Y.flatten()], dtype=tf.float32)
Z = sess.run(tf.exp(z.log_prob(T))).reshape((len(xs), len(ys)))
if ax is None:
fig, ax = plt.subplots()
cs = ax.contour(X, Y, Z)
if label:
plt.clabel(cs, inline=1, fontsize=10)
# Example body.
ed.set_seed(42)
# MODEL
z = MultivariateNormalFull(mu=tf.ones(2),
sigma=tf.constant([[1.0, 0.8], [0.8, 1.0]]))
# INFERENCE
qz = Empirical(params=tf.Variable(tf.random_normal([5000, 2])))
inference = ed.SGHMC({z: qz})
inference.run(step_size=0.02)
# CRITICISM
sess = ed.get_session()
mean, std = sess.run([qz.mean(), qz.std()])
print("Inferred posterior mean:")
print(mean)
print("Inferred posterior std:")
print(std)
# VISUALIZATION
fig, ax = plt.subplots()
trace = sess.run(qz.params)
ax.scatter(trace[:, 0], trace[:, 1], marker=".")
mvn_plot_contours(z, ax=ax)
plt.show()