-
Notifications
You must be signed in to change notification settings - Fork 10
/
sampling.jl
413 lines (337 loc) · 15.3 KB
/
sampling.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
"""
symplectic_integrate(x₀, p₀, Λ, U, δUδx, N=50, ϵ=0.1, progress=false)
Do a symplectic integration of the potential energy `U` (with gradient
`δUδx`) starting from point `x₀` with momentum `p₀` and mass matrix
`Λ`. The number of steps is `N` and the step size `ϵ`.
Returns `ΔH, xᵢ, pᵢ` corresponding to change in Hamiltonian, and final
position and momenta. If `history_keys` is specified a history of
requested variables throughout each step is also returned.
"""
function symplectic_integrate(
x₀::AbstractVector{T}, p₀, Λ, U, δUδx=x->gradient(U,x)[1];
N=50, ϵ=T(0.1), progress=false, history_keys=nothing
) where {T}
xᵢ, pᵢ = x₀, p₀
δUδxᵢ = δUδx(xᵢ)
H(x,p) = U(x) - p⋅(Λ\p)/2
history = []
@showprogress (progress ? 1 : Inf) "Symplectic Integration: " for i=1:N
xᵢ₊₁ = xᵢ - T(ϵ) * (Λ \ (pᵢ - T(ϵ)/2 * δUδxᵢ))
δUδxᵢ₊₁ = δUδx(xᵢ₊₁)
pᵢ₊₁ = pᵢ - T(ϵ)/2 * (δUδxᵢ₊₁ + δUδxᵢ)
xᵢ, pᵢ, δUδxᵢ = xᵢ₊₁, pᵢ₊₁, δUδxᵢ₊₁
if !isnothing(history_keys)
historyᵢ = (;i, x=xᵢ, p=pᵢ, δUδx=δUδxᵢ₊₁, H=(haskey(history_keys,:H) ? H(xᵢ,pᵢ) : nothing))
push!(history, select(historyᵢ, history_keys))
end
end
ΔH = H(xᵢ,pᵢ) - H(x₀,p₀)
if isnothing(history)
return ΔH, xᵢ, pᵢ
else
return ΔH, xᵢ, pᵢ, history
end
end
@doc doc"""
grid_and_sample(lnP::Function; range::NamedTuple; progress=false, nsamples=1)
Interpolate the log pdf `lnP` with support on `range`, and return the
integrated log pdf as well `nsamples` samples (drawn via inverse transform
sampling)
`lnP` should either accept a NamedTuple argument and `range` should be a
NamedTuple mapping those same names to `range` objects specifying where to
evaluate `lnP`, e.g.:
```julia
grid_and_sample(nt->-(nt.x^2+nt.y^2)/2, (x=range(-3,3,length=100),y=range(-3,3,length=100)))
```
or `lnP` should accept a single scalar argument and `range` should be directly
the range for this variable:
```julia
grid_and_sample(x->-x^2/2, range(-3,3,length=100))
```
The return value is `(lnP, samples, Px)` where `lnP` is an interpolated/smoothed
log PDF which can be evaluated anywhere within the original range, `Px` are
sampled points of the original PDF, and `samples` is a NamedTuple giving the
Monte-Carlo samples of each of the parameters.
(Note: only 1D sampling is currently implemented, but 2D like in the example
above is planned)
"""
function grid_and_sample(lnP::Function, range::AbstractVector; progress=false, kwargs...)
lnPs = @showprogress (progress ? 1 : Inf) "Grid Sample: " map(lnP, range)
grid_and_sample(lnPs, range; progress=progress, kwargs...)
end
function grid_and_sample(lnPs::Vector{<:BatchedReal}, xs::AbstractVector; kwargs...)
batches = [grid_and_sample(batch_index.(lnPs,i), xs; kwargs...) for i=1:batch_length(lnPs[1])]
((batch(getindex.(batches,i)) for i=1:3)...,)
end
function grid_and_sample(lnPs::Vector, xs::AbstractVector; progress=false, nsamples=1, span=0.25, require_convex=false)
# trim leading/trailing zero-probability regions
support = findnext(isfinite,lnPs,1):findprev(isfinite,lnPs,length(lnPs))
xs = xs[support]
lnPs = lnPs[support]
if require_convex
support = longest_run_of_trues(finite_second_derivative(lnPs) .< 0)
xs = xs[support]
lnPs = lnPs[support]
end
# interpolate PDF
xmin, xmax = first(xs), last(xs)
lnPs = lnPs .- maximum(lnPs)
ilnP = loess(xs, lnPs, span=span)
# normalize the PDF. note the smoothing is done of the log PDF.
cdf(x) = quadgk(nan2zero∘exp∘ilnP,xmin,x,rtol=1e-3)[1]
logA = nan2zero(log(cdf(xmax)))
lnPs = (ilnP.ys .-= logA)
ilnP.bs[:,1] .-= logA
# draw samples via inverse transform sampling
θsamples = @showprogress (progress ? 1 : Inf) map(1:nsamples) do i
r = rand()
if (cdf(xmin)-r)*(cdf(xmax)-r) >= 0
first(lnPs) > last(lnPs) ? xmin : xmax
else
fzero(x->cdf(x)-r, xmin, xmax, xatol=(xmax-xmin)*1e-3)
end
end
(nsamples==1 ? θsamples[1] : θsamples), ilnP, lnPs
end
function grid_and_sample(lnP::Function, range::NamedTuple{S,<:NTuple{1}}; kwargs...) where {S}
NamedTuple{S}.(Ref.(grid_and_sample(x -> lnP(NamedTuple{S}(x)), first(range); kwargs...)))
end
# allow more convenient evaluation of Loess-interpolated functions
(m::Loess.LoessModel)(x) = Loess.predict(m,x)
@doc doc"""
sample_joint(ds::DataSet; kwargs...)
Sample the joint posterior, $\mathcal{P}(f,\phi,\theta\,|\,d)$.
Keyword arguments:
* `nsamps_per_chain` — The number of samples per chain.
* `nchains = 1` — Number of chains in parallel.
* `nchunk = 1` — Number of steps between parallel chain communication.
* `nsavemaps = 1` — Number of steps in between saving maps into chain.
* `nburnin_always_accept = 0` — Number of steps at the beginning of
the chain to always accept HMC steps regardless of integration
error.
* `nburnin_fixθ = 0` — Number of steps at the beginning of the chain
before starting to sample `θ`.
* `Nϕ = :qe` — Noise to use in the initial approximation to the
Hessian. Can give `:qe` to use the quadratic estimate noise.
* `chains = nothing` — `nothing` to start a new chain; the return
value from a previous call to `sample_joint` to resume those chains;
`:resume` to resume chains from a file given by `filename`
* `θrange` — Range and density to grid sample parameters as a
NamedTuple, e.g. `(Aϕ=range(0.7,1.3,length=20),)`.
* `θstart` — Starting values of parameters as a NamedTuple, e.g.
`(Aϕ=1.2,)`, or nothing to randomly sample from θrange
* `ϕstart` — Starting `ϕ`, either a `Field` object, `:quasi_sample`,
or `:best_fit`
* `metadata` — Does nothing, but is saved into the chain file
* `nhmc = 1` — Number of HMC passes per `ϕ` Gibbs step.
* `symp_kwargs = fill((N=25, ϵ=0.01), nhmc)` — an array of NamedTupe
kwargs to pass to [`symplectic_integrate`](@ref). E.g.
`[(N=50,ϵ=0.1),(N=25,ϵ=0.01)]` would do 50 large steps then 25
smaller steps per each Gibbs pass. If specified, `nhmc` is ignored.
* `wf_kwargs` — Keyword arguments to pass to [`argmaxf_lnP`](@ref) in
the Wiener Filter Gibbs step.
* `MAP_kwargs` — Keyword arguments to pass to [`MAP_joint`](@ref) when
computing the starting point.
"""
function sample_joint(
ds :: DataSet;
nsamps_per_chain,
nchains = nworkers(),
nchunk = 1,
nsavemaps = 1,
nburnin_always_accept = 0,
nburnin_fixθ = 0,
Nϕ = :qe,
filename = nothing,
ϕstart = :prior,
θstart = :prior,
θrange = NamedTuple(),
Nϕ_fac = 2,
pmap = (myid() in workers() ? map : pmap),
conjgrad_kwargs = (tol=1e-1, nsteps=500),
preconditioner = :diag,
nhmc = 1,
symp_kwargs = fill((N=25, ϵ=0.01), nhmc),
MAP_kwargs = (nsteps=40,),
metadata = nothing,
progress = :summary,
interruptable = false,
gibbs_pass_θ::Union{Function,Nothing} = nothing,
postprocess = nothing,
storage = ds.d.storage
)
ds = cpu(ds)
# save input configuration to later write to chain file
rundat = Base.@locals
pop!.(Ref(rundat), (:metadata, :ds)) # saved separately
# validate arguments
if (length(θrange)>1 && gibbs_pass_θ==nothing)
error("Can only currently sample one parameter at a time, otherwise must pass custom `gibbs_pass_θ`")
end
if !(progress in [false,:summary,:verbose])
error("`progress` should be one of [false,:summary,:verbose]")
end
if (filename!=nothing && splitext(filename)[2]!=".jld2")
error("Chain filename '$filename' should have '.jld2' extension.")
end
if mod(nchunk,nsavemaps) != 0
error("`nsavemaps` should divide evenly into `nchunk`")
end
# seed
@everywhere @eval CMBLensing seed!.(global_rng_for.((Array,$storage)))
# initialize chains
if (filename != nothing) && isfile(filename)
@info "Resuming chain at $filename"
local chunks_index, last_chunks
jldopen(filename,"r") do io
chunks_index = maximum([parse(Int,k[8:end]) for k in keys(io) if startswith(k,"chunks_")])
last_chunks = read(io, "chunks_$(chunks_index)")
end
else
Nbatch = batch_length(ds.d)
θstarts = if θstart == :prior
[map(range->batch((first(range) .+ rand(Nbatch) .* (last(range) - first(range)))...), θrange) for i=1:nchains]
elseif θstart isa NamedTuple
fill(θstart, nchains)
elseif θstart isa Vector{<:NamedTuple}
θstart
else
error("`θstart` should be either `nothing` to randomly sample the starting value or a NamedTuple giving the starting point.")
end
ϕstarts = if ϕstart == :prior
pmap(θstarts) do θstart
simulate(ds(;θstart...).Cϕ; Nbatch)
end
elseif ϕstart == 0
fill(zero(diag(ds().Cϕ)), nchains)
elseif ϕstart isa Field
fill(ϕstart, nchains)
elseif ϕstart isa Vector{<:Field}
ϕstart
elseif ϕstart in [:quasi_sample, :best_fit]
pmap(θstarts) do θstart
MAP_joint(adapt(storage,ds(;θstart...)), progress=(progress==:verbose ? :summary : false), Nϕ=adapt(storage,Nϕ), quasi_sample=(ϕstart==:quasi_sample); MAP_kwargs...).ϕ
end
else
error("`ϕstart` should be 0, :quasi_sample, :best_fit, or a Field.")
end
last_chunks = pmap(θstarts,ϕstarts) do θstart,ϕstart
[@dict i=>1 f=>nothing ϕ°=>cpu(ds(;θstart...).G*cpu(ϕstart)) θ=>θstart]
end
chunks_index = 1
if filename != nothing
save(
filename,
"rundat", cpu(rundat),
"ds", cpu(ds),
"ds₀", cpu(ds()), # save separately incase θ-dependent has trouble loading
"metadata", cpu(metadata),
"chunks_1", cpu(last_chunks)
)
end
end
@unpack L, Cϕ = ds
if (Nϕ == :qe)
Nϕ = quadratic_estimate(ds()).Nϕ / Nϕ_fac
end
dsₐ,Nϕₐ = ds,Nϕ
t_write = 0
if progress==:summary
@everywhere first(workers()) @eval CMBLensing begin
pbar = Progress($(nsamps_per_chain-last_chunks[1][end][:i]+1), dt=0, desc="Gibbs chain: ")
ProgressMeter.update!(pbar, showvalues=[("step", $(last_chunks[1][end][:i]))])
end
end
# start chains
try
for chunks_index = (chunks_index+1):(nsamps_per_chain÷nchunk+1)
last_chunks = pmap(last.(last_chunks)) do state
@unpack i,ϕ°,f,θ = state
f,ϕ°,ds,Nϕ = (adapt(storage, x) for x in (f,ϕ°,dsₐ,Nϕₐ))
dsθ = ds(θ)
ϕ = dsθ.G\ϕ°
pϕ°, ΔH, accept = nothing, nothing, nothing
L = ds.L
lnPθ = nothing
chain_chunk = []
for (i, savemaps) in zip( (i+1):(i+nchunk), cycle([fill(false,nsavemaps-1); true]) )
# ==== gibbs P(f°|ϕ°,θ) ====
t_f = @elapsed begin
f = argmaxf_lnP(
ϕ, θ, dsθ;
which = :sample,
fstart = f,
preconditioner = preconditioner,
conjgrad_kwargs = (progress=(progress==:verbose), conjgrad_kwargs...)
)
f°, = mix(f,ϕ,dsθ)
end
# ==== gibbs P(ϕ°|f°,θ) ====
t_ϕ = @elapsed begin
Λm = pinv(dsθ.G)^2 * ((Nϕ == nothing) ? pinv(dsθ.Cϕ) : (pinv(dsθ.Cϕ) + pinv(Nϕ)))
for kwargs in symp_kwargs
pϕ° = simulate(Λm)
(ΔH, ϕtest°) = symplectic_integrate(
ϕ°, pϕ°, Λm,
ϕ°->lnP(:mix, f°, ϕ°, θ, dsθ);
progress=(progress==:verbose),
kwargs...
)
accept = batch(@. (i < nburnin_always_accept) | (log(rand()) < $unbatch(ΔH)))
ϕ° = @. accept * ϕtest° + (1 - accept) * ϕ°
end
end
# ==== gibbs P(θ|f°,ϕ°) ====
t_θ = @elapsed begin
if (i > nburnin_fixθ && length(θrange)>0)
if gibbs_pass_θ == nothing
θ, lnPθ = grid_and_sample(θ->lnP(:mix,f°,ϕ°,θ,ds), θrange, progress=(progress==:verbose))
else
θ, lnPθ = gibbs_pass_θ(;(Base.@locals)...)
end
dsθ = ds(θ)
end
end
# compute un-mixed maps
f, ϕ = unmix(f°,ϕ°,θ,dsθ)
f̃ = L(ϕ)*f
# save state to chain and print progress
timing = (f=t_f, θ=t_θ, ϕ=t_ϕ)
state = @dict i θ lnPθ ΔH accept lnP=>lnP(0,f,ϕ,θ,dsθ) timing
if savemaps
merge!(state, @dict f f° f̃ ϕ ϕ° pϕ°)
end
if postprocess != nothing
merge!(state, postprocess(;(Base.@locals)...))
end
push!(chain_chunk, cpu(state))
if @isdefined(pbar)
string_trunc(x) = Base._truncate_at_width_or_chars(string(x), displaysize(stdout)[2]-14)
next!(pbar, showvalues = [
("step",i),
tuple.(keys(θ), string_trunc.(values(θ)))...,
("ΔH", string_trunc(ΔH)),
("accept", string_trunc(accept)),
("timing", timing)
])
end
end
return chain_chunk
end
if filename != nothing
last_chunks[1][end][:t_write] = t_write
t_write = @elapsed jldopen(filename,"a+") do io
wsession = JLD2.JLDWriteSession()
write(io, "chunks_$chunks_index", last_chunks, wsession)
end
end
end
catch err
if interruptable && (err isa InterruptException)
println()
@warn("Chain interrupted. Returning current progress.")
else
rethrow(err)
end
end
end