Skip to content

Commit

Permalink
Merge pull request #10 from chrished/verbose_parallel
Browse files Browse the repository at this point in the history
add verbose parallel, need to update tests and rerun when converged
  • Loading branch information
chrished committed Jul 9, 2019
2 parents cfee36d + e2da6d3 commit b88d27f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/demcz.jl
Expand Up @@ -99,11 +99,11 @@ demcz_sample_par(logobj, Zmat, opts::DEMCopt; sync_every = 1000, prevrun=nothing
Runs each chain on a separate process - Z is updated simultaenously among all chains running in parallel.
"""
function demcz_sample_par(logobj, Zmat, opts::DEMCopt; sync_every = 1000, prevrun=nothing)
return demcz_sample_par(logobj, Zmat, opts.N, opts.K, opts.Ngeneration, opts.Nblocks, opts.blockindex, opts.eps_scale, opts.γ;sync_every = sync_every, prevrun=prevrun, autostop=opts.autostop, autostop_Rhat=opts.autostop_Rhat)
return demcz_sample_par(logobj, Zmat, opts.N, opts.K, opts.Ngeneration, opts.Nblocks, opts.blockindex, opts.eps_scale, opts.γ;sync_every = sync_every, prevrun=prevrun, autostop=opts.autostop, autostop_Rhat=opts.autostop_Rhat, verbose=opts.verbose)
end


function demcz_sample_par(logobj, Zmat, N=4, K=10, Ngeneration=5000, Nblocks=1, blockindex=[1:size(Zmat,2)], eps_scale=1e-4*ones(size(Zmat,2)), γ=2.38; sync_every = 1000, prevrun=nothing, autostop=:Rhat, autostop_Rhat=1.1)
function demcz_sample_par(logobj, Zmat, N=4, K=10, Ngeneration=5000, Nblocks=1, blockindex=[1:size(Zmat,2)], eps_scale=1e-4*ones(size(Zmat,2)), γ=2.38; sync_every = 1000, prevrun=nothing, autostop=:Rhat, autostop_Rhat=1.1, verbose=true)
# prep storage etc
nrowZ, d = size(Zmat)
global Zshared = SharedArray(vcat(Zmat, zeros(Int(ceil(N*Ngeneration/K)), d)))
Expand Down Expand Up @@ -135,6 +135,9 @@ function demcz_sample_par(logobj, Zmat, N=4, K=10, Ngeneration=5000, Nblocks=1,
global to = set[2]
passobj(myid(), workers(), [:from, :to], from_mod=DEMC, to_mod=DEMC)
pmap(ic -> runchain!(ic, from, to, mc, Zshared, K, M, logobj, blockindex, eps_scale, γ, Nblocks), 1:N)
if verbose
print_status(mc, to)
end
if autostop == :Rhat
Rhat = Rhat_gelman(mc.chain[:,:, from:to], N, to-from, d)
accept_ratio = sum(diff(mc.log_obj, dims = 2).!=0., dims = 2)./ (sync_every-1)
Expand Down
2 changes: 1 addition & 1 deletion test/example_linreg.jl
Expand Up @@ -55,7 +55,7 @@ Z = randn((10*ndim, ndim))
mc, Z = DEMC.demcz_sample(log_obj, Z, opts)

# drop first half of chain
Ntot = size(mc.chain,3)
N, Npar, Ntot = size(mc.chain)
keep = Int(Ntot-opts.autostop_every)+1:Ntot
Ngen_burned = length(keep)
chain_burned = mc.chain[:,:,keep]
Expand Down

0 comments on commit b88d27f

Please sign in to comment.