-
Notifications
You must be signed in to change notification settings - Fork 761
/
hierarchical_logistic_regression.py
131 lines (114 loc) · 4.19 KB
/
hierarchical_logistic_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python
"""
Hierarchical logistic regression using mean-field variational inference.
Probability model:
Hierarchical logistic regression
Prior: Normal
Likelihood: Bernoulli-Logit
Variational model
Likelihood: Mean-field Normal
"""
import edward as ed
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from edward.stats import bernoulli, norm
from edward.variationals import Variational, Normal
class HierarchicalLogistic:
"""
Hierarchical logistic regression for outputs y on inputs x.
p((x,y), z) = Bernoulli(y | link^{-1}(x*z)) *
Normal(z | 0, prior_variance),
where z are weights, and with known link function and
prior_variance.
Parameters
----------
weight_dim : list
Dimension of weights, which is dimension of input x dimension
of output.
inv_link : function, optional
Inverse of link function, which is applied to the linear transformation.
prior_variance : float, optional
Variance of the normal prior on weights; aka L2
regularization parameter, ridge penalty, scale parameter.
"""
def __init__(self, weight_dim, inv_link=tf.sigmoid, prior_variance=0.01):
self.weight_dim = weight_dim
self.inv_link = inv_link
self.prior_variance = prior_variance
self.num_vars = (self.weight_dim[0]+1)*self.weight_dim[1]
def mapping(self, x, z):
"""
Inverse link function on linear transformation,
link^{-1}(W*x + b)
"""
m, n = self.weight_dim[0], self.weight_dim[1]
W = tf.reshape(z[:m*n], [m, n])
b = tf.reshape(z[m*n:], [1, n])
# broadcasting to do (x*W) + b (e.g. 40x10 + 1x10)
h = self.inv_link(tf.matmul(x, W) + b)
h = tf.squeeze(h) # n_data x 1 to n_data
return h
def log_prob(self, xs, zs):
"""Returns a vector [log p(xs, zs[1,:]), ..., log p(xs, zs[S,:])]."""
# Data must have labels in the first column and features in
# subsequent columns.
y = xs[:, 0]
x = xs[:, 1:]
log_lik = []
for z in tf.unpack(zs):
p = self.mapping(x, z)
log_lik += [bernoulli.logpmf(y, p)]
log_lik = tf.pack(log_lik)
log_prior = -self.prior_variance * tf.reduce_sum(zs*zs, 1)
return log_lik + log_prior
def build_toy_dataset(n_data=40, noise_std=0.1):
ed.set_seed(0)
D = 1
x = np.linspace(-3, 3, num=n_data)
y = np.tanh(x) + norm.rvs(0, noise_std, size=n_data)
y[y < 0.5] = 0
y[y >= 0.5] = 1
x = (x - 4.0) / 4.0
x = x.reshape((n_data, D))
y = y.reshape((n_data, 1))
data = np.concatenate((y, x), axis=1) # n_data x (D+1)
data = tf.constant(data, dtype=tf.float32)
return ed.Data(data)
ed.set_seed(42)
model = HierarchicalLogistic(weight_dim=[1,1])
variational = Variational()
variational.add(Normal(model.num_vars))
data = build_toy_dataset()
# Set up figure
fig = plt.figure(figsize=(8,8), facecolor='white')
ax = fig.add_subplot(111, frameon=False)
plt.ion()
plt.show(block=False)
def print_progress(self, t, losses, sess):
if t % self.n_print == 0:
print("iter %d loss %.2f " % (t, np.mean(losses)))
self.variational.print_params(sess)
# Sample functions from variational model
mean, std = sess.run([self.variational.layers[0].m,
self.variational.layers[0].s])
rs = np.random.RandomState(0)
zs = rs.randn(10, self.variational.num_vars) * std + mean
zs = tf.constant(zs, dtype=tf.float32)
inputs = np.linspace(-3, 3, num=400, dtype=np.float32)
x = tf.expand_dims(tf.constant(inputs), 1)
mus = tf.pack([self.model.mapping(x, z) for z in tf.unpack(zs)])
outputs = sess.run(mus)
# Get data
y, x = sess.run([self.data.data[:, 0], self.data.data[:, 1]])
# Plot data and functions
plt.cla()
ax.plot(x, y, 'bx')
ax.plot(inputs, outputs.T)
ax.set_xlim([-3, 3])
ax.set_ylim([-0.5, 1.5])
plt.draw()
ed.MFVI.print_progress = print_progress
inference = ed.MFVI(model, variational, data)
# TODO it gets NaN's at iteration 608 and beyond
inference.run(n_iter=600, n_print=5)