This repository has been archived by the owner on Jul 21, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
vae.wppl
61 lines (50 loc) · 1.55 KB
/
vae.wppl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Variational Autoencoder
// https://arxiv.org/abs/1312.6114
var zDim = 2;
var hDecodeDim = 500;
var hEncodeDim = 500;
var xDim = 784;
// Requires the mnist data set to be downloaded and unpacked. See
// examples/data/mnist.js.
// Requires https://github.com/null-a/webppl-fs
// Run with:
// webppl --require . --require webppl-fs examples/vae.wppl
var images = map(Vector, JSON.parse(fs.read('examples/data/mnist_images.json')));
// Recognition network.
// Maps from an input image to the parameters of the guide
// distribution.
var encode = function(x) {
var encH = compose(tanh, affine('encH', {in: xDim, out: hEncodeDim}));
var encM = affine('encM', {in: hEncodeDim, out: zDim});
var encS = affine('encS', {in: hEncodeDim, out: zDim});
var h = encH(x);
var mu = encM(h);
var sigma = T.exp(encS(h));
return {mu, sigma};
};
// Generative network.
// Maps from the latent space to pixels.
var sampleDecoder = function() {
return stack([
sigmoid,
affine('dec1', {in: hDecodeDim, out: xDim, param: modelParam}),
tanh,
affine('dec0', {in: zDim, out: hDecodeDim, param: modelParam})
]);
};
var zPrior = TensorGaussian({mu: 0, sigma: 1, dims: [zDim, 1]});
var model = function() {
var decode = sampleDecoder();
mapData({data: images, batchSize: 100}, function(image) {
var z = sample(zPrior, {guide() {
return DiagCovGaussian(encode(image));
}});
observe(MultivariateBernoulli({ps: decode(z)}), image);
});
};
Optimize({
model,
steps: 500,
estimator: {ELBO: {samples: 1}},
optMethod: {adam: {stepSize: 0.001}}
});