diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..95ac9a593 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "julia.environmentPath": "C:\\Users\\Cornelius\\.julia\\environments\\v1.9" +} \ No newline at end of file diff --git a/Project.toml b/Project.toml index 22e75649d..6f0787f39 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ ParallelProcessingTools = "8e8a01fc-6193-5ca1-a2f1-20776dae4199" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" @@ -119,7 +120,7 @@ IrrationalConstants = "0.1, 0.2" KernelDensity = "0.5, 0.6" LaTeXStrings = "1" MacroTools = "0.5" -MeasureBase = "0.12, 0.13, 0.14" +MeasureBase = "0.14" Measurements = "2" NamedArrays = "0.9, 0.10" NestedSamplers = "0.8" diff --git a/examples/dev-internal/transformed_example.jl b/examples/dev-internal/transformed_example.jl new file mode 100644 index 000000000..79d1ea5ff --- /dev/null +++ b/examples/dev-internal/transformed_example.jl @@ -0,0 +1,67 @@ +using BAT +using BAT.MeasureBase +using AffineMaps +using ChangesOfVariables +using BAT.LinearAlgebra +using BAT.Distributions +using BAT.InverseFunctions +import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, TransformedMCMCTransformedSampleID +using Random123, PositiveFactorizations +using AutoDiffOperators +import AdvancedHMC + +import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling + +#ENV["JULIA_DEBUG"] = "BAT" + +context = BATContext(ad = ADModule(:ForwardDiff)) + +posterior = BAT.example_posterior() + +my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context) + + +density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior) +density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo, context) + +s = cholesky(Positive, BAT._approx_cov(density)).L +f = BAT.CustomTransform(Mul(s)) + +my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f), context) + +my_samples = my_result.result + + + +using Plots +plot(my_samples) + +r_mh = @time BAT.bat_sample_impl(posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context) + +r_hmc = @time BAT.bat_sample_impl(posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context) + +plot(bat_sample(posterior).result) + +using BAT.Distributions +using BAT.ValueShapes +prior2 = NamedTupleDist(ShapedAsNT, + b = [4.2, 3.3], + a = Exponential(1.0), + c = Normal(1.0,3.0), + d = product_distribution(Weibull.(ones(2),1)), + e = Beta(1.0, 1.0), + f = MvNormal([0.3,-2.9],Matrix([1.7 0.5;0.5 2.3])) + ) + +posterior.likelihood.density._log_f(rand(posterior.prior)) + +posterior.likelihood.density._log_f(rand(prior2)) + +posterior2 = PosteriorDensity(BAT.logfuncdensity(posterior.likelihood.density._log_f), prior2) + + +@profview r_ram2 = @time BAT.bat_sample_impl(posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context) + +@profview r_mh2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context) + +r_hmc2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context) diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index 6dece3cc4..69727e10f 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -78,7 +78,8 @@ function AHMCIterator( throw(ErrorException("HamiltonianMC requires an ADSelector to be specified in the BAT context")) end - f = checked_logdensityof(density) + # TODO AC ToDo!: discuss with @oschulz + f = logdensityof(density) fg = valgrad_func(f, adsel) init_hamiltonian = AdvancedHMC.Hamiltonian(metric, f, fg) diff --git a/src/BAT.jl b/src/BAT.jl index eb7142f5e..600bdbc14 100644 --- a/src/BAT.jl +++ b/src/BAT.jl @@ -58,6 +58,7 @@ import HypothesisTests import MeasureBase import Measurements import NamedArrays +import ProgressMeter import Random123 import Sobol import StableRNGs diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index c9e3604ca..2fa3dc85d 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -55,6 +55,41 @@ _gen_chains( context::BATContext ) = [_construct_chain(rngpart, id, algorithm, density, initval_alg, context) for id in ids] +# TODO AC discuss +function _cluster_selection( + chains::AbstractVector{<:MCMCIterator}, + tuners, + outputs::AbstractVector{<:DensitySampleVector}, + scale::Real=3, + decision_range_skip::Real=0.9, +) + logds_by_chain = [view(s.logd,(floor(Int,decision_range_skip*length(s))):length(s)) for s in outputs] + medians = [median(x) for x in logds_by_chain] + stddevs = [std(x) for x in logds_by_chain] + + # yet uncategoriesed + uncat = eachindex(chains, tuners, outputs, logds_by_chain, stddevs, medians) + + # clustered indices + cidxs = Vector{Vector{eltype(uncat)}}() + # categories all to clusters + while length(uncat) > 0 + idxmin = findmin(view(stddevs,uncat))[2] + + cidx_sel = map(means_remaining_uncat -> abs(means_remaining_uncat-medians[uncat[idxmin]]) < scale*stddevs[uncat[idxmin]], view(medians,uncat)) + + push!(cidxs, uncat[cidx_sel]) + uncat = uncat[.!cidx_sel] + end + medians_c = [ median(reduce(vcat, view(logds_by_chain, ids))) for ids in cidxs] + idx_order = sortperm(medians_c, rev=true) + + chains_by_cluster = [ view(chains, ids) for ids in cidxs[idx_order]] + tuners_by_cluster = [ view(tuners, ids) for ids in cidxs[idx_order]] + outputs_by_cluster = [ view(outputs, ids) for ids in cidxs[idx_order]] + ( chains = chains_by_cluster[1], tuners = tuners_by_cluster[1], outputs = outputs_by_cluster[1], ) +end + function mcmc_init!( algorithm::MCMCAlgorithm, @@ -66,13 +101,13 @@ function mcmc_init!( callback::Function, context::BATContext ) - @info "MCMCChainPoolInit: trying to generate $nchains viable MCMC chain(s)." - initval_alg = init_alg.initval_alg min_nviable::Int = minimum(init_alg.init_tries_per_chain) * nchains max_ncandidates::Int = maximum(init_alg.init_tries_per_chain) * nchains + @info "MCMCChainPoolInit: trying to generate $(min_nviable) viable MCMC chain(s)." + rngpart = RNGPartition(get_rng(context), Base.OneTo(max_ncandidates)) ncandidates::Int = 0 @@ -87,46 +122,60 @@ function mcmc_init!( chains = similar([dummy_chain], 0) tuners = similar([dummy_tuner], 0) outputs = similar([DensitySampleVector(dummy_chain)], 0) - cycle::Int = 1 + init_tries::Int = 1 while length(tuners) < min_nviable && ncandidates < max_ncandidates - n = min(min_nviable, max_ncandidates - ncandidates) - @debug "Generating $n $(cycle > 1 ? "additional " : "")candidate MCMC chain(s)." + viable_tuners = similar(tuners, 0) + viable_chains = similar(chains, 0) + viable_outputs = similar(outputs, 0) - new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) + # as the iteration after viable check is more costly, fill up to be at least capable to skip a complete reiteration. + while length(viable_tuners) < min_nviable-length(tuners) && ncandidates < max_ncandidates + n = min(min_nviable, max_ncandidates - ncandidates) + @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." - filter!(isvalidchain, new_chains) + new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) - new_tuners = tuning_alg.(new_chains) - new_outputs = DensitySampleVector.(new_chains) - next_cycle!.(new_chains) - tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) - ncandidates += n + filter!(isvalidchain, new_chains) - @debug "Testing $(length(new_tuners)) candidate MCMC chain(s)." + new_tuners = tuning_alg.(new_chains) + new_outputs = DensitySampleVector.(new_chains) + next_cycle!.(new_chains) + tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) + ncandidates += n - mcmc_iterate!( - new_outputs, new_chains, new_tuners; - max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), - callback = callback, - nonzero_weights = nonzero_weights - ) + @debug "Testing $(length(new_tuners)) candidate MCMC chain(s)." - viable_idxs = findall(isviablechain.(new_chains)) - viable_tuners = new_tuners[viable_idxs] - viable_chains = new_chains[viable_idxs] - viable_outputs = new_outputs[viable_idxs] + mcmc_iterate!( + new_outputs, new_chains, new_tuners; + max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), + callback = callback, + nonzero_weights = nonzero_weights + ) + @info length.(new_outputs) + + viable_idxs = findall(isviablechain.(new_chains)) - @debug "Found $(length(viable_idxs)) viable MCMC chain(s)." + append!(viable_tuners, new_tuners[viable_idxs]) + append!(viable_chains, new_chains[viable_idxs]) + append!(viable_outputs, new_outputs[viable_idxs]) + end + + @debug "Found $(length(viable_tuners)) viable MCMC chain(s)." if !isempty(viable_tuners) + desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_tuners), " of min_nviable=", length(tuners), "/", min_nviable ) + progress_meter = ProgressMeter.Progress(length(viable_tuners) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + mcmc_iterate!( viable_outputs, viable_chains, viable_tuners; max_nsteps = init_alg.nsteps_init, - callback = callback, + callback = (kwargs...)-> let pm=progress_meter, callback=callback ; callback(kwargs) ; ProgressMeter.next!(pm) ; end, nonzero_weights = nonzero_weights ) + ProgressMeter.finish!(progress_meter) + nsamples_thresh = floor(Int, 0.8 * median([nsamples(chain) for chain in viable_chains])) good_idxs = findall(chain -> nsamples(chain) >= nsamples_thresh, viable_chains) @debug "Found $(length(viable_tuners)) MCMC chain(s) with at least $(nsamples_thresh) unique accepted samples." @@ -136,10 +185,17 @@ function mcmc_init!( append!(outputs, view(viable_outputs, good_idxs)) end - cycle += 1 + init_tries += 1 + end + + # TODO AC + if true + @unpack chains, tuners, outputs = _cluster_selection(chains, tuners, outputs) + length(tuners) < nchains && error("Failed to generate $nchains viable MCMC chains") + else + length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") end - length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") m = nchains tidxs = LinearIndices(tuners) diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index 6084b6f3b..685f80836 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -56,13 +56,30 @@ function bat_sample_impl( get_mcmc_tuning(mcmc_algorithm), algorithm.nonzero_weights, algorithm.store_burnin ? algorithm.callback : nop_func, - context + context, ) if !algorithm.store_burnin chain_outputs .= DensitySampleVector.(chains) end + run_sampling = _run_sample_impl(density, algorithm, chains, tuners, context, chain_outputs=chain_outputs) + samples_trafo, generator = run_sampling.result_trafo, run_sampling.generator + + samples_notrafo = inverse(trafo).(samples_trafo) + + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = generator) +end + +function _run_sample_impl( + density::AnyMeasureOrDensity, + algorithm::MCMCSampling, + chains::AbstractVector{<:MCMCIterator}, + tuners, + context::BATContext; + description::AbstractString="MCMC iterate", + chain_outputs=DensitySampleVector.(chains) +) mcmc_burnin!( algorithm.store_burnin ? chain_outputs : nothing, tuners, @@ -76,19 +93,44 @@ function bat_sample_impl( next_cycle!.(chains) + progress_meter = ProgressMeter.Progress(algorithm.nchains * algorithm.nsteps, desc=description, barlen=80 - length(description), dt=0.1) + mcmc_iterate!( chain_outputs, chains; max_nsteps = algorithm.nsteps, nonzero_weights = algorithm.nonzero_weights, - callback = algorithm.callback + callback = (kwargs...) -> let pm=progress_meter, callback=algorithm.callback ; callback(kwargs) ; ProgressMeter.next!(pm) ; end, ) + ProgressMeter.finish!(progress_meter) + output = DensitySampleVector(first(chains)) isnothing(output) || append!.(Ref(output), chain_outputs) samples_trafo = varshape(density).(output) + (result_trafo = samples_trafo, generator = MCMCSampleGenerator(chains)) +end + +function _bat_sample_continue( + target::AnyMeasureOrDensity, + algorithm::MCMCSampling, + generator::MCMCSampleGenerator, + context, + ;description::AbstractString = "MCMC iterate" +) + @unpack chains = generator + density_notrafo = convert(AbstractMeasureOrDensity, target) + density, trafo = transform_and_unshape(algorithm.trafo, density_notrafo, context) + + chain_outputs = DensitySampleVector.(chains) + + tuners = map(v -> get_mcmc_tuning(getproperty(v, :algorithm))(v), chains) + + run_sampling = _run_sample_impl(density, algorithm, chains, tuners, context, description=description, chain_outputs=chain_outputs) + samples_trafo, generator_new = run_sampling.result_trafo, run_sampling.generator + samples_notrafo = inverse(trafo).(samples_trafo) - (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = MCMCSampleGenerator(chains)) + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = generator_new) end diff --git a/src/samplers/mcmc/multi_cycle_burnin.jl b/src/samplers/mcmc/multi_cycle_burnin.jl index 206199df6..cf98a33ed 100644 --- a/src/samplers/mcmc/multi_cycle_burnin.jl +++ b/src/samplers/mcmc/multi_cycle_burnin.jl @@ -48,13 +48,18 @@ function mcmc_burnin!( tuning_reinit!.(tuners, chains, burnin_alg.nsteps_per_cycle) + desc_string = string("Burnin cycle ", cycles, "/max_cycles=", burnin_alg.max_ncycles," for nchains=", length(chains)) + progress_meter = ProgressMeter.Progress(length(chains)*burnin_alg.nsteps_per_cycle, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + mcmc_iterate!( new_outputs, chains, tuners, max_nsteps = burnin_alg.nsteps_per_cycle, nonzero_weights = nonzero_weights, - callback = callback + callback = (kwargs...) -> let pm=progress_meter, callback=callback ; callback(kwargs) ; ProgressMeter.next!(progress_meter) ; end, ) + ProgressMeter.finish!(progress_meter) + tuning_update!.(tuners, chains, new_outputs) isnothing(outputs) || append!.(outputs, new_outputs) diff --git a/src/samplers/samplers.jl b/src/samplers/samplers.jl index de7264c18..ad6b149d2 100644 --- a/src/samplers/samplers.jl +++ b/src/samplers/samplers.jl @@ -2,5 +2,6 @@ include("bat_sample.jl") include("mcmc/mcmc.jl") +include("transformed_mcmc/mcmc.jl") include("evaluated_measure.jl") include("importance/importance_sampler.jl") diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl new file mode 100644 index 000000000..adcc39e6d --- /dev/null +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -0,0 +1,228 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +""" + struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm + +MCMC chain pool initialization strategy. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm + init_tries_per_chain::ClosedInterval{Int64} = ClosedInterval(8, 128) + nsteps_init::Int64 = 1000 + initval_alg::InitvalAlgorithm = InitFromTarget() +end + +export TransformedMCMCChainPoolInit + + +function apply_trafo_to_init(trafo::Function, initalg::TransformedMCMCChainPoolInit) + TransformedMCMCChainPoolInit( + initalg.init_tries_per_chain, + initalg.nsteps_init, + apply_trafo_to_init(trafo, initalg.initval_alg) + ) +end + + + +function _construct_chain( + rngpart::RNGPartition, + id::Integer, + algorithm::TransformedMCMCSampling, + density::AbstractMeasureOrDensity, + initval_alg::InitvalAlgorithm, + parent_context::BATContext +) + new_context = set_rng(parent_context, AbstractRNG(rngpart, id)) + v_init = bat_initval(density, initval_alg, new_context).result + return TransformedMCMCIterator(algorithm, density, id, v_init, new_context) +end + +_gen_chains( + rngpart::RNGPartition, + ids::AbstractRange{<:Integer}, + algorithm::TransformedMCMCSampling, + density::AbstractMeasureOrDensity, + initval_alg::InitvalAlgorithm, + context::BATContext +) = [_construct_chain(rngpart, id, algorithm, density, initval_alg, context) for id in ids] + +#TODO +function mcmc_init!( + algorithm::TransformedMCMCSampling, + density::AbstractMeasureOrDensity, + nchains::Integer, + init_alg::TransformedMCMCChainPoolInit, + tuning_alg::TransformedMCMCTuningAlgorithm, # TODO: part of algorithm? # MCMCTuner + nonzero_weights::Bool, + callback::Function, + context::BATContext +) + @info "TransformedMCMCChainPoolInit: trying to generate $nchains viable MCMC chain(s)." + + initval_alg = init_alg.initval_alg + + min_nviable::Int = minimum(init_alg.init_tries_per_chain) * nchains + max_ncandidates::Int = maximum(init_alg.init_tries_per_chain) * nchains + + rngpart = RNGPartition(get_rng(context), Base.OneTo(max_ncandidates)) + + ncandidates::Int = 0 + + @debug "Generating dummy MCMC chain to determine chain, output and tuner types." #TODO: remove! + + dummy_context = deepcopy(context) + dummy_initval = unshaped(bat_initval(density, InitFromTarget(), dummy_context).result, varshape(density)) + dummy_chain = TransformedMCMCIterator(algorithm, density, 1, dummy_initval, dummy_context) + dummy_tuner = get_tuner(tuning_alg, dummy_chain) + dummy_temperer = get_temperer(algorithm.tempering, density) + + chains = similar([dummy_chain], 0) + tuners = similar([dummy_tuner], 0) + temperers = similar([dummy_temperer], 0) + + init_tries::Int = 1 + + while length(tuners) < min_nviable && ncandidates < max_ncandidates + viable_tuners = similar(tuners, 0) + viable_chains = similar(chains, 0) + viable_temperers = similar(temperers, 0) + viable_outputs = [] #similar(outputs, 0) #TODO + + # as the iteration after viable check is more costly, fill up to be at least capable to skip a complete reiteration. + while length(viable_tuners) < min_nviable-length(tuners) && ncandidates < max_ncandidates + n = min(min_nviable, max_ncandidates - ncandidates) + @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." + + new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) + + filter!(isvalidchain, new_chains) + + new_tuners = get_tuner.(Ref(tuning_alg), new_chains) + new_temperers = fill(get_temperer(algorithm.tempering, density), size(new_tuners,1)) + + next_cycle!.(new_chains) + + tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) + ncandidates += n + + @debug "Testing $(length(new_chains)) candidate MCMC chain(s)." + + transformed_mcmc_iterate!( + new_chains, new_tuners, new_temperers, + max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), + callback = callback, + nonzero_weights = nonzero_weights + ) + + # testing if chains are viable: + viable_idxs = findall(isviablechain.(new_chains)) + + new_outputs = getproperty.(new_chains, :samples) #TODO ? + + append!(viable_tuners, new_tuners[viable_idxs]) + append!(viable_chains, new_chains[viable_idxs]) + append!(viable_outputs, new_outputs[viable_idxs]) + append!(viable_temperers, new_temperers[viable_idxs]) + + end + + @debug "Found $(length(viable_tuners)) viable MCMC chain(s)." + + if !isempty(viable_chains) + desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_tuners), " of min_nviable=", length(tuners), "/", min_nviable ) + progress_meter = ProgressMeter.Progress(length(viable_tuners) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + transformed_mcmc_iterate!( + viable_chains, viable_tuners, viable_temperers; + max_nsteps = init_alg.nsteps_init, + callback = (kwargs...)-> let pm=progress_meter; ProgressMeter.next!(pm) ; end, + nonzero_weights = nonzero_weights + ) + ProgressMeter.finish!(progress_meter) + nsamples_thresh = floor(Int, 0.8 * median([nsamples(chain) for chain in viable_chains])) + good_idxs = findall(chain -> nsamples(chain) >= nsamples_thresh, viable_chains) + @debug "Found $(length(viable_chains)) MCMC chain(s) with at least $(nsamples_thresh) unique accepted samples." + + append!(chains, view(viable_chains, good_idxs)) + append!(tuners, view(viable_tuners, good_idxs)) + append!(temperers, view(viable_temperers, good_idxs)) + end + + init_tries += 1 + end + + outputs = getproperty.(chains, :samples) + + length(chains) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") + + m = nchains + tidxs = LinearIndices(chains) + n = length(tidxs) + + modes = hcat(broadcast(samples -> Array(bat_findmode(samples, MaxDensitySearch(), context).result), outputs)...) + + final_chains = similar(chains, 0) + final_tuners = similar(tuners, 0) + final_temperers = similar(temperers, 0) + final_outputs = similar(outputs, 0) + + # TODO: should we put this into a function? + if 2 <= m < size(modes, 2) + clusters = kmeans(modes, m, init = KmCentralityAlg()) + clusters.converged || error("k-means clustering of MCMC chains did not converge") + + mincosts = fill(Inf, m) + chain_sel_idxs = fill(0, m) + + for i in tidxs + j = clusters.assignments[i] + if clusters.costs[i] < mincosts[j] + mincosts[j] = clusters.costs[i] + chain_sel_idxs[j] = i + end + end + + @assert all(j -> j in tidxs, chain_sel_idxs) + + for i in sort(chain_sel_idxs) + push!(final_chains, chains[i]) + push!(final_tuners, tuners[i]) + push!(final_temperers, temperers[i]) + push!(final_outputs, outputs[i]) + end + elseif m == 1 + i = findmax(nsamples.(chains))[2] + push!(final_chains, chains[i]) + push!(final_tuners, tuners[i]) + push!(final_temperers, temperers[i]) + push!(final_outputs, outputs[i]) + else + @assert length(chains) == nchains + resize!(final_chains, nchains) + copyto!(final_chains, chains) + + @assert length(tuners) == nchains + resize!(final_tuners, nchains) + copyto!(final_tuners, tuners) + + @assert length(temperers) == nchains + resize!(final_temperers, nchains) + copyto!(final_temperers, temperers) + + @assert length(outputs) == nchains + resize!(final_outputs, nchains) + copyto!(final_outputs, outputs) + end + + @info "Selected $(length(final_chains)) MCMC chain(s)." + #tuning_postinit!.(final_tuners, final_chains, final_outputs) #TODO: implement + + (chains = final_chains, tuners = final_tuners, temperers = final_temperers, outputs = final_outputs) +end diff --git a/src/samplers/transformed_mcmc/mcmc.jl b/src/samplers/transformed_mcmc/mcmc.jl new file mode 100644 index 000000000..742e76195 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc.jl @@ -0,0 +1,16 @@ +using AffineMaps + +#include("mcmc_utils.jl") + +include("mcmc_weighting.jl") +include("proposaldist.jl") +include("mcmc_sampleid.jl") +include("mcmc_algorithm.jl") +include("mcmc_stats.jl") +include("mcmc_sample.jl") +include("tempering.jl") +include("mcmc_iterate.jl") +include("mcmc_tuning/mcmc_tuning.jl") +include("chain_pool_init.jl") +include("mcmc_convergence.jl") +include("multi_cycle_burnin.jl") diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl new file mode 100644 index 000000000..f7e5aa8a6 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -0,0 +1,242 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + + +function get_mcmc_tuning end #TODO: still needed + + +""" + abstract type TransformedMCMCInitAlgorithm + +Abstract type for MCMC initialization algorithms. +""" +abstract type TransformedMCMCInitAlgorithm end +export TransformedMCMCInitAlgorithm + +#TODO AC: reactivate +#apply_trafo_to_init(trafo::Function, initalg::TransformedMCMCInitAlgorithm) = initalg + + + +""" + abstract type TransformedMCMCTuningAlgorithm + +Abstract type for MCMC tuning algorithms. +""" +abstract type TransformedMCMCTuningAlgorithm end +export TransformedMCMCTuningAlgorithm + + + +""" + abstract type TransformedMCMCBurninAlgorithm + +Abstract type for MCMC burn-in algorithms. +""" +abstract type TransformedMCMCBurninAlgorithm end +export TransformedMCMCBurninAlgorithm + + + +@with_kw struct TransformedMCMCIteratorInfo + id::Int32 + cycle::Int32 + tuned::Bool + converged::Bool +end + + +# TODO AC: reactivate +# """ +# abstract type MCMCIterator end + +# Represents the current state of an MCMC chain. + +# !!! note + +# The details of the `MCMCIterator` and `MCMCAlgorithm` API (see below) +# currently do not form part of the stable API and are subject to change +# without deprecation. + +# To implement a new MCMC algorithm, subtypes of both [`MCMCAlgorithm`](@ref) +# and `MCMCIterator` are required. + +# The following methods must be defined for subtypes of `MCMCIterator` (e.g. +# `SomeMCMCIter<:MCMCIterator`): + +# ```julia + +# BAT.getmeasure(chain::SomeMCMCIter)::AbstractMeasureOrDensity + +# BAT.getcontext(chain::SomeMCMCIter)::BATContext + +# BAT.mcmc_info(chain::SomeMCMCIter)::TransformedMCMCIteratorInfo + +# BAT.nsteps(chain::SomeMCMCIter)::Int + +# BAT.nsamples(chain::SomeMCMCIter)::Int + +# BAT.current_sample(chain::SomeMCMCIter)::DensitySample + +# BAT.sample_type(chain::SomeMCMCIter)::Type{<:DensitySample} + +# BAT.samples_available(chain::SomeMCMCIter, nonzero_weights::Bool = false)::Bool + +# BAT.get_samples!(samples::DensitySampleVector, chain::SomeMCMCIter, nonzero_weights::Bool)::typeof(samples) + +# BAT.next_cycle!(chain::SomeMCMCIter)::SomeMCMCIter + +# BAT.mcmc_step!( +# chain::SomeMCMCIter +# callback::Function, +# )::nothing +# ``` + +# The following methods are implemented by default: + +# ```julia +# getalgorithm(chain::MCMCIterator) +# getmeasure(chain::MCMCIterator) +# DensitySampleVector(chain::MCMCIterator) +# mcmc_iterate!(chain::MCMCIterator, ...) +# mcmc_iterate!(chains::AbstractVector{<:MCMCIterator}, ...) +# isvalidchain(chain::MCMCIterator) +# isviablechain(chain::MCMCIterator) +# ``` +# """ +# abstract type MCMCIterator end +# export MCMCIterator + + +#TODO AC: reactivate +# function Base.show(io::IO, chain::MCMCIterator) +# print(io, Base.typename(typeof(chain)).name, "(") +# print(io, "id = "); show(io, mcmc_info(chain).id) +# print(io, ", nsamples = "); show(io, nsamples(chain)) +# print(io, ", density = "); show(io, getmeasure(chain)) +# print(io, ")") +# end + + +function getalgorithm end + +function getmeasure end + +function mcmc_info end + +function nsteps end + +function nsamples end + +function current_sample end + +function sample_type end + +function samples_available end + +function get_samples! end + +function next_cycle! end + +function mcmc_step! end + + +# TODO AC: reactivate +#DensitySampleVector(chain::MCMCIterator) = DensitySampleVector(sample_type(chain), totalndof(getmeasure(chain))) + + +abstract type TransformedAbstractMCMCTunerInstance end + + +function tuning_init! end + +function tuning_postinit! end + +function tuning_reinit! end + +function tuning_update! end + +function tuning_finalize! end + +function tuning_callback end + + +function mcmc_init! end + +function mcmc_burnin! end + + +function isvalidchain end + +function isviablechain end + + + +function mcmc_iterate! end + +""" + BAT.TransformedMCMCSampleGenerator + +*BAT-internal, not part of stable public API.* + +MCMC sample generator. + +Constructors: + +```julia +TransformedMCMCSampleGenerator(chain::AbstractVector{<:MCMCIterator}) +``` +""" +struct TransformedMCMCSampleGenerator{ + T<:AbstractVector{<:MCMCIterator}, + A<:AbstractSamplingAlgorithm, +} <: AbstractSampleGenerator + chains::T + algorithm::A +end + +getalgorithm(sg::TransformedMCMCSampleGenerator) = sg.algorithm + +function Base.show(io::IO, generator::TransformedMCMCSampleGenerator) + if get(io, :compact, false) + print(io, nameof(typeof(generator)), "(") + if !isempty(generator.chains) + show(io, first(generator.chains)) + print(io, ", …") + end + print(io, ")") + else + println(io, nameof(typeof(generator)), ":") + chains = generator.chains + nchains = length(chains) + n_tuned_chains = count(c -> c.info.tuned, chains) + n_converged_chains = count(c -> c.info.converged, chains) + print(io, "algorithm: ") + show(io, "text/plain", getalgorithm(generator)) + println(io) + println(io, "number of chains:", repeat(' ', 12), nchains) + println(io, "number of chains tuned:", repeat(' ', 6), n_tuned_chains) + println(io, "number of chains converged:", repeat(' ', 2), n_converged_chains) + println(io, "number of points…") + println(io, repeat(' ',10), "… in 1th chain:", repeat(' ', 4), nsamples(first(chains))) + print(io, repeat(' ',10), "… on average:", repeat(' ', 6), div(sum(nsamples.(chains)), nchains)) + end +end + + +function bat_report!(md::Markdown.MD, generator::TransformedMCMCSampleGenerator) + mcalg = getalgorithm(generator) + chains = generator.chains + nchains = length(chains) + n_tuned_chains = count(c -> c.info.tuned, chains) + n_converged_chains = count(c -> c.info.converged, chains) + + markdown_append!(md, """ + ### Sample generation + + * Algorithm: MCMC, $(nameof(typeof(mcalg))) + * MCMC chains: $nchains ($n_tuned_chains tuned, $n_converged_chains converged) + """) + + return md +end diff --git a/src/samplers/transformed_mcmc/mcmc_convergence.jl b/src/samplers/transformed_mcmc/mcmc_convergence.jl new file mode 100644 index 000000000..83d568940 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_convergence.jl @@ -0,0 +1,169 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +function transformed_check_convergence!( + chains::AbstractVector{<:MCMCIterator}, + samples::AbstractVector{<:DensitySampleVector}, + algorithm::ConvergenceTest, + context::BATContext +) + result = convert(Bool, bat_convergence(samples, algorithm, context).result) + for chain in chains + chain.info = TransformedMCMCIteratorInfo(chain.info, converged = result) + end + result +end + + + +""" + gr_Rsqr(stats::AbstractVector{<:TransformedMCMCBasicStats}) + gr_Rsqr(samples::AbstractVector{<:DensitySampleVector}) + +*BAT-internal, not part of stable public API.* + +Gelman-Rubin ``\$R^2\$`` for all DOF. +""" +# TODO AC: reactivate +# function gr_Rsqr end + +function gr_Rsqr(stats::AbstractVector{<:TransformedMCMCBasicStats}) + m = totalndof(first(stats)) + W = mean([cs.param_stats.cov[i,i] for cs in stats, i in 1:m], dims=1)[:] + B = var([cs.param_stats.mean[i] for cs in stats, i in 1:m], dims=1)[:] + (W .+ B) ./ W +end + +#TODO AC: reactivate +# function gr_Rsqr(samples::AbstractVector{<:DensitySampleVector}) +# gr_Rsqr(TransformedMCMCBasicStats.(samples)) +# end + + + +""" + struct TransformedGelmanRubinConvergence <: ConvergenceTest + +Gelman-Rubin maximum R^2 convergence test. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedGelmanRubinConvergence <: ConvergenceTest + threshold::Float64 = 1.1 +end + +export TransformedGelmanRubinConvergence + +function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedGelmanRubinConvergence, ::BATContext) + max_Rsqr = maximum(gr_Rsqr(samples)) + vt = ValueAndThreshold{max_Rsqr}(max_Rsqr, <=, algorithm.threshold) + converged = convert(Bool, vt) + @debug begin + success_str = converged ? "have" : "have *not*" + "Chains $success_str converged, max(R^2) = $(vt.value), threshold = $(vt.threshold)" + end + (result = vt,) +end + + + +@doc doc""" + bg_R_2sqr(stats::AbstractVector{<:TransformedMCMCBasicStats}; corrected::Bool = false) + bg_R_2sqr(samples::AbstractVector{<:DensitySampleVector}; corrected::Bool = false) + +*BAT-internal, not part of stable public API.* + +Brooks-Gelman R_2^2 for all DOF. +If normality is assumed, 'corrected' should be set to true to account for the sampling variability. +""" +function bg_R_2sqr(stats::AbstractVector{<:TransformedMCMCBasicStats}; corrected::Bool = false) + p = totalndof(first(stats)) + m = length(stats) + n = mean(Float64.(nsamples.(stats))) + + σ_W = var([cs.param_stats.cov[i,i] for cs in stats, i in 1:p], dims = 1)[:] + B = var([cs.param_stats.mean[i] for cs in stats, i in 1:p], dims = 1)[:] + W = mean([cs.param_stats.cov[i,i] for cs in stats, i in 1:p], dims = 1)[:] + + σ_sq = m * (n - 1) / (m*n - 1) * W + n * (m - 1) / (m*n - 1) * B + + R_unc = σ_sq ./ W + + if corrected == false + return R_unc + end + + σ_ij = [cs.param_stats.cov[i,i] for cs in stats, i in 1:p] + x_ij = [cs.param_stats.mean[i] for cs in stats, i in 1:p] + + cov_σx = [cov(σ_ij[:,j], x_ij[:,j]) for j in 1:p] + cov_σx_sq = [cov(σ_ij[:,j], x_ij[:,j].^2) for j in 1:p] + + N = (n-1)/n + M = (m-1)/m + V = N*σ_sq + M*B + + σ_V = N^2/m*σ_W + 2*M/(m-1)*B.^2 + 2*M*N/m*(cov_σx_sq - 2*B.*cov_σx) + d = 2 * V.^2 ./ σ_V + + R_unc.*(d.+3)./(d.+1) +end + +# TODO AC: reactivate +# function bg_R_2sqr(samples::AbstractVector{<:DensitySampleVector}; corrected::Bool = false) +# bg_R_2sqr(TransformedMCMCBasicStats.(samples), corrected = corrected) +# end + + + +""" + struct TransformedBrooksGelmanConvergence <: ConvergenceTest + +Brooks-Gelman maximum R^2 convergence test. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedBrooksGelmanConvergence <: ConvergenceTest + threshold::Float64 = 1.1 + corrected::Bool = false +end + +export TransformedBrooksGelmanConvergence + +function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedBrooksGelmanConvergence, ::BATContext) + max_Rsqr = maximum(bg_R_2sqr(samples, corrected = algorithm.corrected)) + vt = ValueAndThreshold{max_Rsqr}(max_Rsqr, <=, algorithm.threshold) + converged = convert(Bool, vt) + @debug begin + success_str = converged ? "have" : "have *not*" + "Chains $success_str converged, max(R^2) = $(vt.value), threshold = $(vt.threshold)" + end + (result = vt,) +end + + + +function bat_convergence_impl(samples::DensitySampleVector, algorithm::Union{TransformedGelmanRubinConvergence, TransformedBrooksGelmanConvergence}, context::BATContext) + # create a vector of chains + chains_ind = unique([i.chainid for i in samples.info]) + vector_chains = DensitySampleVector[] + # ToDo: Improve implementation + for i in chains_ind + mask_chain = [j.chainid == i for j in samples.info] + push!(vector_chains, samples[mask_chain]) + end + + bat_convergence_impl(vector_chains, algorithm, context) +end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl new file mode 100644 index 000000000..a59cbda78 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -0,0 +1,336 @@ +mutable struct TransformedMCMCIterator{ + PR<:RNGPartition, + D<:BATMeasure, + F, + Q<:TransformedMCMCProposal, + SV<:DensitySampleVector, + S<:DensitySample, + CTX<:BATContext, +} <: MCMCIterator + rngpart_cycle::PR + μ::D + f_transform::F + proposal::Q + samples::SV + sample_z::S + stepno::Int + n_accepted::Int + info::TransformedMCMCIteratorInfo + context::CTX +end + +getmeasure(chain::TransformedMCMCIterator) = chain.μ + +get_context(chain::TransformedMCMCIterator) = chain.context + +mcmc_info(chain::TransformedMCMCIterator) = chain.info + +nsteps(chain::TransformedMCMCIterator) = chain.stepno + +nsamples(chain::TransformedMCMCIterator) = size(chain.samples, 1) + +current_sample(chain::TransformedMCMCIterator) = last(chain.samples) + +sample_type(chain::TransformedMCMCIterator) = eltype(chain.samples) + +samples_available(chain::TransformedMCMCIterator) = size(chain.samples,1) > 0 + +isvalidchain(chain::TransformedMCMCIterator) = current_sample(chain).logd > -Inf + +isviablechain(chain::TransformedMCMCIterator) = nsamples(chain) >= 2 + +eff_acceptance_ratio(chain::TransformedMCMCIterator) = nsamples(chain) / chain.stepno + + + +#ctor +function TransformedMCMCIterator( + algorithm::TransformedMCMCSampling, + target, + id::Integer, + v_init::AbstractVector{<:Real}, + context::BATContext +) + TransformedMCMCIterator(algorithm, target, Int32(id), v_init, context) +end + + +#ctor +function TransformedMCMCIterator( + algorithm::TransformedMCMCSampling, + target, + id::Int32, + v_init::AbstractVector{<:Real}, + context::BATContext, +) + rngpart_cycle = RNGPartition(get_rng(context), 0:(typemax(Int16) - 2)) + + μ = target + proposal = algorithm.proposal + stepno = 1 + cycle = 1 + n_accepted = 0 + + adaptive_transform_spec = algorithm.adaptive_transform + g = init_adaptive_transform(adaptive_transform_spec, μ, context) + + logd_x = logdensityof(μ, v_init) + sample_x = DensitySample(v_init, logd_x, 1, TransformedMCMCTransformedSampleID(id, 1, 0), nothing) # TODO + inverse_g = inverse(g) + z = inverse_g(v_init) # sample_x.v + logd_z = logdensityof(MeasureBase.pullback(g, μ),z) + sample_z = _rebuild_density_sample(sample_x, z, logd_z) + + samples = DensitySampleVector(([sample_x.v], [sample_x.logd], [sample_x.weight], [sample_x.info], [sample_x.aux] )) + + iter = TransformedMCMCIterator( + rngpart_cycle, + target, + g, + proposal, + samples, + sample_z, + stepno, + n_accepted, + TransformedMCMCIteratorInfo(id, cycle, false, false), + context + ) + + +end + + + +function _rebuild_density_sample(s::DensitySample, x, logd, weight=1) + @unpack info, aux = s + DensitySample(x, logd, weight, info, aux) +end + + + +function propose_mcmc( + iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:TransformedMHProposal} +) + @unpack μ, f_transform, proposal, samples, sample_z, stepno, context = iter + rng = get_rng(context) + sample_x = last(samples) + x, logd_x = sample_x.v, sample_x.logd + z, logd_z = sample_z.v, sample_z.logd + + n = size(z, 1) + z_proposed = z + rand(rng, proposal.proposal_dist, n) #TODO: check if proposal is symmetric? otherwise need additional factor? + x_proposed, ladj = with_logabsdet_jacobian(f_transform, z_proposed) + logd_x_proposed = BAT.checked_logdensityof(μ, x_proposed) + logd_z_proposed = logd_x_proposed + ladj + @assert logd_z_proposed ≈ logdensityof(MeasureBase.pullback(f_transform, μ), z_proposed) #TODO: remove + + + # TODO AC: do we need to check symmetry of proposal distribution? + # T = typeof(logd_z) + # p_accept = if logd_z_proposed > -Inf + # # log of ratio of forward/reverse transition probability + # log_tpr = if issymmetric(proposal.proposal_dist) + # T(0) + # else + # log_tp_fwd = proposaldist_logpdf(proposaldist, proposed_params, current_params) + # log_tp_rev = proposaldist_logpdf(proposaldist, current_params, proposed_params) + # T(log_tp_fwd - log_tp_rev) + # end + + # p_accept_unclamped = exp(proposed_log_posterior - current_log_posterior - log_tpr) + # T(clamp(p_accept_unclamped, 0, 1)) + # else + # zero(T) + # end + + p_accept = clamp(exp(logd_z_proposed-logd_z), 0, 1) + + sample_z_proposed = _rebuild_density_sample(sample_z, z_proposed, logd_z_proposed) + sample_x_proposed = _rebuild_density_sample(sample_x, x_proposed, logd_x_proposed) + + return sample_x_proposed, sample_z_proposed, p_accept +end + + + +function transformed_mcmc_step!!( + iter::TransformedMCMCIterator, + tuner::TransformedAbstractMCMCTunerInstance, + tempering::TransformedMCMCTemperingInstance, +) + @unpack μ, f_transform, proposal, samples, sample_z, stepno, context = iter + rng = get_rng(context) + sample_x = last(samples) + x, logd_x = sample_x.v, sample_x.logd + z, logd_z = sample_z.v, sample_z.logd + @unpack n_accepted, stepno = iter + + sample_x_proposed, sample_z_proposed, p_accept = propose_mcmc(iter) + + z_proposed, logd_z_proposed = sample_z_proposed.v, sample_z_proposed.logd + x_proposed, logd_x_proposed = sample_x_proposed.v, sample_x_proposed.logd + + tuner_new, f_transform = tune_mcmc_transform!!(tuner, f_transform, p_accept, z_proposed, z, stepno, context) + + accepted = rand(rng) <= p_accept + + # f_transform may have changed + inverse_f = inverse(f_transform) + x_new, z_new, logd_x_new, logd_z_new = if accepted + x_proposed, inverse_f(x_proposed), logd_x_proposed, logd_z_proposed + else + x, inverse_f(x), logd_x, logd_z + end + + sample_x_new, sample_z_new, samples_new = if accepted + sample_x_new = DensitySample(x_new, logd_x_new, 1, TransformedMCMCTransformedSampleID(iter.info.id, iter.info.cycle, iter.stepno), nothing) + push!(samples, sample_x_new) + sample_x_new, _rebuild_density_sample(sample_z, z_new, logd_z_new), samples + else + samples.weight[end] += 1 + _rebuild_density_sample(sample_x, x_new, logd_x_new, sample_x.weight+1), _rebuild_density_sample(sample_z, z_new, logd_z_new), samples + end + + tempering_new, μ_new = temper_mcmc_target!!(tempering, μ, stepno) + + f_new = f_transform + + # iter_new = TransformedMCMCIterator(μ_new, f_new, proposal, samples_new, sample_z_new, stepno, n_accepted+Int(accepted), context) + iter.μ, iter.f_transform, iter.samples, iter.sample_z = μ_new, f_new, samples_new, sample_z_new + iter.n_accepted += Int(accepted) + iter.stepno += 1 + @assert iter.context === context + + return (iter, tuner_new, tempering_new) +end + + + +function transformed_mcmc_iterate!( + chain::TransformedMCMCIterator, + tuner::TransformedAbstractMCMCTunerInstance, + tempering::TransformedMCMCTemperingInstance; + max_nsteps::Integer = 1, + max_time::Real = Inf, + nonzero_weights::Bool = true, + callback::Function = nop_func, +) + @debug "Starting iteration over MCMC chain $(mcmc_info(chain).id) with $max_nsteps steps in max. $(@sprintf "%.1f seconds." max_time)" + + start_time = time() + last_progress_message_time = start_time + start_nsteps = nsteps(chain) + start_nsteps = nsteps(chain) + + while ( + (nsteps(chain) - start_nsteps) < max_nsteps && + (time() - start_time) < max_time + ) + transformed_mcmc_step!!(chain, tuner, tempering) + callback(Val(:mcmc_step), chain) + + #TODO: output schemes + + current_time = time() + elapsed_time = current_time - start_time + logging_interval = 5 * round(log2(elapsed_time/60 + 1) + 1) + if current_time - last_progress_message_time > logging_interval + last_progress_message_time = current_time + @debug "Iterating over MCMC chain $(mcmc_info(chain).id), completed $(nsteps(chain) - start_nsteps) (of $(max_nsteps)) steps and produced $(nsteps(chain) - start_nsteps) samples in $(@sprintf "%.1f s" elapsed_time) so far." + end + end + + current_time = time() + elapsed_time = current_time - start_time + @debug "Finished iteration over MCMC chain $(mcmc_info(chain).id), completed $(nsteps(chain) - start_nsteps) steps and produced $(nsteps(chain) - start_nsteps) samples in $(@sprintf "%.1f s" elapsed_time)." + + return nothing +end + + +function transformed_mcmc_iterate!( + chain::MCMCIterator, + tuner::TransformedAbstractMCMCTunerInstance, + tempering::TransformedMCMCTemperingInstance; + # tuner::TransformedAbstractMCMCTunerInstance; + max_nsteps::Integer = 1, + max_time::Real = Inf, + nonzero_weights::Bool = true, + callback::Function = nop_func +) + cb = callback# combine_callbacks(tuning_callback(tuner), callback) #TODO CA: tuning_callback + + transformed_mcmc_iterate!( + chain, tuner, tempering, + max_nsteps = max_nsteps, max_time = max_time, nonzero_weights = nonzero_weights, callback = cb + ) + + return nothing +end + + +function transformed_mcmc_iterate!( + chains::AbstractVector{<:MCMCIterator}, + tuners::AbstractVector{<:TransformedAbstractMCMCTunerInstance}, + temperers::AbstractVector{<:TransformedMCMCTemperingInstance}; + kwargs... +) + if isempty(chains) + @debug "No MCMC chain(s) to iterate over." + return chains + else + @debug "Starting iteration over $(length(chains)) MCMC chain(s)" + end + + @sync for i in eachindex(chains, tuners, temperers) + Base.Threads.@spawn transformed_mcmc_iterate!(chains[i], tuners[i], temperers[i]#= , tnrs[i] =#; kwargs...) + end + + return nothing +end + + +#= +# Unused? +function reset_chain( + rng::AbstractRNG, + chain::TransformedMCMCIterator, +) + rngpart_cycle = RNGPartition(rng, 0:(typemax(Int16) - 2)) + #TODO reset cycle count? + chain.rngpart_cycle = rngpart_cycle + chain.info = TransformedMCMCIteratorInfo(chain.info, cycle=0) + chain.context = set_rng(chain.context, rng) + # wants a next_cycle! + # reset_rng_counters!(chain) +end +=# + + +function reset_rng_counters!(chain::TransformedMCMCIterator) + rng = get_rng(get_context(chain)) + set_rng!(rng, chain.rngpart_cycle, chain.info.cycle) + rngpart_step = RNGPartition(rng, 0:(typemax(Int32) - 2)) + set_rng!(rng, rngpart_step, chain.stepno) + nothing +end + + +function next_cycle!( + chain::TransformedMCMCIterator, + +) + chain.info = TransformedMCMCIteratorInfo(chain.info, cycle = chain.info.cycle + 1) + chain.stepno = 0 + + reset_rng_counters!(chain) + + chain.samples[1] = last(chain.samples) + resize!(chain.samples, 1) + + chain.samples.weight[1] = 1 + chain.samples.info[1] = TransformedMCMCTransformedSampleID(chain.info.id, chain.info.cycle, chain.stepno) + + chain +end + diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl new file mode 100644 index 000000000..957342583 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -0,0 +1,161 @@ +abstract type TransformedMCMCProposal end +""" + BAT.TransformedMHProposal + +*BAT-internal, not part of stable public API.* +""" +struct TransformedMHProposal{ + D<:Union{Distribution, AbstractMeasure} +}<: TransformedMCMCProposal + proposal_dist::D +end + + +# TODO AC: find a better solution for this. Problem is that in the with_kw constructor below, we need to dispatch on this type. +struct TransformedMCMCDispatch end + +@with_kw struct TransformedMCMCSampling{ + TR<:AbstractTransformTarget, + IN<:TransformedMCMCInitAlgorithm, + BI<:TransformedMCMCBurninAlgorithm, + CT<:ConvergenceTest, + CB<:Function +} <: AbstractSamplingAlgorithm + pre_transform::TR = bat_default(TransformedMCMCDispatch, Val(:pre_transform)) + tuning_alg::TransformedMCMCTuningAlgorithm = TransformedRAMTuner() # TODO: use bat_defaults + adaptive_transform::AdaptiveTransformSpec = default_adaptive_transform(tuning_alg) + proposal::TransformedMCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults + tempering = TransformedNoTransformedMCMCTempering() # TODO: use bat_defaults + nchains::Int = 4 + nsteps::Int = 10^5 + #TODO: max_time ? + init::IN = bat_default(TransformedMCMCDispatch, Val(:init), pre_transform, nchains, nsteps) #TransformedMCMCChainPoolInit()#TODO AC: use bat_defaults bat_default(MCMCSampling, Val(:init), MetropolisHastings(), pre_transform, nchains, nsteps) #TODO + burnin::BI = bat_default(TransformedMCMCDispatch, Val(:burnin), pre_transform, nchains, nsteps) + convergence::CT = TransformedBrooksGelmanConvergence() + strict::Bool = true + store_burnin::Bool = false + nonzero_weights::Bool = true + callback::CB = nop_func +end + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:pre_transform}) = PriorToGaussian() + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:nsteps}, trafo::AbstractTransformTarget, nchains::Integer) = 10^5 + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:init}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = + TransformedMCMCChainPoolInit(nsteps_init = max(div(nsteps, 100), 250)) + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:burnin}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = + TransformedMCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 2500)) + + + +function bat_sample_impl( + target::AnyMeasureOrDensity, + algorithm::TransformedMCMCSampling, + context::BATContext +) + density_notrafo = convert(AbstractMeasureOrDensity, target) + density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo, context) + + init = mcmc_init!( + algorithm, + density, + algorithm.nchains, + apply_trafo_to_init(trafo, algorithm.init), + algorithm.tuning_alg, + algorithm.nonzero_weights, + algorithm.store_burnin ? algorithm.callback : nop_func, + context + ) + + @unpack chains, tuners, temperers = init + + # output_init = reduce(vcat, getproperty(chains, :samples)) + + burnin_outputs_coll = if algorithm.store_burnin + DensitySampleVector(first(chains)) + else + nothing + end + + # burnin and tuning + mcmc_burnin!( + burnin_outputs_coll, + chains, + tuners, + temperers, + algorithm.burnin, + algorithm.convergence, + algorithm.strict, + algorithm.nonzero_weights, + algorithm.store_burnin ? algorithm.callback : nop_func + ) + + # sampling + run_sampling = _run_sample_impl( + density, + algorithm, + chains, + ) + samples_trafo, generator = run_sampling.result_trafo, run_sampling.generator + + # prepend burnin samples to output + if algorithm.store_burnin + burnin_samples_trafo = varshape(density).(burnin_outputs_coll) + append!(burnin_samples_trafo, samples_trafo) + samples_trafo = burnin_samples_trafo + end + + samples_notrafo = inverse(trafo).(samples_trafo) + + + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) +end + +#= +function _bat_sample_continue( + target::AnyMeasureOrDensity, + generator::TransformedMCMCSampleGenerator, + ;description::AbstractString = "MCMC iterate" +) + @unpack algorithm, chains = generator + density_notrafo = convert(AbstractMeasureOrDensity, target) + density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo) + + run_sampling = _run_sample_impl(density, algorithm, chains, description=description) + + samples_trafo, generator = run_sampling.result_trafo, run_sampling.generator + + samples_notrafo = inverse(trafo).(samples_trafo) + + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) +end +=# + +function _run_sample_impl( + density::AnyMeasureOrDensity, + algorithm::TransformedMCMCSampling, + chains::AbstractVector{<:MCMCIterator}, + ;description::AbstractString = "MCMC iterate" +) + next_cycle!.(chains) + + progress_meter = ProgressMeter.Progress(algorithm.nchains*algorithm.nsteps, desc=description, barlen=80-length(description), dt=0.1) + + # tuners are set to 'NoOpTuner' for the sampling phase + transformed_mcmc_iterate!( + chains, + get_tuner.(Ref(TransformedMCMCNoOpTuning()),chains), + get_temperer.(Ref(TransformedNoTransformedMCMCTempering()), chains), + max_nsteps = algorithm.nsteps, #TODO: maxtime + nonzero_weights = algorithm.nonzero_weights, + callback = (kwargs...) -> let pm=progress_meter; ProgressMeter.next!(pm) ; end, + ) + ProgressMeter.finish!(progress_meter) + + output = reduce(vcat, getproperty.(chains, :samples)) + samples_trafo = varshape(density).(output) + + (result_trafo = samples_trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) +end diff --git a/src/samplers/transformed_mcmc/mcmc_sampleid.jl b/src/samplers/transformed_mcmc/mcmc_sampleid.jl new file mode 100644 index 000000000..fb072cd25 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_sampleid.jl @@ -0,0 +1,61 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +abstract type TransformedSampleID end + +struct TransformedMCMCTransformedSampleID{ + T<:Int32, + U<:Int64, +} <: TransformedSampleID + chainid::T + chaincycle::T + stepno::U +end + +function TransformedMCMCTransformedSampleID( + chainid::Integer, + chaincycle::Integer, + stepno::Integer, +) + TransformedMCMCTransformedSampleID(Int32(chainid), Int32(chaincycle), Int64(stepno)) +end + +const TransformedMCMCTransformedSampleIDVector{TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} = StructArray{ + TransformedMCMCTransformedSampleID, + 1, + NamedTuple{(:chainid, :chaincycle, :stepno), Tuple{TV,TV,UV}}, + Int +} + + +function TransformedMCMCTransformedSampleIDVector(contents::Tuple{TV,TV,UV}) where {TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} + StructArray{TransformedMCMCTransformedSampleID}(contents)::TransformedMCMCTransformedSampleIDVector{TV,UV} +end + +TransformedMCMCTransformedSampleIDVector(::UndefInitializer, len::Integer) = TransformedMCMCTransformedSampleIDVector(( + Vector{Int32}(undef, len), Vector{Int32}(undef, len), + Vector{Int64}(undef, len) +)) + +TransformedMCMCTransformedSampleIDVector() = TransformedMCMCTransformedSampleIDVector(undef, 0) + + +_create_undef_vector(::Type{TransformedMCMCTransformedSampleID}, len::Integer) = TransformedMCMCTransformedSampleIDVector(undef, len) + + +# Specialize comparison, currently StructArray seems fall back to `(==)(A::AbstractArray, B::AbstractArray)` +import Base.== +function(==)(A::TransformedMCMCTransformedSampleIDVector, B::TransformedMCMCTransformedSampleIDVector) + A.chainid == B.chainid && + A.chaincycle == B.chaincycle && + A.stepno == B.stepno +end + + +function Base.merge!(X::TransformedMCMCTransformedSampleIDVector, Xs::TransformedMCMCTransformedSampleIDVector...) + for Y in Xs + append!(X, Y) + end + X +end + +Base.merge(X::TransformedMCMCTransformedSampleIDVector, Xs::TransformedMCMCTransformedSampleIDVector...) = merge!(deepcopy(X), Xs...) diff --git a/src/samplers/transformed_mcmc/mcmc_stats.jl b/src/samplers/transformed_mcmc/mcmc_stats.jl new file mode 100644 index 000000000..214eafd4e --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_stats.jl @@ -0,0 +1,122 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +abstract type TransformedAbstractMCMCStats end +TransformedAbstractMCMCStats + + + +struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end + + +Base.push!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats + +Base.append!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats + + + +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats + param_stats::BasicMvStatistics{P,FrequencyWeights} + logtf_stats::BasicUvStatistics{L,FrequencyWeights} + mode::Vector{P} + + function TransformedMCMCBasicStats{L,P}(m::Integer) where {L<:Real,P<:Real} + param_stats = BasicMvStatistics{P,FrequencyWeights}(m) + logtf_stats = BasicUvStatistics{L,FrequencyWeights}() + mode = fill(P(NaN), m) + + new{L,P}( + param_stats, + logtf_stats, + mode + ) + end +end + + +function TransformedMCMCBasicStats(::Type{S}, ndof::Integer) where { + PT<:Real, T, W, S<:DensitySample{<:AbstractVector{PT},T,W} +} + SL = promote_type(T, Float64) + SP = promote_type(PT, W, Float64) + TransformedMCMCBasicStats{SL,SP}(ndof) +end + +TransformedMCMCBasicStats(chain::MCMCIterator) = TransformedMCMCBasicStats(sample_type(chain), totalndof(getmeasure(chain))) + +function TransformedMCMCBasicStats(sv::DensitySampleVector{<:AbstractVector{<:Real}}) + stats = TransformedMCMCBasicStats(eltype(sv), innersize(sv.v, 1)) + append!(stats, sv) +end + +TransformedMCMCBasicStats(sv::DensitySampleVector) = TransformedMCMCBasicStats(unshaped.(sv)) + + +function Base.empty!(stats::TransformedMCMCBasicStats) + empty!(stats.param_stats) + empty!(stats.logtf_stats) + fill!(stats.mode, eltype(stats.mode)(NaN)) + + stats +end + + +function Base.push!(stats::TransformedMCMCBasicStats, s::DensitySample) + push!(stats.param_stats, s.v, s.weight) + if s.logd > stats.logtf_stats.maximum + stats.mode .= s.v + end + push!(stats.logtf_stats, s.logd, s.weight) + stats +end + + +function Base.append!(stats::TransformedMCMCBasicStats, sv::DensitySampleVector) + for i in eachindex(sv) + p = sv.v[i] + w = sv.weight[i] + l = sv.logd[i] + push!(stats.param_stats, p, w) # Memory allocation (view)! + if sv.logd[i] > stats.logtf_stats.maximum + stats.mode .= p # Memory allocation (view)! + end + push!(stats.logtf_stats, l, w) + stats + end + stats +end + + +ValueShapes.totalndof(stats::TransformedMCMCBasicStats) = stats.param_stats.m + +nsamples(stats::TransformedMCMCBasicStats) = stats.param_stats.cov.sum_w + +function Base.merge!(target::TransformedMCMCBasicStats, others::TransformedMCMCBasicStats...) + for x in others + if (x.logtf_stats.maximum > target.logtf_stats.maximum) + target.mode .= x.mode + end + merge!(target.param_stats, x.param_stats) + merge!(target.logtf_stats, x.logtf_stats) + end + target +end + +Base.merge(a::TransformedMCMCBasicStats, bs::TransformedMCMCBasicStats...) = merge!(deepcopy(a), bs...) + + +function reweight_relative!(stats::TransformedMCMCBasicStats, reweighting_factor::Real) + reweight_relative!(stats.param_stats, reweighting_factor) + reweight_relative!(stats.logtf_stats, reweighting_factor) + + stats +end + + +function _bat_stats(mcmc_stats::TransformedMCMCBasicStats) + ( + mode = mcmc_stats.mode, + mean = mcmc_stats.param_stats.mean, + cov = mcmc_stats.param_stats.cov + ) +end diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl new file mode 100644 index 000000000..3416ebfd8 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -0,0 +1,55 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +""" + TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm + +No-op tuning, marks MCMC chains as tuned without performing any other changes +on them. Useful if chains are pre-tuned or tuning is an internal part of the +MCMC sampler implementation. +""" +struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end +export TransformedMCMCNoOpTuning + + + +struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end + +(tuning::TransformedMCMCNoOpTuning)(chain::MCMCIterator) = TransformedMCMCNoOpTuner() +get_tuner(tuning::TransformedMCMCNoOpTuning, chain::MCMCIterator) = TransformedMCMCNoOpTuner() + + +function TransformedMCMCNoOpTuning(tuning::TransformedMCMCNoOpTuning, chain::MCMCIterator) + TransformedMCMCNoOpTuner() +end + + +function tuning_init!(tuner::TransformedMCMCNoOpTuning, chain::MCMCIterator, max_nsteps::Integer) + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = true) + nothing +end + + + +function tune_mcmc_transform!!( + tuner::TransformedMCMCNoOpTuner, + transform, + p_accept::Real, + z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead + z_current::Vector{<:Float64}, + stepno::Int, + context::BATContext +) + return (tuner, transform) + +end + +tuning_postinit!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator, samples::DensitySampleVector) = nothing + +tuning_reinit!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator, max_nsteps::Integer) = nothing + +tuning_update!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator, samples::DensitySampleVector) = nothing + +tuning_finalize!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator) = nothing + +tuning_callback(::TransformedMCMCNoOpTuning) = nop_func diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl new file mode 100644 index 000000000..75aac0a51 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -0,0 +1,143 @@ +@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm + "Controls the weight given to new covariance information in adapting the + proposal distribution." + λ::Float64 = 0.5 + + "Metropolis-Hastings acceptance ratio target, tuning will try to adapt + the proposal distribution to bring the acceptance ratio inside this interval." + α::IntervalSets.ClosedInterval{Float64} = ClosedInterval(0.15, 0.35) + + "Controls how much the spread of the proposal distribution is + widened/narrowed depending on the current MH acceptance ratio." + β::Float64 = 1.5 + + "Interval for allowed scale/spread of the proposal distribution." + c::IntervalSets.ClosedInterval{Float64} = ClosedInterval(1e-4, 1e2) + + "Reweighting factor. Take accumulated sample statistics of previous + tuning cycles into account with a relative weight of `r`. Set to + `0` to completely reset sample statistics between each tuning cycle." + r::Real = 0.5 +end + +mutable struct TransformedProposalCovTuner{ + S<:TransformedMCMCBasicStats +} <: TransformedAbstractMCMCTunerInstance + config::TransformedAdaptiveMHTuning + stats::S + iteration::Int + scale::Float64 +end + + +function TransformedProposalCovTuner(tuning::TransformedAdaptiveMHTuning, chain::MCMCIterator) + m = totalndof(getmeasure(chain)) + scale = 2.38^2 / m + TransformedProposalCovTuner(tuning, TransformedMCMCBasicStats(chain), 1, scale) +end + +get_tuner(tuning::TransformedAdaptiveMHTuning, chain::MCMCIterator) = TransformedProposalCovTuner(tuning, chain) +default_adaptive_transform(tuner::TransformedAdaptiveMHTuning) = TriangularAffineTransform() + + +function tuning_init!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, max_nsteps::Integer) + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) + + nothing +end + +tuning_reinit!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, max_nsteps::Integer) = nothing + + +function tuning_postinit!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, samples::DensitySampleVector) + # The very first samples of a chain can be very valuable to init tuner + # stats, especially if the chain gets stuck early after: + stats = tuner.stats + append!(stats, samples) +end + +# this function is called once after each tuning cycle +g_state = nothing +function tuning_update!(tuner::TransformedProposalCovTuner, chain::TransformedMCMCIterator, samples::DensitySampleVector) + global g_state = (;tuner, chain) + + stats = tuner.stats + stats_reweight_factor = tuner.config.r + reweight_relative!(stats, stats_reweight_factor) + # empty!.(stats) + append!(stats, samples) + + + config = tuner.config + + α_min = minimum(config.α) + α_max = maximum(config.α) + + c_min = minimum(config.c) + c_max = maximum(config.c) + + β = config.β + + t = tuner.iteration + λ = config.λ + c = tuner.scale + + transform = chain.f_transform + + A = transform.A + Σ_old = A*A' + + S = convert(Array, stats.param_stats.cov) + a_t = 1 / t^λ + new_Σ_unscal = (1 - a_t) * (Σ_old/c) + a_t * S + + α = eff_acceptance_ratio(chain) + + max_log_posterior = stats.logtf_stats.maximum + + if α_min <= α <= α_max + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = true) + @debug "MCMC chain $(chain.info.id) tuned, acceptance ratio = $(Float32(α)), proposal scale = $(Float32(c)), max. log posterior = $(Float32(max_log_posterior))" + else + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) + @debug "MCMC chain $(chain.info.id) *not* tuned, acceptance ratio = $(Float32(α)), proposal scale = $(Float32(c)), max. log posterior = $(Float32(max_log_posterior))" + + if α > α_max && c < c_max + tuner.scale = c * β + elseif α < α_min && c > c_min + tuner.scale = c / β + end + end + + Σ_new = new_Σ_unscal * tuner.scale + + S_new = cholesky(Positive, Σ_new) + chain.f_transform = Mul(S_new.L) + tuner.iteration += 1 + + nothing + +end + + +tuning_finalize!(tuner::TransformedProposalCovTuner, chain::MCMCIterator) = nothing + +tuning_callback(::TransformedProposalCovTuner) = nop_func + +# default_adaptive_transform(tuner::TransformedProposalCovTuner) = TriangularAffineTransform() + + +# this function is called in each mcmc_iterate step during tuning +function tune_mcmc_transform!!( + tuner::TransformedProposalCovTuner, + transform::Any, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, + p_accept::Real, + z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead + z_current::Vector{<:Float64}, + stepno::Int, + context::BATContext +) + + return (tuner, transform) +end + diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl new file mode 100644 index 000000000..07baad293 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -0,0 +1,87 @@ +@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning + target_acceptance::Float64 = 0.234 #TODO AC: how to pass custom intitial value for cov matrix? + σ_target_acceptance::Float64 = 0.05 + gamma::Float64 = 2/3 +end + +@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance + config::TransformedRAMTuner + nsteps::Int = 0 +end +TransformedRAMTunerInstance(ram::TransformedRAMTuner) = TransformedRAMTunerInstance(config = ram) + +get_tuner(tuning::TransformedRAMTuner, chain::MCMCIterator) = TransformedRAMTunerInstance(tuning) + + +function tuning_init!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, max_nsteps::Integer) + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) # TODO ? + tuner.nsteps = 0 + + return nothing +end + + +tuning_postinit!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, samples::DensitySampleVector) = nothing + +# TODO AC: is this still needed? +# function tuning_postinit!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, samples::DensitySampleVector) +# # The very first samples of a chain can be very valuable to init tuner +# # stats, especially if the chain gets stuck early after: +# stats = tuner.stats +# append!(stats, samples) +# end + +tuning_reinit!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, max_nsteps::Integer) = nothing + + + + + +function tuning_update!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, samples::DensitySampleVector) + α_min, α_max = map(op -> op(1, tuner.config.σ_target_acceptance), [-,+]) .* tuner.config.target_acceptance + α = eff_acceptance_ratio(chain) + + max_log_posterior = maximum(samples.logd) + + if α_min <= α <= α_max + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = true) + @debug "MCMC chain $(chain.info.id) tuned, acceptance ratio = $(Float32(α)), max. log posterior = $(Float32(max_log_posterior))" + else + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) + @debug "MCMC chain $(chain.info.id) *not* tuned, acceptance ratio = $(Float32(α)), max. log posterior = $(Float32(max_log_posterior))" + end +end + +tuning_finalize!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator) = nothing + +# tuning_callback(::TransformedRAMTuner) = nop_func + + + +default_adaptive_transform(tuner::TransformedRAMTuner) = TriangularAffineTransform() + +function tune_mcmc_transform!!( + tuner::TransformedRAMTunerInstance, + transform::Mul{<:LowerTriangular}, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, + p_accept::Real, + z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead + z_current::Vector{<:Float64}, + stepno::Int, + context::BATContext +) + @unpack target_acceptance, gamma = tuner.config + n = size(z_current,1) + η = min(1, n * tuner.nsteps^(-gamma)) + + s_L = transform.A + + u = z_proposed-z_current + M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2 ) * s_L' + + S = cholesky(Positive, M) + transform_new = Mul(S.L) + + tuner.nsteps += 1 + + return (tuner, transform_new) +end diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_tuning.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_tuning.jl new file mode 100644 index 000000000..c1c7a0b18 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_tuning.jl @@ -0,0 +1,3 @@ +include("mcmc_noop_tuner.jl") +include("mcmc_ram_tuner.jl") +include("mcmc_proposalcov_tuner.jl") \ No newline at end of file diff --git a/src/samplers/transformed_mcmc/mcmc_utils.jl b/src/samplers/transformed_mcmc/mcmc_utils.jl new file mode 100644 index 000000000..bce9f6bac --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_utils.jl @@ -0,0 +1,29 @@ +# TODO AC: File not included as it would overwrite BAT.jl functions + + +function _cov_with_fallback(d) + rng = bat_determ_rng() + smplr = bat_sampler(d) + T = float(eltype(rand(rng, smplr))) + n = totalndof(varshape(d)) + C = fill(T(NaN), n, n) + try + C[:] = cov(d) + catch err + if err isa MethodError + C[:] = cov(nestedview(rand(rng, smplr, 10^5))) + else + throw(err) + end + end + return C +end + +_approx_cov(target::Distribution) = _cov_with_fallback(target) +_approx_cov(target::DistLikeMeasure) = _cov_with_fallback(target) +_approx_cov(target::AbstractPosteriorMeasure) = _approx_cov(getprior(target)) +_approx_cov(target::BAT.Transformed{<:Any,<:BAT.DistributionTransform}) = + BAT._approx_cov(target.trafo.target_dist) +_approx_cov(target::Renormalized) = _approx_cov(parent(target)) +_approx_cov(target::WithDiff) = _approx_cov(parent(target)) + diff --git a/src/samplers/transformed_mcmc/mcmc_weighting.jl b/src/samplers/transformed_mcmc/mcmc_weighting.jl new file mode 100644 index 000000000..d1fa41df7 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_weighting.jl @@ -0,0 +1,53 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +""" + abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} + +Abstract class for weighting schemes for MCMC samples. + +Weight values will have type `T`. +""" +abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} end +export TransformedAbstractMCMCWeightingScheme + + +sample_weight_type(::Type{<:TransformedAbstractMCMCWeightingScheme{T}}) where {T} = T + + + +""" + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} + +Sample weighting scheme suitable for sampling algorithms which may repeated +samples multiple times in direct succession (e.g. +[`MetropolisHastings`](@ref)). The repeated sample is stored only once, +with a weight equal to the number of times it has been repeated (e.g. +because a Markov chain has not moved during a sampling step). + +Constructors: + +* ```$(FUNCTIONNAME)()``` +""" +struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end +export TransformedRepetitionWeighting + +TransformedRepetitionWeighting() = TransformedRepetitionWeighting{Int}() + + +""" + TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} + +Sample weighting scheme suitable for accept/reject-based sampling algorithms +(e.g. [`MetropolisHastings`](@ref)). Both accepted and rejected samples +become part of the output, with a weight proportional to their original +acceptance probability. + +Constructors: + +* ```$(FUNCTIONNAME)()``` +""" +struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end +export TransformedARPWeighting + +TransformedARPWeighting() = TransformedARPWeighting{Float64}() diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl new file mode 100644 index 000000000..700ddd320 --- /dev/null +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -0,0 +1,110 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +""" + struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm + +A multi-cycle MCMC burn-in algorithm. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm + nsteps_per_cycle::Int64 = 10000 + max_ncycles::Int = 30 + nsteps_final::Int64 = div(nsteps_per_cycle, 10) +end + +export TransformedMCMCMultiCycleBurnin + + +function mcmc_burnin!( + outputs::Union{DensitySampleVector,Nothing}, + chains::AbstractVector{<:MCMCIterator}, + tuners::AbstractVector{<:TransformedAbstractMCMCTunerInstance}, + temperers::AbstractVector{<:TransformedMCMCTemperingInstance}, + burnin_alg::TransformedMCMCMultiCycleBurnin, + convergence_test::ConvergenceTest, + strict_mode::Bool, + nonzero_weights::Bool, + callback::Function +) + nchains = length(chains) + + @info "Begin tuning of $nchains MCMC chain(s)." + + cycles = zero(Int) + successful = false + while !successful && cycles < burnin_alg.max_ncycles + cycles += 1 + + next_cycle!.(chains) + + tuning_reinit!.(tuners, chains, burnin_alg.nsteps_per_cycle) + + desc_string = string("Burnin cycle ", cycles, "/max_cycles=", burnin_alg.max_ncycles," for nchains=", length(chains)) + progress_meter = ProgressMeter.Progress(length(chains)*burnin_alg.nsteps_per_cycle, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + + transformed_mcmc_iterate!( + chains, tuners, temperers, + max_nsteps = burnin_alg.nsteps_per_cycle, + nonzero_weights = nonzero_weights, + callback = (kwargs...) -> let pm=progress_meter; ProgressMeter.next!(progress_meter) ; end, + ) + ProgressMeter.finish!(progress_meter) + + new_outputs = getproperty.(chains, :samples) + + tuning_update!.(tuners, chains, new_outputs) + + isnothing(outputs) || append!(outputs, reduce(vcat, new_outputs)) + + transformed_check_convergence!(chains, new_outputs, convergence_test, BATContext()) # TODO AC: Rename + + # check_tuned/update_tuners... + ntuned = count(c -> c.info.tuned, chains) + nconverged = count(c -> c.info.converged, chains) + successful = (ntuned == nconverged == nchains) + + callback(Val(:mcmc_burnin), tuners, chains) + + @info "MCMC Tuning cycle $cycles finished, $nchains chains, $ntuned tuned, $nconverged converged." + end + + tuning_finalize!.(tuners, chains) + + if successful + @info "MCMC tuning of $nchains chains successful after $cycles cycle(s)." + else + msg = "MCMC tuning of $nchains chains aborted after $cycles cycle(s)." + if strict_mode + throw(ErrorException(msg)) + else + @warn msg + end + end + + if burnin_alg.nsteps_final > 0 + @info "Running post-tuning stabilization steps for $nchains MCMC chain(s)." + + # turn off tuning + next_cycle!.(chains) + tuners = TransformedMCMCNoOpTuning().(chains) + + # TODO AC: what about tempering? + + transformed_mcmc_iterate!( + chains, tuners, temperers, + max_nsteps = burnin_alg.nsteps_final, + nonzero_weights = nonzero_weights, + callback = callback + ) + end + + successful +end diff --git a/src/samplers/transformed_mcmc/proposaldist.jl b/src/samplers/transformed_mcmc/proposaldist.jl new file mode 100644 index 000000000..231ce0be9 --- /dev/null +++ b/src/samplers/transformed_mcmc/proposaldist.jl @@ -0,0 +1,204 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +""" + abstract type TransformedAbstractProposalDist + +*BAT-internal, not part of stable public API.* + +The following functions must be implemented for subtypes: + +* `BAT.proposaldist_logpdf` +* `BAT.proposal_rand!` +* `ValueShapes.totalndof`, returning the number of DOF (i.e. dimensionality). +* `LinearAlgebra.issymmetric`, indicating whether p(a -> b) == p(b -> a) holds true. +""" +abstract type TransformedAbstractProposalDist end + +# TODO AC: reactivate +# """ +# proposaldist_logpdf( +# p::AbstractArray, +# pdist::TransformedAbstractProposalDist, +# v_proposed::AbstractVector, +# v_current:::AbstractVector +# ) + +# *BAT-internal, not part of stable public API.* + +# Returns log(PDF) value of `pdist` for transitioning from current to proposed +# variate/parameters. +# """#function proposaldist_logpdf end + +# TODO: Implement proposaldist_logpdf for included proposal distributions + + +# TODO AC: reactivate +# """ +# function proposal_rand!( +# rng::AbstractRNG, +# pdist::TransformedGenericProposalDist, +# v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, +# v_current::Union{AbstractVector,VectorOfSimilarVectors} +# ) + +# *BAT-internal, not part of stable public API.* + +# Generate one or multiple proposed variate/parameter vectors, based on one or +# multiple previous vectors. + +# Input: + +# * `rng`: Random number generator to use +# * `pdist`: Proposal distribution to use +# * `v_current`: Old values (vector or column vectors, if a matrix) + +# Output is stored in + +# * `v_proposed`: New values (vector or column vectors, if a matrix) + +# The caller must guarantee: + +# * `size(v_current, 1) == size(v_proposed, 1)` +# * `size(v_current, 2) == size(v_proposed, 2)` or `size(v_current, 2) == 1` +# * `v_proposed !== v_current` (no aliasing) + +# Implementations of `proposal_rand!` must be thread-safe. +# """ +# function proposal_rand! end + + + +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist + d::D + sampler_f::SamplerF + s::S + + function TransformedGenericProposalDist{D,SamplerF}(d::D, sampler_f::SamplerF) where {D<:Distribution{Multivariate},SamplerF} + s = sampler_f(d) + new{D,SamplerF, typeof(s)}(d, sampler_f, s) + end + +end + + +TransformedGenericProposalDist(d::D, sampler_f::SamplerF) where {D<:Distribution{Multivariate},SamplerF} = + TransformedGenericProposalDist{D,SamplerF}(d, sampler_f) + +TransformedGenericProposalDist(d::Distribution{Multivariate}) = TransformedGenericProposalDist(d, bat_sampler) + +TransformedGenericProposalDist(D::Type{<:Distribution{Multivariate}}, varndof::Integer, args...) = + TransformedGenericProposalDist(D, Float64, varndof, args...) + + +Base.similar(q::TransformedGenericProposalDist, d::Distribution{Multivariate}) = + TransformedGenericProposalDist(d, q.sampler_f) + +function Base.convert(::Type{TransformedAbstractProposalDist}, q::TransformedGenericProposalDist, T::Type{<:AbstractFloat}, varndof::Integer) + varndof != totalndof(q) && throw(ArgumentError("q has wrong number of DOF")) + q +end + + +get_cov(q::TransformedGenericProposalDist) = get_cov(q.d) +set_cov(q::TransformedGenericProposalDist, Σ::PosDefMatLike) = similar(q, set_cov(q.d, Σ)) + + +function proposaldist_logpdf( + pdist::TransformedGenericProposalDist, + v_proposed::AbstractVector, + v_current::AbstractVector +) + params_diff = v_proposed .- v_current # TODO: Avoid memory allocation + logpdf(pdist.d, params_diff) +end + + +function proposal_rand!( + rng::AbstractRNG, + pdist::TransformedGenericProposalDist, + v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, + v_current::Union{AbstractVector,VectorOfSimilarVectors} +) + rand!(rng, pdist.s, flatview(v_proposed)) + params_new_flat = flatview(v_proposed) + params_new_flat .+= flatview(v_current) + v_proposed +end + + +ValueShapes.totalndof(pdist::TransformedGenericProposalDist) = length(pdist.d) + +LinearAlgebra.issymmetric(pdist::TransformedGenericProposalDist) = issymmetric_around_origin(pdist.d) + + + +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist + d::D + scale::Vector{T} + sampler_f::SamplerF + s::S +end + + +TransformedGenericUvProposalDist(d::Distribution{Univariate}, scale::Vector{<:AbstractFloat}, samplerF) = + TransformedGenericUvProposalDist(d, scale, samplerF, samplerF(d)) + +TransformedGenericUvProposalDist(d::Distribution{Univariate}, scale::Vector{<:AbstractFloat}) = + TransformedGenericUvProposalDist(d, scale, bat_sampler) + + +ValueShapes.totalndof(pdist::TransformedGenericUvProposalDist) = size(pdist.scale, 1) + +LinearAlgebra.issymmetric(pdist::TransformedGenericUvProposalDist) = issymmetric_around_origin(pdist.d) + +function BAT.proposaldist_logpdf( + pdist::TransformedGenericUvProposalDist, + v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, + v_current::Union{AbstractVector,VectorOfSimilarVectors} +) + params_diff = (flatview(v_proposed) .- flatview(v_current)) ./ pdist.scale # TODO: Avoid memory allocation + sum_first_dim(logpdf.(pdist.d, params_diff)) # TODO: Avoid memory allocation +end + +function BAT.proposal_rand!( + rng::AbstractRNG, + pdist::TransformedGenericUvProposalDist, + v_proposed::AbstractVector, + v_current::AbstractVector +) + v_proposed .= v_current + dim = rand(rng, eachindex(pdist.scale)) + v_proposed[dim] += pdist.scale[dim] * rand(rng, pdist.s) + v_proposed +end + + + +abstract type TransformedProposalDistSpec end + + +struct TransformedMvTDistProposal <: TransformedProposalDistSpec + df::Float64 +end + +TransformedMvTDistProposal() = TransformedMvTDistProposal(1.0) + + +(ps::TransformedMvTDistProposal)(T::Type{<:AbstractFloat}, varndof::Integer) = + TransformedGenericProposalDist(MvTDist, T, varndof, convert(T, ps.df)) + +function TransformedGenericProposalDist(::Type{MvTDist}, T::Type{<:AbstractFloat}, varndof::Integer, df = one(T)) + Σ = PDMat(Matrix(ScalMat(varndof, one(T)))) + μ = Fill(zero(eltype(Σ)), varndof) + M = typeof(Σ) + d = Distributions.GenericMvTDist(convert(T, df), μ, Σ) + TransformedGenericProposalDist(d) +end + + +struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec + df::Float64 +end + +(ps::TransformedUvTDistProposalSpec)(T::Type{<:AbstractFloat}, varndof::Integer) = + TransformedGenericUvProposalDist(TDist(convert(T, ps.df)), fill(one(T), varndof)) diff --git a/src/samplers/transformed_mcmc/tempering.jl b/src/samplers/transformed_mcmc/tempering.jl new file mode 100644 index 000000000..4e0c4a005 --- /dev/null +++ b/src/samplers/transformed_mcmc/tempering.jl @@ -0,0 +1,18 @@ +abstract type TransformedMCMCTempering end +struct TransformedNoTransformedMCMCTempering <: TransformedMCMCTempering end + +""" + temper_mcmc_target!!(tempering::TransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) +""" +function temper_mcmc_target!! end + + + +abstract type TransformedMCMCTemperingInstance end + +struct NoTransformedMCMCTemperingInstance <: TransformedMCMCTemperingInstance end + +temper_mcmc_target!!(tempering::NoTransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ + +get_temperer(tempering::TransformedNoTransformedMCMCTempering, density::BATMeasure) = NoTransformedMCMCTemperingInstance() +get_temperer(tempering::TransformedNoTransformedMCMCTempering, chain::MCMCIterator) = get_temperer(tempering, chain.μ) diff --git a/src/transforms/adaptive_transform.jl b/src/transforms/adaptive_transform.jl new file mode 100644 index 000000000..f44ffef4f --- /dev/null +++ b/src/transforms/adaptive_transform.jl @@ -0,0 +1,40 @@ +abstract type AdaptiveTransformSpec end + + +struct CustomTransform{F} <: AdaptiveTransformSpec + f::F +end + +CustomTransform() = CustomTransform(identity) + +function init_adaptive_transform( + adaptive_transform::CustomTransform, + density, + context +) + return adaptive_transform.f +end + + + +struct TriangularAffineTransform <: AdaptiveTransformSpec end + +function init_adaptive_transform( + adaptive_transform::TriangularAffineTransform, + density, + context +) + M = _approx_cov(density) + s = cholesky(M).L + g = Mul(s) + + return g +end + + + +struct DiagonalAffineTransform <: AdaptiveTransformSpec end + + + + diff --git a/src/transforms/transforms.jl b/src/transforms/transforms.jl index 6c19878a2..fdd814e9a 100644 --- a/src/transforms/transforms.jl +++ b/src/transforms/transforms.jl @@ -2,3 +2,4 @@ include("trafo_utils.jl") include("distribution_transform.jl") +include("adaptive_transform.jl")