-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathStreamingEMTree.h
460 lines (411 loc) · 14.8 KB
/
StreamingEMTree.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
#ifndef STREAMINGEMTREE_H
#define STREAMINGEMTREE_H
#include "StdIncludes.h"
#include "SVectorStream.h"
#include "ClusterVisitor.h"
#include "InsertVisitor.h"
#include "tbb/mutex.h"
#include "tbb/pipeline.h"
namespace lmw {
/**
* The streaming version of the EM-tree algorithm does not store the
* data vectors in the tree. Therefore, the leaf level in the tree contain
* cluster representatives.
*
* It has accumulator vectors for centroids at the leaf level. The accumulators
* are used to calculate a mean in the streaming setting. Note that accumulators
* are only needed in the leaf level as means at all higher levels in the tree
* can be calculated from the leaves.
*
* T is the type of vector stored in the node.
*
* ACCUMULATOR is the the type used for the accumulator vectors. For example,
* with bit vectors, integer accumulators are used.
*
* ACCUMULATORs must support being constructed with the number of dimensions,
* auto a = ACCUMULATOR(dimensions);
* They must also support the add operation at a given dimension,
* a[i] += 1;
*
* OPTIMIZER provides the functions necessary for optimization.
*/
template <typename T, typename ACCUMULATOR, typename OPTIMIZER>
class StreamingEMTree {
public:
explicit StreamingEMTree(const Node<T>* root) :
_root(new Node<AccumulatorKey>()) {
_root->setOwnsKeys(true);
deepCopy(root, _root);
}
~StreamingEMTree() {
delete _root;
}
size_t visit(SVectorStream<T>& vs, InsertVisitor<T>& visitor) {
size_t totalRead = 0;
// setup parallel processing pipeline
tbb::parallel_pipeline(_maxtokens,
// Input filter reads readsize chunks of vectors in serial
tbb::make_filter<void, vector < SVector<bool>*>*>(
tbb::filter::serial_out_of_order,
inputFilter(vs, totalRead)
) &
// Visit filter visits readsize chunks of vectors into streaming EM-tree in parallel
tbb::make_filter < vector < SVector<bool>*>*, void>(
tbb::filter::parallel,
[&] (vector < SVector<bool>*>* data) -> void {
visit(*data, visitor);
vs.free(data);
delete data;
}
)
);
return totalRead;
}
void visit(ClusterVisitor<T>& visitor) const {
visit(NULL, _root, visitor);
}
void visit(vector<T*>& data, InsertVisitor<T>& visitor) const {
for (T* object : data) {
visit(_root, object, visitor);
}
}
size_t insert(SVectorStream<T>& vs) {
return insert(vs, -1);
}
/** Returns the total number of vectors read from the stream.
* Returns 0 if the end of the stream has been reached.
*/
size_t insert(SVectorStream<T>& vs, const size_t maxToRead) {
size_t totalRead = 0;
// setup parallel processing pipeline
tbb::parallel_pipeline(_maxtokens,
// Input filter reads readsize chunks of vectors in serial
tbb::make_filter<void, vector < SVector<bool>*>*>(
tbb::filter::serial_out_of_order,
inputFilter(vs, totalRead, maxToRead)
) &
// Insert filter inserts readsize chunks of vectors into streaming EM-tree in parallel
tbb::make_filter < vector < SVector<bool>*>*, void>(
tbb::filter::parallel,
[&] (vector < SVector<bool>*>* data) -> void {
insert(*data);
vs.free(data);
delete data;
}
)
);
return totalRead;
}
/**
* Insert is thread safe. Shared accumulators are locked.
*/
void insert(vector<T*>& data) {
for (T* object : data) {
insert(_root, object);
}
}
int prune() {
return prune(_root);
}
void update() {
update(_root);
}
void clearAccumulators() {
clearAccumulators(_root);
}
int getMaxLevelCount() const {
return maxLevelCount(_root);
}
int getClusterCount(int depth) const {
return clusterCount(_root, depth);
}
uint64_t getObjCount() const {
return objCount(_root);
}
double getRMSE() const {
double RMSE = sumSquaredError(_root);
uint64_t size = getObjCount();
RMSE /= size;
RMSE = sqrt(RMSE);
return RMSE;
}
private:
typedef tbb::mutex Mutex;
struct AccumulatorKey {
AccumulatorKey() : key(NULL), sumSquaredError(0), accumulator(NULL),
count(0), mutex(NULL) { }
~AccumulatorKey() {
if (key) {
delete key;
}
if (accumulator) {
delete accumulator;
}
if (mutex) {
delete mutex;
}
}
T* key;
double sumSquaredError;
ACCUMULATOR* accumulator; // accumulator for partially updated key
uint64_t count; // how many vectors have been added to accumulator
Mutex* mutex;
};
struct Accessor {
T* operator()(AccumulatorKey* accumulatorKey) const {
return accumulatorKey->key;
}
};
void visit(const T* parentKey, const Node<AccumulatorKey>* node,
ClusterVisitor<T>& visitor, const int level = 1) const {
for (size_t i = 0; i < node->size(); i++) {
auto accumulatorKey = node->getKey(i);
uint64_t count = objCount(node, i);
double SSE = sumSquaredError(node, i);
double RMSE = sqrt(SSE / count);
visitor.accept(level, parentKey, accumulatorKey->key, RMSE, count);
if (!node->isLeaf()) {
visit(accumulatorKey->key, node->getChild(i), visitor, level + 1);
}
}
}
Nearest<AccumulatorKey> nearestKey(const T* object,
const Node<AccumulatorKey>* node) const {
return _optimizer.nearest(object, node->getKeys(), _accessor);
}
void visit(const Node<AccumulatorKey>* node, const T* object,
InsertVisitor<T>& visitor, const int level = 1) const {
auto nearest = nearestKey(object, node);
auto accumulatorKey = nearest.key;
visitor.accept(level, object, accumulatorKey->key, nearest.distance);
if (node->isLeaf()) {
// update stats but not accumulators
Mutex::scoped_lock lock(*accumulatorKey->mutex);
accumulatorKey->sumSquaredError +=
_optimizer.squaredDistance(object, accumulatorKey->key);
accumulatorKey->count++;
} else {
visit(node->getChild(nearest.index), object, visitor, level + 1);
}
}
void insert(Node<AccumulatorKey>* node, T* object) {
auto nearest = nearestKey(object, node);
if (node->isLeaf()) {
// update stats and accumulators
auto accumulatorKey = nearest.key;
Mutex::scoped_lock lock(*accumulatorKey->mutex);
T* key = accumulatorKey->key;
accumulatorKey->sumSquaredError += _optimizer.squaredDistance(object, key);
ACCUMULATOR* accumulator = accumulatorKey->accumulator;
for (size_t i = 0; i < accumulator->size(); i++) {
(*accumulator)[i] += (*object)[i];
}
accumulatorKey->count++;
} else {
insert(node->getChild(nearest.index), object);
}
}
int prune(Node<AccumulatorKey>* node) {
int pruned = 0;
for (int i = 0; i < node->size(); i++) {
if (objCount(node, i) == 0) {
node->remove(i);
pruned++;
} else if (!node->isLeaf()) {
pruned += prune(node->getChild(i));
}
}
node->finalizeRemovals();
return pruned;
}
void gatherAccumulators(Node<AccumulatorKey>* node, ACCUMULATOR* total,
uint64_t* totalCount) {
if (node->isLeaf()) {
for (auto accumulatorKey : node->getKeys()) {
auto accumulator = accumulatorKey->accumulator;
for (size_t i = 0; i < accumulator->size(); i++) {
(*total)[i] += (*accumulator)[i];
}
*totalCount += accumulatorKey->count;
}
} else {
for (auto child : node->getChildren()) {
gatherAccumulators(child, total, totalCount);
}
}
}
/**
* TODO(cdevries): Make it work for something other than bitvectors. It needs
* to be parameterized, for example, with float vectors, a mean is taken.
*/
static void updatePrototypeFromAccumulator(T* key, ACCUMULATOR* accumulator,
uint64_t count) {
if (count == 0) return;
// calculate new key based on accumulator
key->setAllBlocks(0);
for (size_t i = 0; i < key->size(); i++) {
if ((*accumulator)[i] > (count / 2)) {
key->set(i);
}
}
}
void update(Node<AccumulatorKey>* node) {
if (node->isLeaf()) {
// leaves flatten accumulators in node
for (auto accumulatorKey : node->getKeys()) {
updatePrototypeFromAccumulator(accumulatorKey->key,
accumulatorKey->accumulator, accumulatorKey->count);
}
} else {
// internal nodes must gather accumulators from leaves
size_t dimensions = node->getKey(0)->key->size();
for (size_t i = 0; i < node->size(); i++) {
auto accumulatorKey = node->getKey(i);
T* key = accumulatorKey->key;
auto child = node->getChild(i);
ACCUMULATOR total(dimensions);
total.setAll(0);
uint64_t totalCount = 0;
gatherAccumulators(child, &total, &totalCount);
updatePrototypeFromAccumulator(key, &total, totalCount);
}
for (auto child : node->getChildren()) {
update(child);
}
}
}
void clearAccumulators(Node<AccumulatorKey>* node) {
if (node->isLeaf()) {
for (auto accumulatorKey : node->getKeys()) {
accumulatorKey->sumSquaredError = 0;
accumulatorKey->accumulator->setAll(0);
accumulatorKey->count = 0;
}
} else {
for (auto child : node->getChildren()) {
clearAccumulators(child);
}
}
}
void deepCopy(const Node<T>* src, Node<AccumulatorKey>* dst) {
if (!src->isEmpty()) {
size_t dimensions = src->getKey(0)->size();
for (size_t i = 0; i < src->size(); i++) {
auto key = src->getKey(i);
auto child = src->getChild(i);
auto accumulatorKey = new AccumulatorKey();
accumulatorKey->key = new T(*key);
if (child->isLeaf()) {
// Do not copy leaves of original tree and setup
// accumulators for the lowest level cluster means.
accumulatorKey->accumulator = new ACCUMULATOR(dimensions);
accumulatorKey->accumulator->setAll(0);
accumulatorKey->mutex = new Mutex();
dst->add(accumulatorKey);
} else {
auto newChild = new Node<AccumulatorKey>();
newChild->setOwnsKeys(true);
deepCopy(child, newChild);
dst->add(accumulatorKey, newChild);
}
}
}
}
std::function<vector<SVector<bool>*>*(tbb::flow_control&)> inputFilter(
SVectorStream<T>& vs, size_t& totalRead, const size_t maxToRead = -1) {
return ([&vs, &totalRead, this, maxToRead]
(tbb::flow_control & fc) -> vector < SVector<bool>*>* {
if (maxToRead > 0 && totalRead >= maxToRead) {
fc.stop();
return NULL;
}
auto data = new vector<T*>;
size_t read = vs.read(_readsize, data);
if (read == 0) {
delete data;
fc.stop();
return NULL;
}
totalRead += data->size();
return data;
});
}
double sumSquaredError(const Node<AccumulatorKey>* node, const size_t i) const {
if (node->isLeaf()) {
return node->getKey(i)->sumSquaredError;
} else {
return sumSquaredError(node->getChild(i));
}
}
double sumSquaredError(const Node<AccumulatorKey>* node) const {
if (node->isLeaf()) {
double localSum = 0;
for (auto key : node->getKeys()) {
localSum += key->sumSquaredError;
}
return localSum;
} else {
double localSum = 0;
for (auto child : node->getChildren()) {
localSum += sumSquaredError(child);
}
return localSum;
}
}
/**
* Object count for cluster i in node.
*/
uint64_t objCount(const Node<AccumulatorKey>* node, const size_t i) const {
if (node->isLeaf()) {
return node->getKey(i)->count;
} else {
return objCount(node->getChild(i));
}
}
uint64_t objCount(const Node<AccumulatorKey>* node) const {
if (node->isLeaf()) {
uint64_t localCount = 0;
for (auto key : node->getKeys()) {
localCount += key->count;
}
return localCount;
} else {
uint64_t localCount = 0;
for (auto child : node->getChildren()) {
localCount += objCount(child);
}
return localCount;
}
}
int maxLevelCount(const Node<AccumulatorKey>* current) const {
if (current->isLeaf()) {
return 1;
} else {
int maxCount = 0;
for (auto child : current->getChildren()) {
maxCount = max(maxCount, maxLevelCount(child));
}
return maxCount + 1;
}
}
int clusterCount(const Node<AccumulatorKey>* current, const int depth) const {
if (depth == 1) {
return current->size();
} else {
int localCount = 0;
for (auto child : current->getChildren()) {
localCount += clusterCount(child, depth - 1);
}
return localCount;
}
}
Node<AccumulatorKey>* _root;
OPTIMIZER _optimizer;
Accessor _accessor;
// How mamny vectors to read at once when processing a stream.
int _readsize = 1000;
// The maximum number of readsize vector chunks that can be loaded at once.
int _maxtokens = 1024;
};
} // namespace lmw
#endif /* STREAMINGEMTREE_H */