/
metropolishastings.go
213 lines (191 loc) · 7.52 KB
/
metropolishastings.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package samplemv
import (
"math"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat/distmv"
)
var _ Sampler = MetropolisHastingser{}
// MHProposal defines a proposal distribution for Metropolis Hastings.
type MHProposal interface {
// ConditionalLogProb returns the probability of the first argument
// conditioned on being at the second argument.
// p(x|y)
// ConditionalLogProb panics if the input slices are not the same length.
ConditionalLogProb(x, y []float64) (prob float64)
// ConditionalRand generates a new random location conditioned being at the
// location y. If the first argument is nil, a new slice is allocated and
// returned. Otherwise, the random location is stored in-place into the first
// argument, and ConditionalRand will panic if the input slice lengths differ.
ConditionalRand(x, y []float64) []float64
}
// MetropolisHastingser is a type for generating samples using the Metropolis Hastings
// algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm),
// with the given target and proposal distributions, starting at the location
// specified by Initial. If src != nil, it will be used to generate random
// numbers, otherwise rand.Float64 will be used.
//
// Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates
// samples according to the distribution specified by target using the Markov
// chain implicitly defined by the proposal distribution. At each
// iteration, a proposal point is generated randomly from the current location.
// This proposal point is accepted with probability
// p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current)))
// If the new location is accepted, it becomes the new current location.
// If it is rejected, the current location remains. This is the sample stored in
// batch, ignoring BurnIn and Rate (discussed below).
//
// The samples in Metropolis Hastings are correlated with one another through the
// Markov chain. As a result, the initial value can have a significant influence
// on the early samples, and so, typically, the first samples generated by the chain
// are ignored. This is known as "burn-in", and the number of samples ignored
// at the beginning is specified by BurnIn. The proper BurnIn value will depend
// on the mixing time of the Markov chain defined by the target and proposal
// distributions.
//
// Many choose to have a sampling "rate" where a number of samples
// are ignored in between each kept sample. This helps decorrelate
// the samples from one another, but also reduces the number of available samples.
// This value is specified by Rate. If Rate is 0 it is defaulted to 1 (keep
// every sample).
//
// The initial value is NOT changed during calls to Sample.
type MetropolisHastingser struct {
Initial []float64
Target distmv.LogProber
Proposal MHProposal
Src rand.Source
BurnIn int
Rate int
}
// Sample generates rows(batch) samples using the Metropolis Hastings sample
// generation method. The initial location is NOT updated during the call to Sample.
//
// The number of columns in batch must equal len(m.Initial), otherwise Sample
// will panic.
func (m MetropolisHastingser) Sample(batch *mat.Dense) {
burnIn := m.BurnIn
rate := m.Rate
if rate == 0 {
rate = 1
}
r, c := batch.Dims()
if len(m.Initial) != c {
panic("metropolishastings: length mismatch")
}
// Use the optimal size for the temporary memory to allow the fewest calls
// to MetropolisHastings. The case where tmp shadows samples must be
// aligned with the logic after burn-in so that tmp does not shadow samples
// during the rate portion.
tmp := batch
if rate > r {
tmp = mat.NewDense(rate, c, nil)
}
rTmp, _ := tmp.Dims()
// Perform burn-in.
remaining := burnIn
initial := make([]float64, c)
copy(initial, m.Initial)
for remaining != 0 {
newSamp := min(rTmp, remaining)
metropolisHastings(tmp.Slice(0, newSamp, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src)
copy(initial, tmp.RawRowView(newSamp-1))
remaining -= newSamp
}
if rate == 1 {
metropolisHastings(batch, initial, m.Target, m.Proposal, m.Src)
return
}
if rTmp <= r {
tmp = mat.NewDense(rate, c, nil)
}
// Take a single sample from the chain.
metropolisHastings(batch.Slice(0, 1, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src)
copy(initial, batch.RawRowView(0))
// For all of the other samples, first generate Rate samples and then actually
// accept the last one.
for i := 1; i < r; i++ {
metropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src)
v := tmp.RawRowView(rate - 1)
batch.SetRow(i, v)
copy(initial, v)
}
}
func metropolisHastings(batch *mat.Dense, initial []float64, target distmv.LogProber, proposal MHProposal, src rand.Source) {
f64 := rand.Float64
if src != nil {
f64 = rand.New(src).Float64
}
if len(initial) == 0 {
panic("metropolishastings: zero length initial")
}
r, _ := batch.Dims()
current := make([]float64, len(initial))
copy(current, initial)
proposed := make([]float64, len(initial))
currentLogProb := target.LogProb(initial)
for i := 0; i < r; i++ {
proposal.ConditionalRand(proposed, current)
proposedLogProb := target.LogProb(proposed)
probTo := proposal.ConditionalLogProb(proposed, current)
probBack := proposal.ConditionalLogProb(current, proposed)
accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb)
if accept > f64() {
copy(current, proposed)
currentLogProb = proposedLogProb
}
batch.SetRow(i, current)
}
}
// ProposalNormal is a sampling distribution for Metropolis-Hastings. It has a
// fixed covariance matrix and changes the mean based on the current sampling
// location.
type ProposalNormal struct {
normal *distmv.Normal
}
// NewProposalNormal constructs a new ProposalNormal for use as a proposal
// distribution for Metropolis-Hastings. ProposalNormal is a multivariate normal
// distribution (implemented by distmv.Normal) where the covariance matrix is fixed
// and the mean of the distribution changes.
//
// NewProposalNormal returns {nil, false} if the covariance matrix is not positive-definite.
func NewProposalNormal(sigma *mat.SymDense, src rand.Source) (*ProposalNormal, bool) {
mu := make([]float64, sigma.Symmetric())
normal, ok := distmv.NewNormal(mu, sigma, src)
if !ok {
return nil, false
}
p := &ProposalNormal{
normal: normal,
}
return p, true
}
// ConditionalLogProb returns the probability of the first argument conditioned on
// being at the second argument.
// p(x|y)
// ConditionalLogProb panics if the input slices are not the same length or
// are not equal to the dimension of the covariance matrix.
func (p *ProposalNormal) ConditionalLogProb(x, y []float64) (prob float64) {
// Either SetMean or LogProb will panic if the slice lengths are innaccurate.
p.normal.SetMean(y)
return p.normal.LogProb(x)
}
// ConditionalRand generates a new random location conditioned being at the
// location y. If the first argument is nil, a new slice is allocated and
// returned. Otherwise, the random location is stored in-place into the first
// argument, and ConditionalRand will panic if the input slice lengths differ or
// if they are not equal to the dimension of the covariance matrix.
func (p *ProposalNormal) ConditionalRand(x, y []float64) []float64 {
if x == nil {
x = make([]float64, p.normal.Dim())
}
if len(x) != len(y) {
panic(errLengthMismatch)
}
p.normal.SetMean(y)
p.normal.Rand(x)
return x
}