This repository has been archived by the owner on Jul 21, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
non-recurrent-guide.wppl
111 lines (98 loc) · 2.54 KB
/
non-recurrent-guide.wppl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
// Run with:
// webppl --require . --require webppl-viz examples/non-recurrent-guide.wppl
var iterate = function(n, prev, f) {
if (n === 0) {
return [];
} else {
var newval = f(prev);
return [newval].concat(iterate(n - 1, newval, f));
}
};
var marginalize = function(dist, project) {
return Infer({model() {
return project(sample(dist));
}});
};
var predictNet = stack([
affine(2, 'p2'),
relu,
affine(60, 'p1')
]);
var predict = function(state) {
var out = predictNet(concat([state.ctx, param({name: 'predict-pos' + state.pos, dims: [1, 1]})]));
return {
mu: T.get(out, 0),
sigma: Math.exp(T.get(out, 1))
};
};
var ctxdim = 4;
var updateNet = linear(ctxdim, 'update-net');
var update = function(state, val, params) {
var pos = param({name: 'update-pos' + state.pos, mu: 0, sigma: 0});
var newCtx = T.add(state.ctx, updateNet(Vector([val * pos])));
return {
val: val,
ctx: newCtx,
pos: state.pos + 1,
params: _.mapValues(params, ad.value)
};
};
var step = function(state) {
var params = predict(state);
var val = sample(Gaussian({mu: 0, sigma: 1}), {guide() {
return Gaussian(params);
}});
return update(state, val, params);
};
var prior = function(n) {
var initialState = {pos: 0, ctx: zeros([ctxdim, 1])};
return iterate(n, initialState, step);
};
var model = function(n) {
var steps = prior(n);
var vals = _.map(steps, 'val');
assert.ok(vals.length >= 2);
var x = first(vals);
var y = last(vals);
//var score = -40 * Math.pow(x - y, 2);
var score = -40 * Math.pow(1 + y - Math.pow(x, 2), 2);
factor(score);
return {x, y};
};
// var m = Infer({
// model() { return model(2); },
// method: 'MCMC',
// //kernel: {HMC: {}},
// kernel: {HMC: {stepSize: .01, steps: 20}},
// samples: 1000,
// burn: 1000
// });
// viz(m);
var N = 8;
var m = Infer({
method: 'optimize',
samples: 1000,
model() { return model(N); },
steps: 10000
});
viz(m);
// Sample from the guide collecting the first value samples, and the
// predicted mus from all steps.
// var m2 = Infer({
// method: 'forward',
// guide: true,
// samples: 100,
// model() {
// var steps = prior(N);
// var x = steps[0].val;
// var mus = map(function(step) { return step.params.mu; }, steps);
// return {x, mus};
// }
// });
// // Viz how the value sampled at the first random choice maps to the
// // predicted mean at each step.
// map(function(i) {
// viz(marginalize(m2, function(obj) {
// return _.zipObject(['x', 'mu' + i], [obj.x, obj.mus[i]]);
// }));
// }, _.range(N));