forked from probmods/webppl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ldaCollapsed.wppl
123 lines (103 loc) · 3.48 KB
/
ldaCollapsed.wppl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Dirichlet-discrete with sample, observe & getCount (could go in webppl header) // NEW
var makeDirichletDiscrete = function(pseudocounts) {
var addCount = function(a, i, j) {
var j = j == undefined ? 0 : j;
if (a.length == 0) {
return [];
} else {
return [a[0] + (i == j)].concat(addCount(a.slice(1), i, j + 1));
}
};
globalStore.DDindex = 1 + (globalStore.DDindex == undefined ? 0 : globalStore.DDindex);
var ddname = 'DD' + globalStore.DDindex;
globalStore[ddname] = pseudocounts;
var ddSample = function() {
var pc = globalStore[ddname]; // get current sufficient stats
var val = sample(discreteERP, [pc]); // sample from predictive. (doesn't need to be normalized.)
globalStore[ddname] = addCount(pc, val); // update sufficient stats
return val;
};
var ddObserve = function(val) {
var pc = globalStore[ddname]; // get current sufficient stats
factor(discreteERP.score([util.normalizeArray(pc)], val));
// score based on predictive distribution (normalize counts)
globalStore[ddname] = addCount(pc, val); // update sufficient stats
};
var ddCounts = function() {
return globalStore[ddname];
};
return {
'sample': ddSample,
'observe': ddObserve,
'getCounts': ddCounts
};
};
var dirichletDiscreteFactor = function(vs, dd, v) { // NEW
var i = indexOf(v, vs);
var observe = dd.observe;
observe(i);
}
// Parameters
var vocabulary = ['bear', 'wolf', 'python', 'prolog'];
var topics = {
'topic1': null,
'topic2': null
};
var docs = {
'doc1': 'bear wolf bear wolf bear wolf python wolf bear wolf'.split(' '),
'doc2': 'python prolog python prolog python prolog python prolog python prolog'.split(' '),
'doc3': 'bear wolf bear wolf bear wolf bear wolf bear wolf'.split(' '),
'doc4': 'python prolog python prolog python prolog python prolog python prolog'.split(' '),
'doc5': 'bear wolf bear python bear wolf bear wolf bear wolf'.split(' ')
};
// Constants and helper functions
var ones = function(n) {
return repeat(n, function() {return 1.0;});
}
var mapObject = function(fn, obj) {
return _.object(
map(
function(kv) {
return [kv[0], fn(kv[0], kv[1])]
},
_.pairs(obj))
);
}
// Model
var makeWordDist = function() {
return makeDirichletDiscrete(ones(vocabulary.length)); // NEW
};
var makeTopicDist = function() {
return makeDirichletDiscrete(ones(_.size(topics))); // NEW
};
var model = function() {
var wordDistForTopic = mapObject(makeWordDist, topics);
var topicDistForDoc = mapObject(makeTopicDist, docs);
var makeTopicForWord = function(docName, word) {
var sampleDD = topicDistForDoc[docName].sample; // NEW
var i = sampleDD(); // NEW
return _.keys(topics)[i];
};
var makeWordTopics = function(docName, words) {
return map(function(word) {return makeTopicForWord(docName, word);},
words);
};
var topicsForDoc = mapObject(makeWordTopics, docs);
mapObject(
function(docName, words) {
map2(
function(topic, word) {
dirichletDiscreteFactor(vocabulary, wordDistForTopic[topic], word); // NEW
},
topicsForDoc[docName],
words);
},
docs);
// Print out pseudecounts of (dirichlet-discrete) word distributions for topics // NEW
var getCounts1 = wordDistForTopic['topic1'].getCounts;
var getCounts2 = wordDistForTopic['topic2'].getCounts;
var counts = [getCounts1(), getCounts2()];
// console.log(counts);
return counts;
};
MH(model, 5000)