forked from posterior/treecat
/
serving.py
278 lines (236 loc) · 10.7 KB
/
serving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
from scipy.misc import logsumexp
from scipy.stats import entropy
from treecat.structure import TreeStructure
from treecat.structure import make_propagation_schedule
from treecat.util import profile
from treecat.util import sample_from_probs2
logger = logging.getLogger(__name__)
def correlation(probs):
"""Compute correlation rho(X,Y) = sqrt(1 - exp(-2 I(X;Y)))."""
assert probs.shape[0] == probs.shape[1]
mutual_information = (entropy(probs.sum(0)) + entropy(probs.sum(1)) -
entropy(probs.flatten()))
return np.sqrt(1.0 - np.exp(-2.0 * mutual_information))
class TreeCatServer(object):
"""Class for serving queries against a trained TreeCat model."""
def __init__(self, tree, suffstats, config):
logger.info('TreeCatServer with %d features', tree.num_vertices)
assert isinstance(tree, TreeStructure)
ragged_index = suffstats['ragged_index']
self._tree = tree
self._config = config
self._ragged_index = ragged_index
self._schedule = make_propagation_schedule(tree.tree_grid)
self._zero_row = np.zeros(self._ragged_index[-1], np.int8)
# These are useful dimensions to import into locals().
V = self._tree.num_vertices
E = V - 1 # Number of edges in the tree.
M = self._config[
'model_num_clusters'] # Clusters in each mixture model.
self._VEM = (V, E, M)
# Use Jeffreys priors.
vert_prior = 0.5
edge_prior = 0.5 / M
feat_prior = 0.5 / M
meas_prior = feat_prior * np.array(
[(ragged_index[v + 1] - ragged_index[v]) for v in range(V)],
dtype=np.float32).reshape((V, 1))
# These are posterior marginals for vertices and pairs of vertices.
self._vert_probs = suffstats['vert_ss'].astype(np.float32) + vert_prior
self._vert_probs /= self._vert_probs.sum(axis=1, keepdims=True)
self._edge_probs = suffstats['edge_ss'].astype(np.float32) + edge_prior
self._edge_probs /= self._edge_probs.sum(axis=(1, 2), keepdims=True)
# This represents information in the pairwise joint posterior minus
# information in the individual factors.
self._edge_trans = self._edge_probs.copy()
for e, v1, v2 in tree.tree_grid.T:
self._edge_trans[e, :, :] /= self._vert_probs[v1, np.newaxis, :]
self._edge_trans[e, :, :] /= self._vert_probs[v2, :, np.newaxis]
# This is the conditional distribution of features given latent.
self._feat_cond = suffstats['feat_ss'].astype(np.float32) + feat_prior
meas_probs = suffstats['meas_ss'].astype(np.float32) + meas_prior
for v in range(V):
beg, end = ragged_index[v:v + 2]
self._feat_cond[beg:end, :] /= meas_probs[v, np.newaxis, :]
def zero_row(self):
"""Make an empty data row."""
return self._zero_row.copy()
@profile
def sample(self, N, counts, data=None):
"""Draw N samples from the posterior distribution.
Args:
size: The number of samples to draw.
counts: A [V]-shaped numpy array of requested counts of multinomials
to sample.
data: An optional single row of conditioning data, as a ragged nummpy
array of multinomial counts.
Returns:
An [N, _]-shaped numpy array of sampled multinomial data.
"""
logger.debug('sampling data')
V, E, M = self._VEM
if data is None:
data = self._zero_row
assert data.shape == self._zero_row.shape
assert data.dtype == self._zero_row.dtype
assert counts.shape == (V, )
assert counts.dtype == np.int8
edge_trans = self._edge_trans
feat_cond = self._feat_cond
messages_in = self._vert_probs.copy()
messages_out = np.tile(self._vert_probs[:, np.newaxis, :], (1, N, 1))
vert_samples = np.zeros([V, N], np.int8)
feat_samples = np.zeros([N, self._zero_row.shape[0]], np.int8)
range_N = np.arange(N, dtype=np.int32)
for op, v, v2, e in self._schedule:
if op == 0: # OP_UP
# Propagate upward from observed to latent.
message = messages_in[v, :]
beg, end = self._ragged_index[v:v + 2]
for r in range(beg, end):
# This uses a with-replacement approximation which is exact
# for categorical data but approximate for multinomial.
message *= feat_cond[r, :]**data[r]
elif op == 1: # OP_IN
# Propagate latent state inward from children to v.
message = messages_in[v, :]
trans = edge_trans[e, :, :]
if v > v2:
trans = trans.T
message *= np.dot(trans, messages_in[v2, :])
message /= message.sum()
else: # OP_ROOT or OP_OUT
message = messages_out[v, :, :]
message[...] = messages_in[v, np.newaxis, :]
# Propagate latent state outward from parent to v.
if op == 3: # OP_OUT
trans = edge_trans[e, :, :]
if v2 > v:
trans = trans.T
message *= trans[vert_samples[v2, :], :]
message /= message.sum(axis=1, keepdims=True)
vert_samples[v, :] = sample_from_probs2(message)
# Propagate downward from latent to observed.
beg, end = self._ragged_index[v:v + 2]
feat_block = feat_cond[beg:end, :].T
probs = feat_block[vert_samples[v, :], :]
samples_block = feat_samples[:, beg:end]
for _ in range(counts[v]):
samples_block[range_N, sample_from_probs2(probs)] += 1
return feat_samples
@profile
def logprob(self, data):
"""Compute non-normalized log probabilies of many rows of data.
To compute conditional probabilty, use the identity:
log P(data|cond_data) = server.logprob(data + cond_data)
- server.logprob(cond_data)
Args:
data: A [N, _]-shaped ragged nummpy array of multinomial count data,
where N is the number of rows.
Returns:
An [N]-shaped numpy array of log probabilities.
"""
logger.debug('computing logprob')
assert len(data.shape) == 2
assert data.shape[1] == self._ragged_index[-1]
assert data.dtype == np.int8
N = data.shape[0]
V, E, M = self._VEM
edge_trans = self._edge_trans
feat_cond = self._feat_cond
messages = np.tile(self._vert_probs[:, :, np.newaxis], (1, 1, N))
assert messages.shape == (V, M, N)
logprob = np.zeros(N, np.float32)
for op, v, v2, e in self._schedule:
message = messages[v, :, :]
if op == 0: # OP_UP
# Propagate upward from observed to latent.
beg, end = self._ragged_index[v:v + 2]
for r in range(beg, end):
# This uses a with-replacement approximation which is exact
# for categorical data but approximate for multinomial.
power = data[np.newaxis, :, r]
message *= feat_cond[r, :, np.newaxis]**power
elif op == 1: # OP_IN
# Propagate latent state inward from children to v.
trans = edge_trans[e, :, :]
if v > v2:
trans = trans.T
message *= np.dot(trans, messages[v2, :, :])
message_sum = message.sum(axis=0, keepdims=True)
message /= message_sum
logprob += np.log(message_sum[0, :])
elif op == 2: # OP_ROOT
# Aggregate the total logprob.
logprob += np.log(message.sum(axis=0))
return logprob
@profile
def correlation(self):
"""Compute correlation matrix among latent features.
This computes the generalization of Pearson's correlation to discrete
data. Let I(X;Y) be the mutual information. Then define correlation as
rho(X,Y) = sqrt(1 - exp(-2 I(X;Y)))
Returns:
An [V, V] numpy array of feature-feature correlations.
"""
logger.debug('computing correlation')
V, E, M = self._VEM
edge_probs = self._edge_probs
vert_probs = self._vert_probs
result = np.zeros([V, V], np.float32)
for root in range(V):
messages = np.empty([V, M, M])
schedule = make_propagation_schedule(self._tree.tree_grid, root)
for op, v, v2, e in schedule:
if op == 2: # OP_ROOT
messages[v, :, :] = np.diagflat(vert_probs[v, :])
elif op == 3: # OP_OUT
trans = edge_probs[e, :, :]
if v > v2:
trans = trans.T
messages[v, :, :] = np.dot(
trans / vert_probs[v2, np.newaxis, :],
messages[v2, :, :])
for v in range(V):
result[root, v] = correlation(messages[v, :, :])
return result
def serve_model(tree, suffstats, config):
return TreeCatServer(tree, suffstats, config)
class EnsembleServer(object):
"""Class for serving queries against a trained TreeCat ensemble model."""
def __init__(self, ensemble):
logger.info('EnsembleServer of size %d', len(ensemble))
assert ensemble
self._ensemble = [
TreeCatServer(model['tree'], model['suffstats'], model['config'])
for model in ensemble
]
self._zero_row = self._ensemble[0]._zero_row.copy()
def zero_row(self):
"""Make an empty data row."""
return self._zero_row.copy()
def sample(self, N, counts, data=None):
size = len(self._ensemble)
pvals = np.ones(size, dtype=np.float32) / size
sub_Ns = np.random.multinomial(N, pvals)
samples = np.concatenate([
server.sample(sub_N, counts, data)
for server, sub_N in zip(self._ensemble, sub_Ns)
])
np.random.shuffle(samples)
assert samples.shape[0] == N
return samples
def logprob(self, data):
logprobs = np.stack(
[server.logprob(data) for server in self._ensemble])
logprobs = logsumexp(logprobs, axis=0)
logprobs -= np.log(len(self._ensemble))
assert logprobs.shape == (data.shape[0], )
return logprobs
def serve_ensemble(ensemble):
return EnsembleServer(ensemble)