forked from probmods/webppl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
daipp.wppl
117 lines (94 loc) · 4.68 KB
/
daipp.wppl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/*
WebPPL functions to manage the context and prediction networks.
These should be used to annotate the target program as so:
Insert an initContext at start of thunk to initialize.
Insert updateContext after each sample statement.
Optional to do updateContext at other places (after primitives, fn returns).
At sample statements guide as so: `sample(origERP, origParams, {guide: DAIPPguide(origERP, origParams)})`.
This file provides a simple helper function sampleDaipp that replaces sample and expands to the
right guide+update code for known ERPs.
It should be possible for the daipp package to provide a macro that replaces sample with sampleDaipp and
inserts initContext at the start of the model fn.
Two key helper functions are in a corresponding DAIPP.js file:
daipp.val2vec embeds js objects into vectors,
daipp.vec2dist chooses an importance ERP ad generates its params.
*/
//A function used to initialize the context, given some data (or a summary of data)
// dritchie: It's not clear to me why we really need another net here? Why not let the initial
// context be the output of val2vec?
var initNet = nn.mlp(daipp.latentSize, [
{nOut: daipp.latentSize, activation: nn.tanh}
], 'initNet', daipp.debug);
initNet.setTraining(true)
var initContext = function(data) {
var dataVec = daipp.val2vec(data)
var context = T.add(dataVec, daipp.nneval(initNet, dataVec));
globalStore.context = context
return
}
//A function used to update the context upon getting a sampled (or deterministic) value
//TODO: make deeper? use GRU / LSTM? ResNet?
//TODO: take optional name/type to allow different update nets, eg for sample vs deterministic vs return.
// var updateNet = nn.mlp(3*daipp.latentSize, [
// {nOut: daipp.latentSize, activation: nn.tanh},
// {nOut: daipp.latentSize}
// ], 'updateNet', daipp.debug);
// updateNet.setTraining(true)
// TODO: Update using a similar scheme to init/predict.
var updateContext = function(val) {
var dataVec = daipp.val2vec(val)
var address = getRelativeAddress()
var addressVec = daipp.val2vec(address)
var context = globalStore.context
var newContext = daipp.nneval(updateNet, ad.tensor.concat(context, addressVec, dataVec));
globalStore.context = newContext
return
}
//A function to predict the params of the importance distribution at a sample statement.
// `ERP` is the distribution we're guiding.
// `val` will usually (by default) be the original ERP params (so
// params === val), but we want to maintain the flexibility to put in
// other info.
// paul, paraphrasing noah:
// when `val` is the original ERP params, we want to treat them as a
// tensor (so that they get minimally mangled).... simplest is
// probably just to upgrade them to tensor in the webppl program
// before calling DAIPPguide. so something like: `DAIPPguide(ERP,
// ad.tensor(params))`.
// paul: note that params isn't necessarily an array of reals.
var predictNet = nn.mlp(1*daipp.latentSize, [
{nOut: daipp.latentSize, activation: nn.tanh}
], 'predictNet', daipp.debug);
predictNet.setTraining(true)
var DAIPPguide = function(ERP, val) {
//var dataVec = daipp.val2vec(val)
//var address = getRelativeAddress()
//TODO: at the moment the address will be a single string, so val2vec will learn a separate embedding per address. we may want to split the string into an array (of syntax sites); in that case val2vec will use an RNN along the array, making related addresses have related vectors.
//var addressVec = daipp.val2vec(address)
// var context = zeros([daipp.latentSize]) //for testing how much the context net matters...
var context = globalStore.context
//merge the val, context, address together
//var predictInput = ad.tensor.concat(context, addressVec, dataVec);
var predict = T.add(context, daipp.nneval(predictNet, context))
//generate params as appropriate to ERP
var guide = daipp.vec2dist(predict, ERP)
// display(guide)
return guide
}
//this helper samples from the guide and updates the context. except when opt.observedVal is defined it instead is a factor enforcing the observation.
var sampleDaipp = function(ERP,opt) {
//FIXME: in fantasy mode for sleepey phase, we want to sample even for observed vals? or just make sure they are undefined in that context?
if(opt == undefined || opt.observedVal == undefined) {
// Convert the ERP params to an array to retain the behavior we
// had under the old ERP interface. Probably better to have some
// handling of this directly in val2vec at some point.
var params = daipp.orderedValues(ERP.params);
var val = sample(ERP, {guide: DAIPPguide(ERP, params)})
//updateContext(val)
return val
} else {
var val = opt.observedVal
factor(ERP.score(val))
return val
}
}