From c1cec60b1293ad5230aa1523fd00e6386c434c95 Mon Sep 17 00:00:00 2001 From: AG Date: Fri, 30 Jun 2023 12:49:47 +0200 Subject: [PATCH 01/33] use Tranformed prefix for all structs Co-Authored-By: Cornelius-G --- .../transformed_mcmc/chain_pool_init.jl | 215 ++++++++++++ src/samplers/transformed_mcmc/example.jl | 59 ++++ src/samplers/transformed_mcmc/mcmc.jl | 16 + .../transformed_mcmc/mcmc_algorithm.jl | 242 +++++++++++++ .../transformed_mcmc/mcmc_convergence.jl | 165 +++++++++ src/samplers/transformed_mcmc/mcmc_iterate.jl | 329 ++++++++++++++++++ src/samplers/transformed_mcmc/mcmc_sample.jl | 159 +++++++++ .../transformed_mcmc/mcmc_sampleid.jl | 61 ++++ src/samplers/transformed_mcmc/mcmc_stats.jl | 122 +++++++ .../mcmc_tuning/mcmc_noop_tuner.jl | 55 +++ .../mcmc_tuning/mcmc_proposalcov_tuner.jl | 145 ++++++++ .../mcmc_tuning/mcmc_ram_tuner.jl | 89 +++++ .../mcmc_tuning/mcmc_tuning.jl | 3 + src/samplers/transformed_mcmc/mcmc_utils.jl | 27 ++ .../transformed_mcmc/mcmc_weighting.jl | 53 +++ .../transformed_mcmc/multi_cycle_burnin.jl | 110 ++++++ src/samplers/transformed_mcmc/proposaldist.jl | 205 +++++++++++ .../transformed_mcmc/replace_type_list.sh | 14 + src/samplers/transformed_mcmc/struct_list.jl | 46 +++ src/samplers/transformed_mcmc/tempering.jl | 18 + 20 files changed, 2133 insertions(+) create mode 100644 src/samplers/transformed_mcmc/chain_pool_init.jl create mode 100644 src/samplers/transformed_mcmc/example.jl create mode 100644 src/samplers/transformed_mcmc/mcmc.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_algorithm.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_convergence.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_iterate.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_sample.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_sampleid.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_stats.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_tuning/mcmc_tuning.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_utils.jl create mode 100644 src/samplers/transformed_mcmc/mcmc_weighting.jl create mode 100644 src/samplers/transformed_mcmc/multi_cycle_burnin.jl create mode 100644 src/samplers/transformed_mcmc/proposaldist.jl create mode 100644 src/samplers/transformed_mcmc/replace_type_list.sh create mode 100644 src/samplers/transformed_mcmc/struct_list.jl create mode 100644 src/samplers/transformed_mcmc/tempering.jl diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl new file mode 100644 index 000000000..20acac8ad --- /dev/null +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -0,0 +1,215 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +""" + struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm + +MCMC chain pool initialization strategy. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm + init_tries_per_chain::ClosedInterval{Int64} = ClosedInterval(8, 128) + nsteps_init::Int64 = 1000 + initval_alg::InitvalAlgorithm = InitFromTarget() +end + +export TransformedMCMCChainPoolInit + + +function apply_trafo_to_init(trafo::Function, initalg::TransformedMCMCChainPoolInit) + TransformedMCMCChainPoolInit( + initalg.init_tries_per_chain, + initalg.nsteps_init, + apply_trafo_to_init(trafo, initalg.initval_alg) + ) +end + + + +function _construct_chain( + rngpart::RNGPartition, + id::Integer, + algorithm::TransformedMCMCSampling, + density::AbstractMeasureOrDensity, + initval_alg::InitvalAlgorithm +) + rng = AbstractRNG(rngpart, id) + v_init = bat_initval(rng, density, initval_alg).result + + TransformedMCMCIterator(rng, algorithm, density, id, v_init) +end + +_gen_chains( + rngpart::RNGPartition, + ids::AbstractRange{<:Integer}, + algorithm::TransformedMCMCSampling, + density::AbstractMeasureOrDensity, + initval_alg::InitvalAlgorithm +) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids] + +#TODO +function mcmc_init!( + rng::AbstractRNG, + algorithm::TransformedMCMCSampling, + density::AbstractMeasureOrDensity, + nchains::Integer, + init_alg::TransformedMCMCChainPoolInit, + tuning_alg::MCMCTuningAlgorithm, # TODO: part of algorithm? # MCMCTuner + nonzero_weights::Bool, + callback::Function +) + @info "TransformedMCMCChainPoolInit: trying to generate $nchains viable MCMC chain(s)." + + initval_alg = init_alg.initval_alg + + min_nviable::Int = minimum(init_alg.init_tries_per_chain) * nchains + max_ncandidates::Int = maximum(init_alg.init_tries_per_chain) * nchains + + rngpart = RNGPartition(rng, Base.OneTo(max_ncandidates)) + + ncandidates::Int = 0 + + @debug "Generating dummy MCMC chain to determine chain, output and tuner types." #TODO: remove! + + dummy_initval = unshaped(bat_initval(rng, density, InitFromTarget()).result, varshape(density)) + dummy_chain = TransformedMCMCIterator(rng, algorithm, density, 1, dummy_initval) + dummy_tuner = get_tuner(tuning_alg, dummy_chain) + dummy_temperer = get_temperer(algorithm.tempering, density) + + chains = similar([dummy_chain], 0) + tuners = similar([dummy_tuner], 0) + temperers = similar([dummy_temperer], 0) + + init_tries::Int = 1 + + while length(tuners) < min_nviable && ncandidates < max_ncandidates + + n = min(min_nviable, max_ncandidates - ncandidates) + @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." + + new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg) + + filter!(isvalidchain, new_chains) + + new_tuners = get_tuner.(Ref(tuning_alg), new_chains) + new_temperers = fill(get_temperer(algorithm.tempering, density), size(new_tuners,1)) + + next_cycle!.(new_chains) + + tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) + ncandidates += n + + @debug "Testing $(length(new_chains)) candidate MCMC chain(s)." + + transformed_mcmc_iterate!( + new_chains, new_tuners, new_temperers, + max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), + callback = callback, + nonzero_weights = nonzero_weights + ) + + # testing if chains are viable: + viable_idxs = findall(isviablechain.(new_chains)) + viable_temperers = new_temperers[viable_idxs] + viable_tuners = new_tuners[viable_idxs] + viable_chains = new_chains[viable_idxs] + + @debug "Found $(length(viable_idxs)) viable MCMC chain(s)." + + if !isempty(viable_chains) + desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_idxs), " of min_nviable=", length(tuners), "/", min_nviable ) + progress_meter = ProgressMeter.Progress(length(viable_idxs) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + transformed_mcmc_iterate!( + viable_chains, viable_tuners, viable_temperers; + max_nsteps = init_alg.nsteps_init, + callback = (kwargs...)-> let pm=progress_meter; ProgressMeter.next!(pm) ; end, + nonzero_weights = nonzero_weights + ) + ProgressMeter.finish!(progress_meter) + nsamples_thresh = floor(Int, 0.8 * median([nsamples(chain) for chain in viable_chains])) + good_idxs = findall(chain -> nsamples(chain) >= nsamples_thresh, viable_chains) + @debug "Found $(length(viable_chains)) MCMC chain(s) with at least $(nsamples_thresh) unique accepted samples." + + + append!(chains, view(viable_chains, good_idxs)) + append!(tuners, view(viable_tuners, good_idxs)) + append!(temperers, view(viable_temperers, good_idxs)) + end + + init_tries += 1 + end + + outputs = getproperty.(chains, :samples) + + length(chains) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") + + m = nchains + tidxs = LinearIndices(chains) + n = length(tidxs) + + modes = hcat(broadcast(samples -> Array(bat_findmode(rng, samples, MaxDensitySearch()).result), outputs)...) + + final_chains = similar(chains, 0) + final_tuners = similar(tuners, 0) + final_temperers = similar(temperers, 0) + final_outputs = similar(outputs, 0) + + # TODO: should we put this into a function? + if 2 <= m < size(modes, 2) + clusters = kmeans(modes, m, init = KmCentralityAlg()) + clusters.converged || error("k-means clustering of MCMC chains did not converge") + + mincosts = fill(Inf, m) + chain_sel_idxs = fill(0, m) + + for i in tidxs + j = clusters.assignments[i] + if clusters.costs[i] < mincosts[j] + mincosts[j] = clusters.costs[i] + chain_sel_idxs[j] = i + end + end + + @assert all(j -> j in tidxs, chain_sel_idxs) + + for i in sort(chain_sel_idxs) + push!(final_chains, chains[i]) + push!(final_tuners, tuners[i]) + push!(final_temperers, temperers[i]) + push!(final_outputs, outputs[i]) + end + elseif m == 1 + i = findmax(nsamples.(chains))[2] + push!(final_chains, chains[i]) + push!(final_tuners, tuners[i]) + push!(final_temperers, temperers[i]) + push!(final_outputs, outputs[i]) + else + @assert length(chains) == nchains + resize!(final_chains, nchains) + copyto!(final_chains, chains) + + @assert length(tuners) == nchains + resize!(final_tuners, nchains) + copyto!(final_tuners, tuners) + + @assert length(temperers) == nchains + resize!(final_temperers, nchains) + copyto!(final_temperers, temperers) + + @assert length(outputs) == nchains + resize!(final_outputs, nchains) + copyto!(final_outputs, outputs) + end + + @info "Selected $(length(final_chains)) MCMC chain(s)." + #tuning_postinit!.(final_tuners, final_chains, final_outputs) #TODO: implement + + (chains = final_chains, tuners = final_tuners, temperers = final_temperers, outputs = final_outputs) +end diff --git a/src/samplers/transformed_mcmc/example.jl b/src/samplers/transformed_mcmc/example.jl new file mode 100644 index 000000000..8d270dc20 --- /dev/null +++ b/src/samplers/transformed_mcmc/example.jl @@ -0,0 +1,59 @@ +using BAT +using BAT.MeasureBase +using AffineMaps +using ChangesOfVariables +using BAT.LinearAlgebra +using BAT.Distributions +using BAT.InverseFunctions +import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoMCMCTempering, transformed_mcmc_step!!, TransformedMCMCSampleID +using BAT.Random123 + +import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling + +#ENV["JULIA_DEBUG"] = "BAT" + +rng = Philox4x() + +posterior = BAT.example_posterior() + +my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000)) +my_samples = my_result.result + +mh_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(tuning_alg=TransformedAdaptiveMHTuning(), pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000)) + +(;chain, tuner) = BAT.g_state + + +using Plots +plot(my_samples) + +r_mh = @time BAT.bat_sample_impl(rng, posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true) ) + +r_hmc = @time BAT.bat_sample_impl(rng, posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000) ) + +plot(bat_sample(posterior).result) + +using BAT.Distributions +using BAT.ValueShapes +prior2 = NamedTupleDist(ShapedAsNT, + b = [4.2, 3.3], + a = Exponential(1.0), + c = Normal(1.0,3.0), + d = product_distribution(Weibull.(ones(2),1)), + e = Beta(1.0, 1.0), + f = MvNormal([0.3,-2.9],Matrix([1.7 0.5;0.5 2.3])) + ) + +posterior.likelihood.density._log_f(rand(posterior.prior)) + +posterior.likelihood.density._log_f(rand(prior2)) + +posterior2 = PosteriorDensity(BAT.logfuncdensity(posterior.likelihood.density._log_f), prior2) + + +@profview r_ram2 = @time BAT.bat_sample_impl(rng, posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000)) + +@profview r_mh2 = @time BAT.bat_sample_impl(rng, posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true) ) + +r_hmc2 = @time BAT.bat_sample_impl(rng, posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000) ) + diff --git a/src/samplers/transformed_mcmc/mcmc.jl b/src/samplers/transformed_mcmc/mcmc.jl new file mode 100644 index 000000000..2e7efec54 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc.jl @@ -0,0 +1,16 @@ +using AffineMaps + +include("mcmc_utils.jl") + +include("mcmc_weighting.jl") +include("proposaldist.jl") +include("mcmc_sampleid.jl") +include("mcmc_algorithm.jl") +include("mcmc_stats.jl") +include("mcmc_tuning/mcmc_tuning.jl") +include("mcmc_convergence.jl") +include("tempering.jl") +include("mcmc_sample.jl") +include("mcmc_iterate.jl") +include("multi_cycle_burnin.jl") +include("chain_pool_init.jl") \ No newline at end of file diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl new file mode 100644 index 000000000..07d4accf6 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -0,0 +1,242 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + + +function get_mcmc_tuning end #TODO: still needed + + +""" + abstract type MCMCInitAlgorithm + +Abstract type for MCMC initialization algorithms. +""" +abstract type MCMCInitAlgorithm end +export MCMCInitAlgorithm + +apply_trafo_to_init(trafo::Function, initalg::MCMCInitAlgorithm) = initalg + + + +""" + abstract type MCMCTuningAlgorithm + +Abstract type for MCMC tuning algorithms. +""" +abstract type MCMCTuningAlgorithm end +export MCMCTuningAlgorithm + + + +""" + abstract type MCMCBurninAlgorithm + +Abstract type for MCMC burn-in algorithms. +""" +abstract type MCMCBurninAlgorithm end +export MCMCBurninAlgorithm + + + +@with_kw struct TransformedMCMCIteratorInfo + id::Int32 + cycle::Int32 + tuned::Bool + converged::Bool +end + + +""" + abstract type MCMCIterator end + +Represents the current state of an MCMC chain. + +!!! note + + The details of the `MCMCIterator` and `MCMCAlgorithm` API (see below) + currently do not form part of the stable API and are subject to change + without deprecation. + +To implement a new MCMC algorithm, subtypes of both [`MCMCAlgorithm`](@ref) +and `MCMCIterator` are required. + +The following methods must be defined for subtypes of `MCMCIterator` (e.g. +`SomeMCMCIter<:MCMCIterator`): + +```julia + +BAT.getmeasure(chain::SomeMCMCIter)::AbstractMeasureOrDensity + +BAT.getrng(chain::SomeMCMCIter)::AbstractRNG + +BAT.mcmc_info(chain::SomeMCMCIter)::TransformedMCMCIteratorInfo + +BAT.nsteps(chain::SomeMCMCIter)::Int + +BAT.nsamples(chain::SomeMCMCIter)::Int + +BAT.current_sample(chain::SomeMCMCIter)::DensitySample + +BAT.sample_type(chain::SomeMCMCIter)::Type{<:DensitySample} + +BAT.samples_available(chain::SomeMCMCIter, nonzero_weights::Bool = false)::Bool + +BAT.get_samples!(samples::DensitySampleVector, chain::SomeMCMCIter, nonzero_weights::Bool)::typeof(samples) + +BAT.next_cycle!(chain::SomeMCMCIter)::SomeMCMCIter + +BAT.mcmc_step!( + chain::SomeMCMCIter + callback::Function, +)::nothing +``` + +The following methods are implemented by default: + +```julia +getalgorithm(chain::MCMCIterator) +getmeasure(chain::MCMCIterator) +DensitySampleVector(chain::MCMCIterator) +mcmc_iterate!(chain::MCMCIterator, ...) +mcmc_iterate!(chains::AbstractVector{<:MCMCIterator}, ...) +isvalidchain(chain::MCMCIterator) +isviablechain(chain::MCMCIterator) +``` +""" +abstract type MCMCIterator end +export MCMCIterator + + +function Base.show(io::IO, chain::MCMCIterator) + print(io, Base.typename(typeof(chain)).name, "(") + print(io, "id = "); show(io, mcmc_info(chain).id) + print(io, ", nsamples = "); show(io, nsamples(chain)) + print(io, ", density = "); show(io, getmeasure(chain)) + print(io, ")") +end + + +function getalgorithm end + +function getmeasure end + +function getrng end + +function mcmc_info end + +function nsteps end + +function nsamples end + +function current_sample end + +function sample_type end + +function samples_available end + +function get_samples! end + +function next_cycle! end + +function mcmc_step! end + + + +DensitySampleVector(chain::MCMCIterator) = DensitySampleVector(sample_type(chain), totalndof(getmeasure(chain))) + + + +abstract type AbstractMCMCTunerInstance end + + +function tuning_init! end + +function tuning_postinit! end + +function tuning_reinit! end + +function tuning_update! end + +function tuning_finalize! end + +function tuning_callback end + + +function mcmc_init! end + +function mcmc_burnin! end + + +function isvalidchain end + +function isviablechain end + + + +function mcmc_iterate! end + +""" + BAT.TransformedMCMCSampleGenerator + +*BAT-internal, not part of stable public API.* + +MCMC sample generator. + +Constructors: + +```julia +TransformedMCMCSampleGenerator(chain::AbstractVector{<:MCMCIterator}) +``` +""" +struct TransformedMCMCSampleGenerator{ + T<:AbstractVector{<:MCMCIterator}, + A<:AbstractSamplingAlgorithm, +} <: AbstractSampleGenerator + chains::T + algorithm::A +end + +getalgorithm(sg::TransformedMCMCSampleGenerator) = sg.algorithm + +function Base.show(io::IO, generator::TransformedMCMCSampleGenerator) + if get(io, :compact, false) + print(io, nameof(typeof(generator)), "(") + if !isempty(generator.chains) + show(io, first(generator.chains)) + print(io, ", …") + end + print(io, ")") + else + println(io, nameof(typeof(generator)), ":") + chains = generator.chains + nchains = length(chains) + n_tuned_chains = count(c -> c.info.tuned, chains) + n_converged_chains = count(c -> c.info.converged, chains) + print(io, "algorithm: ") + show(io, "text/plain", getalgorithm(generator)) + println(io) + println(io, "number of chains:", repeat(' ', 12), nchains) + println(io, "number of chains tuned:", repeat(' ', 6), n_tuned_chains) + println(io, "number of chains converged:", repeat(' ', 2), n_converged_chains) + println(io, "number of points…") + println(io, repeat(' ',10), "… in 1th chain:", repeat(' ', 4), nsamples(first(chains))) + print(io, repeat(' ',10), "… on average:", repeat(' ', 6), div(sum(nsamples.(chains)), nchains)) + end +end + + +function bat_report!(md::Markdown.MD, generator::TransformedMCMCSampleGenerator) + mcalg = getalgorithm(generator) + chains = generator.chains + nchains = length(chains) + n_tuned_chains = count(c -> c.info.tuned, chains) + n_converged_chains = count(c -> c.info.converged, chains) + + markdown_append!(md, """ + ### Sample generation + + * Algorithm: MCMC, $(nameof(typeof(mcalg))) + * MCMC chains: $nchains ($n_tuned_chains tuned, $n_converged_chains converged) + """) + + return md +end diff --git a/src/samplers/transformed_mcmc/mcmc_convergence.jl b/src/samplers/transformed_mcmc/mcmc_convergence.jl new file mode 100644 index 000000000..9036f5b27 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_convergence.jl @@ -0,0 +1,165 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +function check_convergence!( + chains::AbstractVector{<:MCMCIterator}, + samples::AbstractVector{<:DensitySampleVector}, + algorithm::ConvergenceTest, +) + result = convert(Bool, bat_convergence(samples, algorithm).result) + for chain in chains + chain.info = TransformedMCMCIteratorInfo(chain.info, converged = result) + end + result +end + + + +""" + gr_Rsqr(stats::AbstractVector{<:TransformedMCMCBasicStats}) + gr_Rsqr(samples::AbstractVector{<:DensitySampleVector}) + +*BAT-internal, not part of stable public API.* + +Gelman-Rubin ``\$R^2\$`` for all DOF. +""" +function gr_Rsqr end + +function gr_Rsqr(stats::AbstractVector{<:TransformedMCMCBasicStats}) + m = totalndof(first(stats)) + W = mean([cs.param_stats.cov[i,i] for cs in stats, i in 1:m], dims=1)[:] + B = var([cs.param_stats.mean[i] for cs in stats, i in 1:m], dims=1)[:] + (W .+ B) ./ W +end + +function gr_Rsqr(samples::AbstractVector{<:DensitySampleVector}) + gr_Rsqr(TransformedMCMCBasicStats.(samples)) +end + + + +""" + struct TransformedGelmanRubinConvergence <: ConvergenceTest + +Gelman-Rubin maximum R^2 convergence test. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedGelmanRubinConvergence <: ConvergenceTest + threshold::Float64 = 1.1 +end + +export TransformedGelmanRubinConvergence + +function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedGelmanRubinConvergence) + max_Rsqr = maximum(gr_Rsqr(samples)) + vt = ValueAndThreshold{max_Rsqr}(max_Rsqr, <=, algorithm.threshold) + converged = convert(Bool, vt) + @debug begin + success_str = converged ? "have" : "have *not*" + "Chains $success_str converged, max(R^2) = $(vt.value), threshold = $(vt.threshold)" + end + (result = vt,) +end + + + +@doc doc""" + bg_R_2sqr(stats::AbstractVector{<:TransformedMCMCBasicStats}; corrected::Bool = false) + bg_R_2sqr(samples::AbstractVector{<:DensitySampleVector}; corrected::Bool = false) + +*BAT-internal, not part of stable public API.* + +Brooks-Gelman R_2^2 for all DOF. +If normality is assumed, 'corrected' should be set to true to account for the sampling variability. +""" +function bg_R_2sqr(stats::AbstractVector{<:TransformedMCMCBasicStats}; corrected::Bool = false) + p = totalndof(first(stats)) + m = length(stats) + n = mean(Float64.(nsamples.(stats))) + + σ_W = var([cs.param_stats.cov[i,i] for cs in stats, i in 1:p], dims = 1)[:] + B = var([cs.param_stats.mean[i] for cs in stats, i in 1:p], dims = 1)[:] + W = mean([cs.param_stats.cov[i,i] for cs in stats, i in 1:p], dims = 1)[:] + + σ_sq = m * (n - 1) / (m*n - 1) * W + n * (m - 1) / (m*n - 1) * B + + R_unc = σ_sq ./ W + + if corrected == false + return R_unc + end + + σ_ij = [cs.param_stats.cov[i,i] for cs in stats, i in 1:p] + x_ij = [cs.param_stats.mean[i] for cs in stats, i in 1:p] + + cov_σx = [cov(σ_ij[:,j], x_ij[:,j]) for j in 1:p] + cov_σx_sq = [cov(σ_ij[:,j], x_ij[:,j].^2) for j in 1:p] + + N = (n-1)/n + M = (m-1)/m + V = N*σ_sq + M*B + + σ_V = N^2/m*σ_W + 2*M/(m-1)*B.^2 + 2*M*N/m*(cov_σx_sq - 2*B.*cov_σx) + d = 2 * V.^2 ./ σ_V + + R_unc.*(d.+3)./(d.+1) +end + +function bg_R_2sqr(samples::AbstractVector{<:DensitySampleVector}; corrected::Bool = false) + bg_R_2sqr(TransformedMCMCBasicStats.(samples), corrected = corrected) +end + + + +""" + struct TransformedBrooksGelmanConvergence <: ConvergenceTest + +Brooks-Gelman maximum R^2 convergence test. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedBrooksGelmanConvergence <: ConvergenceTest + threshold::Float64 = 1.1 + corrected::Bool = false +end + +export TransformedBrooksGelmanConvergence + +function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedBrooksGelmanConvergence) + max_Rsqr = maximum(bg_R_2sqr(samples, corrected = algorithm.corrected)) + vt = ValueAndThreshold{max_Rsqr}(max_Rsqr, <=, algorithm.threshold) + converged = convert(Bool, vt) + @debug begin + success_str = converged ? "have" : "have *not*" + "Chains $success_str converged, max(R^2) = $(vt.value), threshold = $(vt.threshold)" + end + (result = vt,) +end + + + +function bat_convergence_impl(samples::DensitySampleVector, algorithm::Union{TransformedGelmanRubinConvergence, TransformedBrooksGelmanConvergence}) + # create a vector of chains + chains_ind = unique([i.chainid for i in samples.info]) + vector_chains = DensitySampleVector[] + # ToDo: Improve implementation + for i in chains_ind + mask_chain = [j.chainid == i for j in samples.info] + push!(vector_chains, samples[mask_chain]) + end + + bat_convergence_impl(vector_chains, algorithm) +end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl new file mode 100644 index 000000000..5f7028c6d --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -0,0 +1,329 @@ +mutable struct TransformedMCMCIterator{ + R<:AbstractRNG, + PR<:RNGPartition, + D<:BATMeasure, + F, + Q<:MCMCProposal, + SV<:DensitySampleVector, + S<:DensitySample, +} <: MCMCIterator + rng::R + rngpart_cycle::PR + μ::D + f_transform::F + proposal::Q + samples::SV + sample_z::S + stepno::Int + n_accepted::Int + info::TransformedMCMCIteratorInfo +end + +getmeasure(chain::TransformedMCMCIterator) = chain.μ + +getrng(chain::TransformedMCMCIterator) = chain.rng + +mcmc_info(chain::TransformedMCMCIterator) = chain.info + +nsteps(chain::TransformedMCMCIterator) = chain.stepno + +nsamples(chain::TransformedMCMCIterator) = size(chain.samples, 1) + +current_sample(chain::TransformedMCMCIterator) = last(chain.samples) + +sample_type(chain::TransformedMCMCIterator) = eltype(chain.samples) + +samples_available(chain::TransformedMCMCIterator) = size(chain.samples,1) > 0 + +isvalidchain(chain::TransformedMCMCIterator) = current_sample(chain).logd > -Inf + +isviablechain(chain::TransformedMCMCIterator) = nsamples(chain) >= 2 + +eff_acceptance_ratio(chain::TransformedMCMCIterator) = nsamples(chain) / chain.stepno + + + +#ctor +function TransformedMCMCIterator( + rng::AbstractRNG, + algorithm::TransformedMCMCSampling, + target, + id::Integer, + v_init::AbstractVector{<:Real} +) + TransformedMCMCIterator(rng, algorithm, target, Int32(id), v_init) +end + + +#ctor +function TransformedMCMCIterator( + rng::AbstractRNG, + algorithm::TransformedMCMCSampling, + target, + id::Int32, + v_init::AbstractVector{<:Real}, +) + rngpart_cycle = RNGPartition(rng, 0:(typemax(Int16) - 2)) + + μ = target + proposal = algorithm.proposal + stepno = 1 + cycle = 1 + n_accepted = 0 + + adaptive_transform_spec = algorithm.adaptive_transform + g = init_adaptive_transform(rng, adaptive_transform_spec, μ) + + logd_x = logdensityof(μ, v_init) + sample_x = DensitySample(v_init, logd_x, 1, TransformedMCMCSampleID(id, 1, 0), nothing) # TODO + inverse_g = inverse(g) + z = inverse_g(v_init) # sample_x.v + logd_z = logdensityof(MeasureBase.pullback(g, μ),z) + sample_z = _rebuild_density_sample(sample_x, z, logd_z) + + samples = DensitySampleVector(([sample_x.v], [sample_x.logd], [sample_x.weight], [sample_x.info], [sample_x.aux] )) + + iter = TransformedMCMCIterator( + rng, + rngpart_cycle, + target, + g, + proposal, + samples, + sample_z, + stepno, + n_accepted, + TransformedMCMCIteratorInfo(id, cycle, false, false) + ) + + +end + + + +function _rebuild_density_sample(s::DensitySample, x, logd, weight=1) + @unpack info, aux = s + DensitySample(x, logd, weight, info, aux) +end + + + +function propose_mcmc( + iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:Any, <:TransformedMHProposal} +) + @unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter + sample_x = last(samples) + x, logd_x = sample_x.v, sample_x.logd + z, logd_z = sample_z.v, sample_z.logd + + n = size(z, 1) + z_proposed = z + rand(rng, proposal.proposal_dist, n) #TODO: check if proposal is symmetric? otherwise need additional factor? + x_proposed, ladj = with_logabsdet_jacobian(f_transform, z_proposed) + logd_x_proposed = BAT.checked_logdensityof(μ, x_proposed) + logd_z_proposed = logd_x_proposed + ladj + @assert logd_z_proposed ≈ logdensityof(MeasureBase.pullback(f_transform, μ), z_proposed) #TODO: remove + + + # TODO AC: do we need to check symmetry of proposal distribution? + # T = typeof(logd_z) + # p_accept = if logd_z_proposed > -Inf + # # log of ratio of forward/reverse transition probability + # log_tpr = if issymmetric(proposal.proposal_dist) + # T(0) + # else + # log_tp_fwd = proposaldist_logpdf(proposaldist, proposed_params, current_params) + # log_tp_rev = proposaldist_logpdf(proposaldist, current_params, proposed_params) + # T(log_tp_fwd - log_tp_rev) + # end + + # p_accept_unclamped = exp(proposed_log_posterior - current_log_posterior - log_tpr) + # T(clamp(p_accept_unclamped, 0, 1)) + # else + # zero(T) + # end + + p_accept = clamp(exp(logd_z_proposed-logd_z), 0, 1) + + sample_z_proposed = _rebuild_density_sample(sample_z, z_proposed, logd_z_proposed) + sample_x_proposed = _rebuild_density_sample(sample_x, x_proposed, logd_x_proposed) + + return sample_x_proposed, sample_z_proposed, p_accept +end + + + +function transformed_mcmc_step!!( + iter::TransformedMCMCIterator, + tuner::AbstractMCMCTunerInstance, + tempering::MCMCTemperingInstance, +) + @unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter + sample_x = last(samples) + x, logd_x = sample_x.v, sample_x.logd + z, logd_z = sample_z.v, sample_z.logd + @unpack n_accepted, stepno = iter + + sample_x_proposed, sample_z_proposed, p_accept = propose_mcmc(iter) + + z_proposed, logd_z_proposed = sample_z_proposed.v, sample_z_proposed.logd + x_proposed, logd_x_proposed = sample_x_proposed.v, sample_x_proposed.logd + + tuner_new, f_transform = tune_mcmc_transform!!(rng, tuner, f_transform, p_accept, z_proposed, z, stepno) + + accepted = rand(rng) <= p_accept + + # f_transform may have changed + inverse_f = inverse(f_transform) + x_new, z_new, logd_x_new, logd_z_new = if accepted + x_proposed, inverse_f(x_proposed), logd_x_proposed, logd_z_proposed + else + x, inverse_f(x), logd_x, logd_z + end + + sample_x_new, sample_z_new, samples_new = if accepted + sample_x_new = DensitySample(x_new, logd_x_new, 1, TransformedMCMCSampleID(iter.info.id, iter.info.cycle, iter.stepno), nothing) + push!(samples, sample_x_new) + sample_x_new, _rebuild_density_sample(sample_z, z_new, logd_z_new), samples + else + samples.weight[end] += 1 + _rebuild_density_sample(sample_x, x_new, logd_x_new, sample_x.weight+1), _rebuild_density_sample(sample_z, z_new, logd_z_new), samples + end + + tempering_new, μ_new = temper_mcmc_target!!(tempering, μ, stepno) + + f_new = f_transform + + # iter_new = TransformedMCMCIterator(rng, μ_new, f_new, proposal, samples_new, sample_z_new, stepno, n_accepted+Int(accepted)) + iter.rng = rng + iter.μ, iter.f_transform, iter.samples, iter.sample_z = μ_new, f_new, samples_new, sample_z_new + iter.n_accepted += Int(accepted) + iter.stepno += 1 + + return (iter, tuner_new, tempering_new) +end + + + +function transformed_mcmc_iterate!( + chain::TransformedMCMCIterator, + tuner::AbstractMCMCTunerInstance, + tempering::MCMCTemperingInstance; + max_nsteps::Integer = 1, + max_time::Real = Inf, + nonzero_weights::Bool = true, + callback::Function = nop_func, +) + @debug "Starting iteration over MCMC chain $(mcmc_info(chain).id) with $max_nsteps steps in max. $(@sprintf "%.1f seconds." max_time)" + + start_time = time() + last_progress_message_time = start_time + start_nsteps = nsteps(chain) + start_nsteps = nsteps(chain) + + while ( + (nsteps(chain) - start_nsteps) < max_nsteps && + (time() - start_time) < max_time + ) + transformed_mcmc_step!!(chain, tuner, tempering) + callback(Val(:mcmc_step), chain) + + #TODO: output schemes + + current_time = time() + elapsed_time = current_time - start_time + logging_interval = 5 * round(log2(elapsed_time/60 + 1) + 1) + if current_time - last_progress_message_time > logging_interval + last_progress_message_time = current_time + @debug "Iterating over MCMC chain $(mcmc_info(chain).id), completed $(nsteps(chain) - start_nsteps) (of $(max_nsteps)) steps and produced $(nsteps(chain) - start_nsteps) samples in $(@sprintf "%.1f s" elapsed_time) so far." + end + end + + current_time = time() + elapsed_time = current_time - start_time + @debug "Finished iteration over MCMC chain $(mcmc_info(chain).id), completed $(nsteps(chain) - start_nsteps) steps and produced $(nsteps(chain) - start_nsteps) samples in $(@sprintf "%.1f s" elapsed_time)." + + return nothing +end + + +function transformed_mcmc_iterate!( + chain::MCMCIterator, + tuner::AbstractMCMCTunerInstance, + tempering::MCMCTemperingInstance; + # tuner::AbstractMCMCTunerInstance; + max_nsteps::Integer = 1, + max_time::Real = Inf, + nonzero_weights::Bool = true, + callback::Function = nop_func +) + cb = callback# combine_callbacks(tuning_callback(tuner), callback) #TODO CA: tuning_callback + + transformed_mcmc_iterate!( + chain, tuner, tempering, + max_nsteps = max_nsteps, max_time = max_time, nonzero_weights = nonzero_weights, callback = cb + ) + + return nothing +end + + +function transformed_mcmc_iterate!( + chains::AbstractVector{<:MCMCIterator}, + tuners::AbstractVector{<:AbstractMCMCTunerInstance}, + temperers::AbstractVector{<:MCMCTemperingInstance}; + kwargs... +) + if isempty(chains) + @debug "No MCMC chain(s) to iterate over." + return chains + else + @debug "Starting iteration over $(length(chains)) MCMC chain(s)" + end + + @sync for i in eachindex(chains, tuners, temperers) + Base.Threads.@spawn transformed_mcmc_iterate!(chains[i], tuners[i], temperers[i]#= , tnrs[i] =#; kwargs...) + end + + return nothing +end + + +function reset_chain( + rng::AbstractRNG, + chain::TransformedMCMCIterator, +) + rngpart_cycle = RNGPartition(rng, 0:(typemax(Int16) - 2)) + #TODO reset cycle count? + chain.rngpart_cycle = rngpart_cycle + chain.info = TransformedMCMCIteratorInfo(chain.info, cycle=0) + # wants a next_cycle! + # reset_rng_counters!(chain) +end + + +function reset_rng_counters!(chain::TransformedMCMCIterator) + set_rng!(chain.rng, chain.rngpart_cycle, chain.info.cycle) + rngpart_step = RNGPartition(chain.rng, 0:(typemax(Int32) - 2)) + set_rng!(chain.rng, rngpart_step, chain.stepno) + nothing +end + + +function next_cycle!( + chain::TransformedMCMCIterator, + +) + chain.info = TransformedMCMCIteratorInfo(chain.info, cycle = chain.info.cycle + 1) + chain.stepno = 0 + + reset_rng_counters!(chain) + + chain.samples[1] = last(chain.samples) + resize!(chain.samples, 1) + + chain.samples.weight[1] = 1 + chain.samples.info[1] = TransformedMCMCSampleID(chain.info.id, chain.info.cycle, chain.stepno) + + chain +end + diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl new file mode 100644 index 000000000..b09c2972d --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -0,0 +1,159 @@ +abstract type MCMCProposal end +""" + BAT.TransformedMHProposal + +*BAT-internal, not part of stable public API.* +""" +struct TransformedMHProposal{ + D<:Union{Distribution, AbstractMeasure} +}<: MCMCProposal + proposal_dist::D +end + + +# TODO AC: find a better solution for this. Problem is that in the with_kw constructor below, we need to dispatch on this type. +struct TransformedMCMCDispatch end + +@with_kw struct TransformedMCMCSampling{ + TR<:AbstractTransformTarget, + IN<:MCMCInitAlgorithm, + BI<:MCMCBurninAlgorithm, + CT<:ConvergenceTest, + CB<:Function +} <: AbstractSamplingAlgorithm + pre_transform::TR = bat_default(TransformedMCMCDispatch, Val(:pre_transform)) + tuning_alg::MCMCTuningAlgorithm = TransformedRAMTuner() # TODO: use bat_defaults + adaptive_transform::AdaptiveTransformSpec = default_adaptive_transform(tuning_alg) + proposal::MCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults + tempering = TransformedNoMCMCTempering() # TODO: use bat_defaults + nchains::Int = 4 + nsteps::Int = 10^5 + #TODO: max_time ? + init::IN = bat_default(TransformedMCMCDispatch, Val(:init), pre_transform, nchains, nsteps) #TransformedMCMCChainPoolInit()#TODO AC: use bat_defaults bat_default(MCMCSampling, Val(:init), MetropolisHastings(), pre_transform, nchains, nsteps) #TODO + burnin::BI = bat_default(TransformedMCMCDispatch, Val(:burnin), pre_transform, nchains, nsteps) + convergence::CT = TransformedBrooksGelmanConvergence() + strict::Bool = true + store_burnin::Bool = false + nonzero_weights::Bool = true + callback::CB = nop_func +end + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:pre_transform}) = PriorToGaussian() + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:nsteps}, trafo::AbstractTransformTarget, nchains::Integer) = 10^5 + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:init}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = + TransformedMCMCChainPoolInit(nsteps_init = max(div(nsteps, 100), 250)) + +bat_default(::Type{TransformedMCMCDispatch}, ::Val{:burnin}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = + TransformedMCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 2500)) + + + +function bat_sample_impl( + rng::AbstractRNG, + target::AnyMeasureOrDensity, + algorithm::TransformedMCMCSampling +) + density_notrafo = convert(AbstractMeasureOrDensity, target) + density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo) + + init = mcmc_init!( + rng, + algorithm, + density, + algorithm.nchains, + apply_trafo_to_init(trafo, algorithm.init), + algorithm.tuning_alg, + algorithm.nonzero_weights, + algorithm.store_burnin ? algorithm.callback : nop_func + ) + + @unpack chains, tuners, temperers = init + + # output_init = reduce(vcat, getproperty(chains, :samples)) + + burnin_outputs_coll = if algorithm.store_burnin + DensitySampleVector(first(chains)) + else + nothing + end + + # burnin and tuning + mcmc_burnin!( + burnin_outputs_coll, + chains, + tuners, + temperers, + algorithm.burnin, + algorithm.convergence, + algorithm.strict, + algorithm.nonzero_weights, + algorithm.store_burnin ? algorithm.callback : nop_func + ) + + # sampling + run_sampling = _run_sample_impl( + density, + algorithm, + chains, + ) + samples_trafo, generator = run_sampling.result_trafo, run_sampling.generator + + # prepend burnin samples to output + if algorithm.store_burnin + burnin_samples_trafo = varshape(density).(burnin_outputs_coll) + append!(burnin_samples_trafo, samples_trafo) + samples_trafo = burnin_samples_trafo + end + + samples_notrafo = inverse(trafo).(samples_trafo) + + + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) +end + +function _bat_sample_continue( + target::AnyMeasureOrDensity, + generator::TransformedMCMCSampleGenerator, + ;description::AbstractString = "MCMC iterate" +) + @unpack algorithm, chains = generator + density_notrafo = convert(AbstractMeasureOrDensity, target) + density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo) + + run_sampling = _run_sample_impl(density, algorithm, chains, description=description) + + samples_trafo, generator = run_sampling.result_trafo, run_sampling.generator + + samples_notrafo = inverse(trafo).(samples_trafo) + + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) +end + +function _run_sample_impl( + density::AnyMeasureOrDensity, + algorithm::TransformedMCMCSampling, + chains::AbstractVector{<:MCMCIterator}, + ;description::AbstractString = "MCMC iterate" +) + next_cycle!.(chains) + + progress_meter = ProgressMeter.Progress(algorithm.nchains*algorithm.nsteps, desc=description, barlen=80-length(description), dt=0.1) + + # tuners are set to 'NoOpTuner' for the sampling phase + transformed_mcmc_iterate!( + chains, + get_tuner.(Ref(TransformedMCMCNoOpTuning()),chains), + get_temperer.(Ref(TransformedNoMCMCTempering()), chains), + max_nsteps = algorithm.nsteps, #TODO: maxtime + nonzero_weights = algorithm.nonzero_weights, + callback = (kwargs...) -> let pm=progress_meter; ProgressMeter.next!(pm) ; end, + ) + ProgressMeter.finish!(progress_meter) + + output = reduce(vcat, getproperty.(chains, :samples)) + samples_trafo = varshape(density).(output) + + (result_trafo = samples_trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) +end diff --git a/src/samplers/transformed_mcmc/mcmc_sampleid.jl b/src/samplers/transformed_mcmc/mcmc_sampleid.jl new file mode 100644 index 000000000..1dfb92c3f --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_sampleid.jl @@ -0,0 +1,61 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +abstract type SampleID end + +struct TransformedMCMCSampleID{ + T<:Int32, + U<:Int64, +} <: SampleID + chainid::T + chaincycle::T + stepno::U +end + +function TransformedMCMCSampleID( + chainid::Integer, + chaincycle::Integer, + stepno::Integer, +) + TransformedMCMCSampleID(Int32(chainid), Int32(chaincycle), Int64(stepno)) +end + +const TransformedMCMCSampleIDVector{TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} = StructArray{ + TransformedMCMCSampleID, + 1, + NamedTuple{(:chainid, :chaincycle, :stepno), Tuple{TV,TV,UV}}, + Int +} + + +function TransformedMCMCSampleIDVector(contents::Tuple{TV,TV,UV}) where {TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} + StructArray{TransformedMCMCSampleID}(contents)::TransformedMCMCSampleIDVector{TV,UV} +end + +TransformedMCMCSampleIDVector(::UndefInitializer, len::Integer) = TransformedMCMCSampleIDVector(( + Vector{Int32}(undef, len), Vector{Int32}(undef, len), + Vector{Int64}(undef, len) +)) + +TransformedMCMCSampleIDVector() = TransformedMCMCSampleIDVector(undef, 0) + + +_create_undef_vector(::Type{TransformedMCMCSampleID}, len::Integer) = TransformedMCMCSampleIDVector(undef, len) + + +# Specialize comparison, currently StructArray seems fall back to `(==)(A::AbstractArray, B::AbstractArray)` +import Base.== +function(==)(A::TransformedMCMCSampleIDVector, B::TransformedMCMCSampleIDVector) + A.chainid == B.chainid && + A.chaincycle == B.chaincycle && + A.stepno == B.stepno +end + + +function Base.merge!(X::TransformedMCMCSampleIDVector, Xs::TransformedMCMCSampleIDVector...) + for Y in Xs + append!(X, Y) + end + X +end + +Base.merge(X::TransformedMCMCSampleIDVector, Xs::TransformedMCMCSampleIDVector...) = merge!(deepcopy(X), Xs...) diff --git a/src/samplers/transformed_mcmc/mcmc_stats.jl b/src/samplers/transformed_mcmc/mcmc_stats.jl new file mode 100644 index 000000000..cf57675d3 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_stats.jl @@ -0,0 +1,122 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +abstract type AbstractMCMCStats end +AbstractMCMCStats + + + +struct TransformedMCMCNullStats <: AbstractMCMCStats end + + +Base.push!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats + +Base.append!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats + + + +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats + param_stats::BasicMvStatistics{P,FrequencyWeights} + logtf_stats::BasicUvStatistics{L,FrequencyWeights} + mode::Vector{P} + + function TransformedMCMCBasicStats{L,P}(m::Integer) where {L<:Real,P<:Real} + param_stats = BasicMvStatistics{P,FrequencyWeights}(m) + logtf_stats = BasicUvStatistics{L,FrequencyWeights}() + mode = fill(P(NaN), m) + + new{L,P}( + param_stats, + logtf_stats, + mode + ) + end +end + + +function TransformedMCMCBasicStats(::Type{S}, ndof::Integer) where { + PT<:Real, T, W, S<:DensitySample{<:AbstractVector{PT},T,W} +} + SL = promote_type(T, Float64) + SP = promote_type(PT, W, Float64) + TransformedMCMCBasicStats{SL,SP}(ndof) +end + +TransformedMCMCBasicStats(chain::MCMCIterator) = TransformedMCMCBasicStats(sample_type(chain), totalndof(getmeasure(chain))) + +function TransformedMCMCBasicStats(sv::DensitySampleVector{<:AbstractVector{<:Real}}) + stats = TransformedMCMCBasicStats(eltype(sv), innersize(sv.v, 1)) + append!(stats, sv) +end + +TransformedMCMCBasicStats(sv::DensitySampleVector) = TransformedMCMCBasicStats(unshaped.(sv)) + + +function Base.empty!(stats::TransformedMCMCBasicStats) + empty!(stats.param_stats) + empty!(stats.logtf_stats) + fill!(stats.mode, eltype(stats.mode)(NaN)) + + stats +end + + +function Base.push!(stats::TransformedMCMCBasicStats, s::DensitySample) + push!(stats.param_stats, s.v, s.weight) + if s.logd > stats.logtf_stats.maximum + stats.mode .= s.v + end + push!(stats.logtf_stats, s.logd, s.weight) + stats +end + + +function Base.append!(stats::TransformedMCMCBasicStats, sv::DensitySampleVector) + for i in eachindex(sv) + p = sv.v[i] + w = sv.weight[i] + l = sv.logd[i] + push!(stats.param_stats, p, w) # Memory allocation (view)! + if sv.logd[i] > stats.logtf_stats.maximum + stats.mode .= p # Memory allocation (view)! + end + push!(stats.logtf_stats, l, w) + stats + end + stats +end + + +ValueShapes.totalndof(stats::TransformedMCMCBasicStats) = stats.param_stats.m + +nsamples(stats::TransformedMCMCBasicStats) = stats.param_stats.cov.sum_w + +function Base.merge!(target::TransformedMCMCBasicStats, others::TransformedMCMCBasicStats...) + for x in others + if (x.logtf_stats.maximum > target.logtf_stats.maximum) + target.mode .= x.mode + end + merge!(target.param_stats, x.param_stats) + merge!(target.logtf_stats, x.logtf_stats) + end + target +end + +Base.merge(a::TransformedMCMCBasicStats, bs::TransformedMCMCBasicStats...) = merge!(deepcopy(a), bs...) + + +function reweight_relative!(stats::TransformedMCMCBasicStats, reweighting_factor::Real) + reweight_relative!(stats.param_stats, reweighting_factor) + reweight_relative!(stats.logtf_stats, reweighting_factor) + + stats +end + + +function _bat_stats(mcmc_stats::TransformedMCMCBasicStats) + ( + mode = mcmc_stats.mode, + mean = mcmc_stats.param_stats.mean, + cov = mcmc_stats.param_stats.cov + ) +end diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl new file mode 100644 index 000000000..c5a1cca01 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -0,0 +1,55 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +""" + TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm + +No-op tuning, marks MCMC chains as tuned without performing any other changes +on them. Useful if chains are pre-tuned or tuning is an internal part of the +MCMC sampler implementation. +""" +struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end +export TransformedMCMCNoOpTuning + + + +struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end + +(tuning::TransformedMCMCNoOpTuning)(chain::MCMCIterator) = TransformedMCMCNoOpTuner() +get_tuner(tuning::TransformedMCMCNoOpTuning, chain::MCMCIterator) = TransformedMCMCNoOpTuner() + + +function TransformedMCMCNoOpTuning(tuning::TransformedMCMCNoOpTuning, chain::MCMCIterator) + TransformedMCMCNoOpTuner() +end + + +function tuning_init!(tuner::TransformedMCMCNoOpTuning, chain::MCMCIterator, max_nsteps::Integer) + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = true) + nothing +end + + + +function tune_mcmc_transform!!( + rng::AbstractRNG, + tuner::TransformedMCMCNoOpTuner, + transform, + p_accept::Real, + z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead + z_current::Vector{<:Float64}, + stepno::Int +) + return (tuner, transform) + +end + +tuning_postinit!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator, samples::DensitySampleVector) = nothing + +tuning_reinit!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator, max_nsteps::Integer) = nothing + +tuning_update!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator, samples::DensitySampleVector) = nothing + +tuning_finalize!(tuner::TransformedMCMCNoOpTuner, chain::MCMCIterator) = nothing + +tuning_callback(::TransformedMCMCNoOpTuning) = nop_func diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl new file mode 100644 index 000000000..93772cdf6 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -0,0 +1,145 @@ +@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm + "Controls the weight given to new covariance information in adapting the + proposal distribution." + λ::Float64 = 0.5 + + "Metropolis-Hastings acceptance ratio target, tuning will try to adapt + the proposal distribution to bring the acceptance ratio inside this interval." + α::IntervalSets.ClosedInterval{Float64} = ClosedInterval(0.15, 0.35) + + "Controls how much the spread of the proposal distribution is + widened/narrowed depending on the current MH acceptance ratio." + β::Float64 = 1.5 + + "Interval for allowed scale/spread of the proposal distribution." + c::IntervalSets.ClosedInterval{Float64} = ClosedInterval(1e-4, 1e2) + + "Reweighting factor. Take accumulated sample statistics of previous + tuning cycles into account with a relative weight of `r`. Set to + `0` to completely reset sample statistics between each tuning cycle." + r::Real = 0.5 +end + +mutable struct TransformedProposalCovTuner{ + S<:TransformedMCMCBasicStats +} <: AbstractMCMCTunerInstance + config::TransformedAdaptiveMHTuning + stats::S + iteration::Int + scale::Float64 +end + + +function TransformedProposalCovTuner(tuning::TransformedAdaptiveMHTuning, chain::MCMCIterator) + m = totalndof(getmeasure(chain)) + scale = 2.38^2 / m + TransformedProposalCovTuner(tuning, TransformedMCMCBasicStats(chain), 1, scale) +end + +get_tuner(tuning::TransformedAdaptiveMHTuning, chain::MCMCIterator) = TransformedProposalCovTuner(tuning, chain) +default_adaptive_transform(tuner::TransformedAdaptiveMHTuning) = TriangularAffineTransform() + + +function tuning_init!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, max_nsteps::Integer) + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) + + nothing +end + +tuning_reinit!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, max_nsteps::Integer) = nothing + + +function tuning_postinit!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, samples::DensitySampleVector) + # The very first samples of a chain can be very valuable to init tuner + # stats, especially if the chain gets stuck early after: + stats = tuner.stats + append!(stats, samples) +end + +# this function is called once after each tuning cycle +g_state = nothing +function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, samples::DensitySampleVector) + global g_state = (;tuner, chain) + + stats = tuner.stats + stats_reweight_factor = tuner.config.r + reweight_relative!(stats, stats_reweight_factor) + # empty!.(stats) + append!(stats, samples) + + + config = tuner.config + + α_min = minimum(config.α) + α_max = maximum(config.α) + + c_min = minimum(config.c) + c_max = maximum(config.c) + + β = config.β + + t = tuner.iteration + λ = config.λ + c = tuner.scale + + transform = chain.f_transform + + + #TODO AC: check with Oli + S_L = transform.A + Σ_old = S_L + + S = convert(Array, stats.param_stats.cov) + a_t = 1 / t^λ + new_Σ_unscal = (1 - a_t) * (Σ_old/c) + a_t * S + + α = eff_acceptance_ratio(chain) + + max_log_posterior = stats.logtf_stats.maximum + + if α_min <= α <= α_max + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = true) + @debug "MCMC chain $(chain.info.id) tuned, acceptance ratio = $(Float32(α)), proposal scale = $(Float32(c)), max. log posterior = $(Float32(max_log_posterior))" + else + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) + @debug "MCMC chain $(chain.info.id) *not* tuned, acceptance ratio = $(Float32(α)), proposal scale = $(Float32(c)), max. log posterior = $(Float32(max_log_posterior))" + + if α > α_max && c < c_max + tuner.scale = c * β + elseif α < α_min && c > c_min + tuner.scale = c / β + end + end + + Σ_new = new_Σ_unscal * tuner.scale + #TODO AC: check + S = cholesky(Positive, Σ_new) + chain.f_transform = Mul(S.L) + tuner.iteration += 1 + + nothing + +end + + +tuning_finalize!(tuner::TransformedProposalCovTuner, chain::MCMCIterator) = nothing + +tuning_callback(::TransformedProposalCovTuner) = nop_func + +# default_adaptive_transform(tuner::TransformedProposalCovTuner) = TriangularAffineTransform() + + +# this function is called in each mcmc_iterate step during tuning +function tune_mcmc_transform!!( + rng::AbstractRNG, + tuner::TransformedProposalCovTuner, + transform::Mul{<:LowerTriangular}, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, + p_accept::Real, + z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead + z_current::Vector{<:Float64}, + stepno::Int +) + + return (tuner, transform) +end + diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl new file mode 100644 index 000000000..f3ea576e2 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -0,0 +1,89 @@ +@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning + target_acceptance::Float64 = 0.234 #TODO AC: how to pass custom intitial value for cov matrix? + σ_target_acceptance::Float64 = 0.05 + gamma::Float64 = 2/3 +end + +@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance + config::TransformedRAMTuner + nsteps::Int = 0 +end +TransformedRAMTunerInstance(ram::TransformedRAMTuner) = TransformedRAMTunerInstance(config = ram) + +get_tuner(tuning::TransformedRAMTuner, chain::MCMCIterator) = TransformedRAMTunerInstance(tuning) + + +function tuning_init!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, max_nsteps::Integer) + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) # TODO ? + tuner.nsteps = 0 + + return nothing +end + + +tuning_postinit!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, samples::DensitySampleVector) = nothing + +# TODO AC: is this still needed? +# function tuning_postinit!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, samples::DensitySampleVector) +# # The very first samples of a chain can be very valuable to init tuner +# # stats, especially if the chain gets stuck early after: +# stats = tuner.stats +# append!(stats, samples) +# end + +tuning_reinit!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, max_nsteps::Integer) = nothing + + + + + +function tuning_update!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator, samples::DensitySampleVector) + α_min, α_max = map(op -> op(1, tuner.config.σ_target_acceptance), [-,+]) .* tuner.config.target_acceptance + α = eff_acceptance_ratio(chain) + + max_log_posterior = maximum(samples.logd) + + if α_min <= α <= α_max + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = true) + @debug "MCMC chain $(chain.info.id) tuned, acceptance ratio = $(Float32(α)), max. log posterior = $(Float32(max_log_posterior))" + else + chain.info = TransformedMCMCIteratorInfo(chain.info, tuned = false) + @debug "MCMC chain $(chain.info.id) *not* tuned, acceptance ratio = $(Float32(α)), max. log posterior = $(Float32(max_log_posterior))" + end +end + +tuning_finalize!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator) = nothing + +# tuning_callback(::TransformedRAMTuner) = nop_func + + + +default_adaptive_transform(tuner::TransformedRAMTuner) = TriangularAffineTransform() + +function tune_mcmc_transform!!( + rng::AbstractRNG, + tuner::TransformedRAMTunerInstance, + transform::Mul{<:LowerTriangular}, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, + p_accept::Real, + z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead + z_current::Vector{<:Float64}, + stepno::Int +) + @unpack target_acceptance, gamma = tuner.config + n = size(z_current,1) + η = min(1, n * stepno^(-gamma)) + + s_L = transform.A + + u = z_proposed-z_current + M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2 ) * s_L' + + S = cholesky(Positive, M) + transform_new = Mul(S.L) + + tuner.nsteps += 1 + + return (tuner, transform_new) +end + + diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_tuning.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_tuning.jl new file mode 100644 index 000000000..c1c7a0b18 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_tuning.jl @@ -0,0 +1,3 @@ +include("mcmc_noop_tuner.jl") +include("mcmc_ram_tuner.jl") +include("mcmc_proposalcov_tuner.jl") \ No newline at end of file diff --git a/src/samplers/transformed_mcmc/mcmc_utils.jl b/src/samplers/transformed_mcmc/mcmc_utils.jl new file mode 100644 index 000000000..b87662cca --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_utils.jl @@ -0,0 +1,27 @@ + +function _cov_with_fallback(d) + rng = bat_determ_rng() + smplr = bat_sampler(d) + T = float(eltype(rand(rng, smplr))) + n = totalndof(varshape(d)) + C = fill(T(NaN), n, n) + try + C[:] = cov(d) + catch err + if err isa MethodError + C[:] = cov(nestedview(rand(rng, smplr, 10^5))) + else + throw(err) + end + end + return C +end + +_approx_cov(target::Distribution) = _cov_with_fallback(target) +_approx_cov(target::DistLikeMeasure) = _cov_with_fallback(target) +_approx_cov(target::AbstractPosteriorMeasure) = _approx_cov(getprior(target)) +_approx_cov(target::BAT.Transformed{<:Any,<:BAT.DistributionTransform}) = + BAT._approx_cov(target.trafo.target_dist) +_approx_cov(target::Renormalized) = _approx_cov(parent(target)) +_approx_cov(target::WithDiff) = _approx_cov(parent(target)) + diff --git a/src/samplers/transformed_mcmc/mcmc_weighting.jl b/src/samplers/transformed_mcmc/mcmc_weighting.jl new file mode 100644 index 000000000..478e735e5 --- /dev/null +++ b/src/samplers/transformed_mcmc/mcmc_weighting.jl @@ -0,0 +1,53 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +""" + abstract type AbstractMCMCWeightingScheme{T<:Real} + +Abstract class for weighting schemes for MCMC samples. + +Weight values will have type `T`. +""" +abstract type AbstractMCMCWeightingScheme{T<:Real} end +export AbstractMCMCWeightingScheme + + +sample_weight_type(::Type{<:AbstractMCMCWeightingScheme{T}}) where {T} = T + + + +""" + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} + +Sample weighting scheme suitable for sampling algorithms which may repeated +samples multiple times in direct succession (e.g. +[`MetropolisHastings`](@ref)). The repeated sample is stored only once, +with a weight equal to the number of times it has been repeated (e.g. +because a Markov chain has not moved during a sampling step). + +Constructors: + +* ```$(FUNCTIONNAME)()``` +""" +struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end +export TransformedRepetitionWeighting + +TransformedRepetitionWeighting() = TransformedRepetitionWeighting{Int}() + + +""" + TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} + +Sample weighting scheme suitable for accept/reject-based sampling algorithms +(e.g. [`MetropolisHastings`](@ref)). Both accepted and rejected samples +become part of the output, with a weight proportional to their original +acceptance probability. + +Constructors: + +* ```$(FUNCTIONNAME)()``` +""" +struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end +export TransformedARPWeighting + +TransformedARPWeighting() = TransformedARPWeighting{Float64}() diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl new file mode 100644 index 000000000..139ed6a2b --- /dev/null +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -0,0 +1,110 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +""" + struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm + +A multi-cycle MCMC burn-in algorithm. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm + nsteps_per_cycle::Int64 = 10000 + max_ncycles::Int = 30 + nsteps_final::Int64 = div(nsteps_per_cycle, 10) +end + +export TransformedMCMCMultiCycleBurnin + + +function mcmc_burnin!( + outputs::Union{DensitySampleVector,Nothing}, + chains::AbstractVector{<:MCMCIterator}, + tuners::AbstractVector{<:AbstractMCMCTunerInstance}, + temperers::AbstractVector{<:MCMCTemperingInstance}, + burnin_alg::TransformedMCMCMultiCycleBurnin, + convergence_test::ConvergenceTest, + strict_mode::Bool, + nonzero_weights::Bool, + callback::Function +) + nchains = length(chains) + + @info "Begin tuning of $nchains MCMC chain(s)." + + cycles = zero(Int) + successful = false + while !successful && cycles < burnin_alg.max_ncycles + cycles += 1 + + next_cycle!.(chains) + + tuning_reinit!.(tuners, chains, burnin_alg.nsteps_per_cycle) + + desc_string = string("Burnin cycle ", cycles, "/max_cycles=", burnin_alg.max_ncycles," for nchains=", length(chains)) + progress_meter = ProgressMeter.Progress(length(chains)*burnin_alg.nsteps_per_cycle, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + + transformed_mcmc_iterate!( + chains, tuners, temperers, + max_nsteps = burnin_alg.nsteps_per_cycle, + nonzero_weights = nonzero_weights, + callback = (kwargs...) -> let pm=progress_meter; ProgressMeter.next!(progress_meter) ; end, + ) + ProgressMeter.finish!(progress_meter) + + new_outputs = getproperty.(chains, :samples) + + tuning_update!.(tuners, chains, new_outputs) + + isnothing(outputs) || append!(outputs, reduce(vcat, new_outputs)) + + check_convergence!(chains, new_outputs, convergence_test) + + # check_tuned/update_tuners... + ntuned = count(c -> c.info.tuned, chains) + nconverged = count(c -> c.info.converged, chains) + successful = (ntuned == nconverged == nchains) + + callback(Val(:mcmc_burnin), tuners, chains) + + @info "MCMC Tuning cycle $cycles finished, $nchains chains, $ntuned tuned, $nconverged converged." + end + + tuning_finalize!.(tuners, chains) + + if successful + @info "MCMC tuning of $nchains chains successful after $cycles cycle(s)." + else + msg = "MCMC tuning of $nchains chains aborted after $cycles cycle(s)." + if strict_mode + throw(ErrorException(msg)) + else + @warn msg + end + end + + if burnin_alg.nsteps_final > 0 + @info "Running post-tuning stabilization steps for $nchains MCMC chain(s)." + + # turn off tuning + next_cycle!.(chains) + tuners = TransformedMCMCNoOpTuning().(chains) + + # TODO AC: what about tempering? + + transformed_mcmc_iterate!( + chains, tuners, temperers, + max_nsteps = burnin_alg.nsteps_final, + nonzero_weights = nonzero_weights, + callback = callback + ) + end + + successful +end diff --git a/src/samplers/transformed_mcmc/proposaldist.jl b/src/samplers/transformed_mcmc/proposaldist.jl new file mode 100644 index 000000000..fd57cfb35 --- /dev/null +++ b/src/samplers/transformed_mcmc/proposaldist.jl @@ -0,0 +1,205 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + + +""" + abstract type AbstractProposalDist + +*BAT-internal, not part of stable public API.* + +The following functions must be implemented for subtypes: + +* `BAT.proposaldist_logpdf` +* `BAT.proposal_rand!` +* `ValueShapes.totalndof`, returning the number of DOF (i.e. dimensionality). +* `LinearAlgebra.issymmetric`, indicating whether p(a -> b) == p(b -> a) holds true. +""" +abstract type AbstractProposalDist end + + +""" + proposaldist_logpdf( + p::AbstractArray, + pdist::AbstractProposalDist, + v_proposed::AbstractVector, + v_current:::AbstractVector + ) + +*BAT-internal, not part of stable public API.* + +Returns log(PDF) value of `pdist` for transitioning from current to proposed +variate/parameters. +""" +function proposaldist_logpdf end + +# TODO: Implement proposaldist_logpdf for included proposal distributions + + +""" + function proposal_rand!( + rng::AbstractRNG, + pdist::TransformedGenericProposalDist, + v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, + v_current::Union{AbstractVector,VectorOfSimilarVectors} + ) + +*BAT-internal, not part of stable public API.* + +Generate one or multiple proposed variate/parameter vectors, based on one or +multiple previous vectors. + +Input: + +* `rng`: Random number generator to use +* `pdist`: Proposal distribution to use +* `v_current`: Old values (vector or column vectors, if a matrix) + +Output is stored in + +* `v_proposed`: New values (vector or column vectors, if a matrix) + +The caller must guarantee: + +* `size(v_current, 1) == size(v_proposed, 1)` +* `size(v_current, 2) == size(v_proposed, 2)` or `size(v_current, 2) == 1` +* `v_proposed !== v_current` (no aliasing) + +Implementations of `proposal_rand!` must be thread-safe. +""" +function proposal_rand! end + + + +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist + d::D + sampler_f::SamplerF + s::S + + function TransformedGenericProposalDist{D,SamplerF}(d::D, sampler_f::SamplerF) where {D<:Distribution{Multivariate},SamplerF} + s = sampler_f(d) + new{D,SamplerF, typeof(s)}(d, sampler_f, s) + end + +end + + +TransformedGenericProposalDist(d::D, sampler_f::SamplerF) where {D<:Distribution{Multivariate},SamplerF} = + TransformedGenericProposalDist{D,SamplerF}(d, sampler_f) + +TransformedGenericProposalDist(d::Distribution{Multivariate}) = TransformedGenericProposalDist(d, bat_sampler) + +TransformedGenericProposalDist(D::Type{<:Distribution{Multivariate}}, varndof::Integer, args...) = + TransformedGenericProposalDist(D, Float64, varndof, args...) + + +Base.similar(q::TransformedGenericProposalDist, d::Distribution{Multivariate}) = + TransformedGenericProposalDist(d, q.sampler_f) + +function Base.convert(::Type{AbstractProposalDist}, q::TransformedGenericProposalDist, T::Type{<:AbstractFloat}, varndof::Integer) + varndof != totalndof(q) && throw(ArgumentError("q has wrong number of DOF")) + q +end + + +get_cov(q::TransformedGenericProposalDist) = get_cov(q.d) +set_cov(q::TransformedGenericProposalDist, Σ::PosDefMatLike) = similar(q, set_cov(q.d, Σ)) + + +function proposaldist_logpdf( + pdist::TransformedGenericProposalDist, + v_proposed::AbstractVector, + v_current::AbstractVector +) + params_diff = v_proposed .- v_current # TODO: Avoid memory allocation + logpdf(pdist.d, params_diff) +end + + +function proposal_rand!( + rng::AbstractRNG, + pdist::TransformedGenericProposalDist, + v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, + v_current::Union{AbstractVector,VectorOfSimilarVectors} +) + rand!(rng, pdist.s, flatview(v_proposed)) + params_new_flat = flatview(v_proposed) + params_new_flat .+= flatview(v_current) + v_proposed +end + + +ValueShapes.totalndof(pdist::TransformedGenericProposalDist) = length(pdist.d) + +LinearAlgebra.issymmetric(pdist::TransformedGenericProposalDist) = issymmetric_around_origin(pdist.d) + + + +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist + d::D + scale::Vector{T} + sampler_f::SamplerF + s::S +end + + +TransformedGenericUvProposalDist(d::Distribution{Univariate}, scale::Vector{<:AbstractFloat}, samplerF) = + TransformedGenericUvProposalDist(d, scale, samplerF, samplerF(d)) + +TransformedGenericUvProposalDist(d::Distribution{Univariate}, scale::Vector{<:AbstractFloat}) = + TransformedGenericUvProposalDist(d, scale, bat_sampler) + + +ValueShapes.totalndof(pdist::TransformedGenericUvProposalDist) = size(pdist.scale, 1) + +LinearAlgebra.issymmetric(pdist::TransformedGenericUvProposalDist) = issymmetric_around_origin(pdist.d) + +function BAT.proposaldist_logpdf( + pdist::TransformedGenericUvProposalDist, + v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, + v_current::Union{AbstractVector,VectorOfSimilarVectors} +) + params_diff = (flatview(v_proposed) .- flatview(v_current)) ./ pdist.scale # TODO: Avoid memory allocation + sum_first_dim(logpdf.(pdist.d, params_diff)) # TODO: Avoid memory allocation +end + +function BAT.proposal_rand!( + rng::AbstractRNG, + pdist::TransformedGenericUvProposalDist, + v_proposed::AbstractVector, + v_current::AbstractVector +) + v_proposed .= v_current + dim = rand(rng, eachindex(pdist.scale)) + v_proposed[dim] += pdist.scale[dim] * rand(rng, pdist.s) + v_proposed +end + + + +abstract type ProposalDistSpec end + + +struct TransformedMvTDistProposal <: ProposalDistSpec + df::Float64 +end + +TransformedMvTDistProposal() = TransformedMvTDistProposal(1.0) + + +(ps::TransformedMvTDistProposal)(T::Type{<:AbstractFloat}, varndof::Integer) = + TransformedGenericProposalDist(MvTDist, T, varndof, convert(T, ps.df)) + +function TransformedGenericProposalDist(::Type{MvTDist}, T::Type{<:AbstractFloat}, varndof::Integer, df = one(T)) + Σ = PDMat(Matrix(ScalMat(varndof, one(T)))) + μ = Fill(zero(eltype(Σ)), varndof) + M = typeof(Σ) + d = Distributions.GenericMvTDist(convert(T, df), μ, Σ) + TransformedGenericProposalDist(d) +end + + +struct TransformedUvTDistProposalSpec <: ProposalDistSpec + df::Float64 +end + +(ps::TransformedUvTDistProposalSpec)(T::Type{<:AbstractFloat}, varndof::Integer) = + TransformedGenericUvProposalDist(TDist(convert(T, ps.df)), fill(one(T), varndof)) diff --git a/src/samplers/transformed_mcmc/replace_type_list.sh b/src/samplers/transformed_mcmc/replace_type_list.sh new file mode 100644 index 000000000..ec9216d63 --- /dev/null +++ b/src/samplers/transformed_mcmc/replace_type_list.sh @@ -0,0 +1,14 @@ +# find . -name \*.jl -execdir sh -c +find . -name \*.jl -execdir sed -i -e "s/ \(MCMCInitAlgorithm\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(MCMCTuningAlgorithm\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(MCMCBurninAlgorithm\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(MCMCIterator\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(AbstractMCMCTunerInstance\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(MCMCProposal\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(SampleID\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(AbstractMCMCStats\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(AbstractMCMCWeightingScheme\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(AbstractProposalDist\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(ProposalDistSpec\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(MCMCTempering\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/ \(MCMCTemperingInstance\)/ Transformed\1/g" {} \; diff --git a/src/samplers/transformed_mcmc/struct_list.jl b/src/samplers/transformed_mcmc/struct_list.jl new file mode 100644 index 000000000..4ab047f02 --- /dev/null +++ b/src/samplers/transformed_mcmc/struct_list.jl @@ -0,0 +1,46 @@ + + + + +struct TransformedGelmanRubinConvergence <: ConvergenceTest +# Constructors: +# @with_kw struct TransformedGelmanRubinConvergence <: ConvergenceTest + struct TransformedBrooksGelmanConvergence <: ConvergenceTest +# Constructors: +# @with_kw struct TransformedBrooksGelmanConvergence <: ConvergenceTest +struct TransformedMCMCNullStats <: AbstractMCMCStats end +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} +# Constructors: +# struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end +# Constructors: +struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedMvTDistProposal <: ProposalDistSpec +struct TransformedUvTDistProposalSpec <: ProposalDistSpec + struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm +# Constructors: +# @with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm +struct TransformedNoMCMCTempering <: MCMCTempering end +# struct NoMCMCTemperingInstance <: MCMCTemperingInstance end + struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +# Constructors: +# @with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +# function _construct_chain( +# ) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids] +struct TransformedMCMCSampleID{ +struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end +struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end +@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm +mutable struct TransformedProposalCovTuner{ +@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning +@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance +@with_kw struct TransformedMCMCIteratorInfo +Constructors: +struct TransformedMCMCSampleGenerator{ +mutable struct TransformedMCMCIterator{ +struct TransformedMHProposal{ +# TODO AC: find a better solution for this. Problem is that in the with_kw constructor below, we need to dispatch on this type. +struct TransformedMCMCDispatch end +@with_kw struct TransformedMCMCSampling{ \ No newline at end of file diff --git a/src/samplers/transformed_mcmc/tempering.jl b/src/samplers/transformed_mcmc/tempering.jl new file mode 100644 index 000000000..1601c7ae0 --- /dev/null +++ b/src/samplers/transformed_mcmc/tempering.jl @@ -0,0 +1,18 @@ +abstract type MCMCTempering end +struct TransformedNoMCMCTempering <: MCMCTempering end + +""" + temper_mcmc_target!!(tempering::MCMCTemperingInstance, μ::BATMeasure, stepno::Integer) +""" +function temper_mcmc_target!! end + + + +abstract type MCMCTemperingInstance end + +struct NoMCMCTemperingInstance <: MCMCTemperingInstance end + +temper_mcmc_target!!(tempering::NoMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ + +get_temperer(tempering::TransformedNoMCMCTempering, density::BATMeasure) = NoMCMCTemperingInstance() +get_temperer(tempering::TransformedNoMCMCTempering, chain::MCMCIterator) = get_temperer(tempering, chain.μ) From 9cbacf2a200310e7761a7b309460244c3e711cc4 Mon Sep 17 00:00:00 2001 From: AG Date: Fri, 30 Jun 2023 12:51:07 +0200 Subject: [PATCH 02/33] use Transformed prefix for all types --- .../transformed_mcmc/chain_pool_init.jl | 4 +- .../transformed_mcmc/mcmc_algorithm.jl | 26 ++++++------ src/samplers/transformed_mcmc/mcmc_iterate.jl | 2 +- src/samplers/transformed_mcmc/mcmc_sample.jl | 4 +- .../transformed_mcmc/mcmc_sampleid.jl | 4 +- src/samplers/transformed_mcmc/mcmc_stats.jl | 6 +-- .../mcmc_tuning/mcmc_noop_tuner.jl | 6 +-- .../mcmc_tuning/mcmc_proposalcov_tuner.jl | 4 +- .../mcmc_tuning/mcmc_ram_tuner.jl | 4 +- .../transformed_mcmc/mcmc_weighting.jl | 14 +++---- .../transformed_mcmc/multi_cycle_burnin.jl | 4 +- src/samplers/transformed_mcmc/proposaldist.jl | 14 +++---- src/samplers/transformed_mcmc/struct_list.jl | 40 +++++++++---------- src/samplers/transformed_mcmc/tempering.jl | 8 ++-- 14 files changed, 70 insertions(+), 70 deletions(-) diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl index 20acac8ad..ee50360ac 100644 --- a/src/samplers/transformed_mcmc/chain_pool_init.jl +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -1,7 +1,7 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). """ - struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm + struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm MCMC chain pool initialization strategy. @@ -13,7 +13,7 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +@with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm init_tries_per_chain::ClosedInterval{Int64} = ClosedInterval(8, 128) nsteps_init::Int64 = 1000 initval_alg::InitvalAlgorithm = InitFromTarget() diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl index 07d4accf6..72d6e27d2 100644 --- a/src/samplers/transformed_mcmc/mcmc_algorithm.jl +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -6,34 +6,34 @@ function get_mcmc_tuning end #TODO: still needed """ - abstract type MCMCInitAlgorithm + abstract type TransformedMCMCInitAlgorithm Abstract type for MCMC initialization algorithms. """ -abstract type MCMCInitAlgorithm end -export MCMCInitAlgorithm +abstract type TransformedMCMCInitAlgorithm end +export TransformedMCMCInitAlgorithm apply_trafo_to_init(trafo::Function, initalg::MCMCInitAlgorithm) = initalg """ - abstract type MCMCTuningAlgorithm + abstract type TransformedMCMCTuningAlgorithm Abstract type for MCMC tuning algorithms. """ -abstract type MCMCTuningAlgorithm end -export MCMCTuningAlgorithm +abstract type TransformedMCMCTuningAlgorithm end +export TransformedMCMCTuningAlgorithm """ - abstract type MCMCBurninAlgorithm + abstract type TransformedMCMCBurninAlgorithm Abstract type for MCMC burn-in algorithms. """ -abstract type MCMCBurninAlgorithm end -export MCMCBurninAlgorithm +abstract type TransformedMCMCBurninAlgorithm end +export TransformedMCMCBurninAlgorithm @@ -46,7 +46,7 @@ end """ - abstract type MCMCIterator end + abstract type TransformedMCMCIterator end Represents the current state of an MCMC chain. @@ -102,8 +102,8 @@ isvalidchain(chain::MCMCIterator) isviablechain(chain::MCMCIterator) ``` """ -abstract type MCMCIterator end -export MCMCIterator +abstract type TransformedMCMCIterator end +export TransformedMCMCIterator function Base.show(io::IO, chain::MCMCIterator) @@ -145,7 +145,7 @@ DensitySampleVector(chain::MCMCIterator) = DensitySampleVector(sample_type(chain -abstract type AbstractMCMCTunerInstance end +abstract type TransformedAbstractMCMCTunerInstance end function tuning_init! end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl index 5f7028c6d..939ba64dd 100644 --- a/src/samplers/transformed_mcmc/mcmc_iterate.jl +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -6,7 +6,7 @@ mutable struct TransformedMCMCIterator{ Q<:MCMCProposal, SV<:DensitySampleVector, S<:DensitySample, -} <: MCMCIterator +} <: TransformedMCMCIterator rng::R rngpart_cycle::PR μ::D diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl index b09c2972d..f4e96f0f4 100644 --- a/src/samplers/transformed_mcmc/mcmc_sample.jl +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -1,4 +1,4 @@ -abstract type MCMCProposal end +abstract type TransformedMCMCProposal end """ BAT.TransformedMHProposal @@ -6,7 +6,7 @@ abstract type MCMCProposal end """ struct TransformedMHProposal{ D<:Union{Distribution, AbstractMeasure} -}<: MCMCProposal +}<: TransformedMCMCProposal proposal_dist::D end diff --git a/src/samplers/transformed_mcmc/mcmc_sampleid.jl b/src/samplers/transformed_mcmc/mcmc_sampleid.jl index 1dfb92c3f..e2e53766f 100644 --- a/src/samplers/transformed_mcmc/mcmc_sampleid.jl +++ b/src/samplers/transformed_mcmc/mcmc_sampleid.jl @@ -1,11 +1,11 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -abstract type SampleID end +abstract type TransformedSampleID end struct TransformedMCMCSampleID{ T<:Int32, U<:Int64, -} <: SampleID +} <: TransformedSampleID chainid::T chaincycle::T stepno::U diff --git a/src/samplers/transformed_mcmc/mcmc_stats.jl b/src/samplers/transformed_mcmc/mcmc_stats.jl index cf57675d3..8fccc353e 100644 --- a/src/samplers/transformed_mcmc/mcmc_stats.jl +++ b/src/samplers/transformed_mcmc/mcmc_stats.jl @@ -1,12 +1,12 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -abstract type AbstractMCMCStats end +abstract type TransformedAbstractMCMCStats end AbstractMCMCStats -struct TransformedMCMCNullStats <: AbstractMCMCStats end +struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end Base.push!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats @@ -15,7 +15,7 @@ Base.append!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats -struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats param_stats::BasicMvStatistics{P,FrequencyWeights} logtf_stats::BasicUvStatistics{L,FrequencyWeights} mode::Vector{P} diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl index c5a1cca01..8352a88c7 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -2,18 +2,18 @@ """ - TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm + TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm No-op tuning, marks MCMC chains as tuned without performing any other changes on them. Useful if chains are pre-tuned or tuning is an internal part of the MCMC sampler implementation. """ -struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end +struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end export TransformedMCMCNoOpTuning -struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end +struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end (tuning::TransformedMCMCNoOpTuning)(chain::MCMCIterator) = TransformedMCMCNoOpTuner() get_tuner(tuning::TransformedMCMCNoOpTuning, chain::MCMCIterator) = TransformedMCMCNoOpTuner() diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl index 93772cdf6..405470a38 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -1,4 +1,4 @@ -@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm +@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm "Controls the weight given to new covariance information in adapting the proposal distribution." λ::Float64 = 0.5 @@ -22,7 +22,7 @@ end mutable struct TransformedProposalCovTuner{ S<:TransformedMCMCBasicStats -} <: AbstractMCMCTunerInstance +} <: TransformedAbstractMCMCTunerInstance config::TransformedAdaptiveMHTuning stats::S iteration::Int diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl index f3ea576e2..ae1572fbe 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -1,10 +1,10 @@ -@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning +@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning target_acceptance::Float64 = 0.234 #TODO AC: how to pass custom intitial value for cov matrix? σ_target_acceptance::Float64 = 0.05 gamma::Float64 = 2/3 end -@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance +@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance config::TransformedRAMTuner nsteps::Int = 0 end diff --git a/src/samplers/transformed_mcmc/mcmc_weighting.jl b/src/samplers/transformed_mcmc/mcmc_weighting.jl index 478e735e5..2d9662cae 100644 --- a/src/samplers/transformed_mcmc/mcmc_weighting.jl +++ b/src/samplers/transformed_mcmc/mcmc_weighting.jl @@ -2,14 +2,14 @@ """ - abstract type AbstractMCMCWeightingScheme{T<:Real} + abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} Abstract class for weighting schemes for MCMC samples. Weight values will have type `T`. """ -abstract type AbstractMCMCWeightingScheme{T<:Real} end -export AbstractMCMCWeightingScheme +abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} end +export TransformedAbstractMCMCWeightingScheme sample_weight_type(::Type{<:AbstractMCMCWeightingScheme{T}}) where {T} = T @@ -17,7 +17,7 @@ sample_weight_type(::Type{<:AbstractMCMCWeightingScheme{T}}) where {T} = T """ - struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} Sample weighting scheme suitable for sampling algorithms which may repeated samples multiple times in direct succession (e.g. @@ -29,14 +29,14 @@ Constructors: * ```$(FUNCTIONNAME)()``` """ -struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end +struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end export TransformedRepetitionWeighting TransformedRepetitionWeighting() = TransformedRepetitionWeighting{Int}() """ - TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} + TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} Sample weighting scheme suitable for accept/reject-based sampling algorithms (e.g. [`MetropolisHastings`](@ref)). Both accepted and rejected samples @@ -47,7 +47,7 @@ Constructors: * ```$(FUNCTIONNAME)()``` """ -struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end +struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end export TransformedARPWeighting TransformedARPWeighting() = TransformedARPWeighting{Float64}() diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl index 139ed6a2b..9edff8ccd 100644 --- a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -2,7 +2,7 @@ """ - struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm + struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm A multi-cycle MCMC burn-in algorithm. @@ -14,7 +14,7 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm +@with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm nsteps_per_cycle::Int64 = 10000 max_ncycles::Int = 30 nsteps_final::Int64 = div(nsteps_per_cycle, 10) diff --git a/src/samplers/transformed_mcmc/proposaldist.jl b/src/samplers/transformed_mcmc/proposaldist.jl index fd57cfb35..94fb9bf3e 100644 --- a/src/samplers/transformed_mcmc/proposaldist.jl +++ b/src/samplers/transformed_mcmc/proposaldist.jl @@ -2,7 +2,7 @@ """ - abstract type AbstractProposalDist + abstract type TransformedAbstractProposalDist *BAT-internal, not part of stable public API.* @@ -13,7 +13,7 @@ The following functions must be implemented for subtypes: * `ValueShapes.totalndof`, returning the number of DOF (i.e. dimensionality). * `LinearAlgebra.issymmetric`, indicating whether p(a -> b) == p(b -> a) holds true. """ -abstract type AbstractProposalDist end +abstract type TransformedAbstractProposalDist end """ @@ -69,7 +69,7 @@ function proposal_rand! end -struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist d::D sampler_f::SamplerF s::S @@ -133,7 +133,7 @@ LinearAlgebra.issymmetric(pdist::TransformedGenericProposalDist) = issymmetric_a -struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist d::D scale::Vector{T} sampler_f::SamplerF @@ -175,10 +175,10 @@ end -abstract type ProposalDistSpec end +abstract type TransformedProposalDistSpec end -struct TransformedMvTDistProposal <: ProposalDistSpec +struct TransformedMvTDistProposal <: TransformedProposalDistSpec df::Float64 end @@ -197,7 +197,7 @@ function TransformedGenericProposalDist(::Type{MvTDist}, T::Type{<:AbstractFloat end -struct TransformedUvTDistProposalSpec <: ProposalDistSpec +struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec df::Float64 end diff --git a/src/samplers/transformed_mcmc/struct_list.jl b/src/samplers/transformed_mcmc/struct_list.jl index 4ab047f02..4a16d922b 100644 --- a/src/samplers/transformed_mcmc/struct_list.jl +++ b/src/samplers/transformed_mcmc/struct_list.jl @@ -8,34 +8,34 @@ struct TransformedGelmanRubinConvergence <: ConvergenceTest struct TransformedBrooksGelmanConvergence <: ConvergenceTest # Constructors: # @with_kw struct TransformedBrooksGelmanConvergence <: ConvergenceTest -struct TransformedMCMCNullStats <: AbstractMCMCStats end -struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats - struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} +struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} # Constructors: -# struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end +# struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end # Constructors: -struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end -struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist -struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist -struct TransformedMvTDistProposal <: ProposalDistSpec -struct TransformedUvTDistProposalSpec <: ProposalDistSpec - struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm +struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist +struct TransformedMvTDistProposal <: TransformedProposalDistSpec +struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec + struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm # Constructors: -# @with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm -struct TransformedNoMCMCTempering <: MCMCTempering end -# struct NoMCMCTemperingInstance <: MCMCTemperingInstance end - struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +# @with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm +struct TransformedNoMCMCTempering <: TransformedMCMCTempering end +# struct NoMCMCTemperingInstance <: TransformedMCMCTemperingInstance end + struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm # Constructors: -# @with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +# @with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm # function _construct_chain( # ) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids] struct TransformedMCMCSampleID{ -struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end -struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end -@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm +struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end +struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end +@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm mutable struct TransformedProposalCovTuner{ -@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning -@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance +@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning +@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance @with_kw struct TransformedMCMCIteratorInfo Constructors: struct TransformedMCMCSampleGenerator{ diff --git a/src/samplers/transformed_mcmc/tempering.jl b/src/samplers/transformed_mcmc/tempering.jl index 1601c7ae0..869488886 100644 --- a/src/samplers/transformed_mcmc/tempering.jl +++ b/src/samplers/transformed_mcmc/tempering.jl @@ -1,5 +1,5 @@ -abstract type MCMCTempering end -struct TransformedNoMCMCTempering <: MCMCTempering end +abstract type TransformedMCMCTempering end +struct TransformedNoMCMCTempering <: TransformedMCMCTempering end """ temper_mcmc_target!!(tempering::MCMCTemperingInstance, μ::BATMeasure, stepno::Integer) @@ -8,9 +8,9 @@ function temper_mcmc_target!! end -abstract type MCMCTemperingInstance end +abstract type TransformedMCMCTemperingInstance end -struct NoMCMCTemperingInstance <: MCMCTemperingInstance end +struct NoMCMCTemperingInstance <: TransformedMCMCTemperingInstance end temper_mcmc_target!!(tempering::NoMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ From 6cbe392a6b38658f5171e474e75f3f66f749319e Mon Sep 17 00:00:00 2001 From: AG Date: Fri, 30 Jun 2023 12:54:05 +0200 Subject: [PATCH 03/33] update deps --- Project.toml | 5 ++++- src/BAT.jl | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 592a5f10d..0f005d083 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "3.0.0-DEV" [deps] AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" +AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -47,6 +48,7 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" @@ -83,6 +85,7 @@ BATUltraNestExt = "UltraNest" [compat] AdvancedHMC = "0.3, 0.4" +AffineMaps = "≥ 0.2" ArgCheck = "1, 2.0" ArraysOfArrays = "0.4, 0.5, 0.6" ChainRulesCore = "0.9.44, 0.10, 1" @@ -114,7 +117,7 @@ IrrationalConstants = "0.1, 0.2" KernelDensity = "0.5, 0.6" LaTeXStrings = "1" MacroTools = "0.5" -MeasureBase = "0.12, 0.13, 0.14" +MeasureBase = "0.14" Measurements = "2" NamedArrays = "0.9" NestedSamplers = "0.8" diff --git a/src/BAT.jl b/src/BAT.jl index 6af849d7a..c3ce8b1fc 100644 --- a/src/BAT.jl +++ b/src/BAT.jl @@ -16,6 +16,7 @@ using Printf using Random using Statistics +using AffineMaps using ArgCheck using ArraysOfArrays using ChangesOfVariables @@ -60,6 +61,7 @@ import Measurements import NamedArrays import NLSolversBase import Optim +import ProgressMeter import Random123 import Sobol import StableRNGs From e32abcac08d2020fc2491cab5b1c4cc2644dd214 Mon Sep 17 00:00:00 2001 From: AG Date: Fri, 30 Jun 2023 13:02:01 +0200 Subject: [PATCH 04/33] add AdaptiveTransform --- src/transforms/adaptive_transform.jl | 40 ++++++++++++++++++++++++++++ src/transforms/transforms.jl | 1 + 2 files changed, 41 insertions(+) create mode 100644 src/transforms/adaptive_transform.jl diff --git a/src/transforms/adaptive_transform.jl b/src/transforms/adaptive_transform.jl new file mode 100644 index 000000000..6dba6955d --- /dev/null +++ b/src/transforms/adaptive_transform.jl @@ -0,0 +1,40 @@ +abstract type AdaptiveTransformSpec end + + +struct CustomTransform{F} <: AdaptiveTransformSpec + f::F +end + +CustomTransform() = CustomTransform(identity) + +function init_adaptive_transform( + rng::AbstractRNG, + adaptive_transform::CustomTransform, + density +) + return adaptive_transform +end + + + +struct TriangularAffineTransform <: AdaptiveTransformSpec end + +function init_adaptive_transform( + rng::AbstractRNG, + adaptive_transform::TriangularAffineTransform, + density +) + M = _approx_cov(density) + s = cholesky(M).L + g = Mul(s) + + return g +end + + + +struct DiagonalAffineTransform <: AdaptiveTransformSpec end + + + + diff --git a/src/transforms/transforms.jl b/src/transforms/transforms.jl index 6c19878a2..fdd814e9a 100644 --- a/src/transforms/transforms.jl +++ b/src/transforms/transforms.jl @@ -2,3 +2,4 @@ include("trafo_utils.jl") include("distribution_transform.jl") +include("adaptive_transform.jl") From 50c02ba9d9c24ba02c133ad2878cda438135f76a Mon Sep 17 00:00:00 2001 From: AG Date: Fri, 30 Jun 2023 13:08:50 +0200 Subject: [PATCH 05/33] include transformed_mcmc --- src/samplers/samplers.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/samplers/samplers.jl b/src/samplers/samplers.jl index 8b09eb56b..5f558199e 100644 --- a/src/samplers/samplers.jl +++ b/src/samplers/samplers.jl @@ -2,5 +2,6 @@ include("bat_sample.jl") include("mcmc/mcmc.jl") +include("transformed_mcmc/mcmc.jl") include("sampled_density.jl") include("importance/importance_sampler.jl") From b0794f07b156bb19d0e4e1250adf7a6b384ff697 Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Fri, 30 Jun 2023 13:52:56 +0200 Subject: [PATCH 06/33] Revert "use Transformed prefix for all types" This reverts commit 9cbacf2a200310e7761a7b309460244c3e711cc4. --- .../transformed_mcmc/chain_pool_init.jl | 4 +- .../transformed_mcmc/mcmc_algorithm.jl | 26 ++++++------ src/samplers/transformed_mcmc/mcmc_iterate.jl | 2 +- src/samplers/transformed_mcmc/mcmc_sample.jl | 4 +- .../transformed_mcmc/mcmc_sampleid.jl | 4 +- src/samplers/transformed_mcmc/mcmc_stats.jl | 6 +-- .../mcmc_tuning/mcmc_noop_tuner.jl | 6 +-- .../mcmc_tuning/mcmc_proposalcov_tuner.jl | 4 +- .../mcmc_tuning/mcmc_ram_tuner.jl | 4 +- .../transformed_mcmc/mcmc_weighting.jl | 14 +++---- .../transformed_mcmc/multi_cycle_burnin.jl | 4 +- src/samplers/transformed_mcmc/proposaldist.jl | 14 +++---- src/samplers/transformed_mcmc/struct_list.jl | 40 +++++++++---------- src/samplers/transformed_mcmc/tempering.jl | 8 ++-- 14 files changed, 70 insertions(+), 70 deletions(-) diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl index ee50360ac..20acac8ad 100644 --- a/src/samplers/transformed_mcmc/chain_pool_init.jl +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -1,7 +1,7 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). """ - struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm + struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm MCMC chain pool initialization strategy. @@ -13,7 +13,7 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm +@with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm init_tries_per_chain::ClosedInterval{Int64} = ClosedInterval(8, 128) nsteps_init::Int64 = 1000 initval_alg::InitvalAlgorithm = InitFromTarget() diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl index 72d6e27d2..07d4accf6 100644 --- a/src/samplers/transformed_mcmc/mcmc_algorithm.jl +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -6,34 +6,34 @@ function get_mcmc_tuning end #TODO: still needed """ - abstract type TransformedMCMCInitAlgorithm + abstract type MCMCInitAlgorithm Abstract type for MCMC initialization algorithms. """ -abstract type TransformedMCMCInitAlgorithm end -export TransformedMCMCInitAlgorithm +abstract type MCMCInitAlgorithm end +export MCMCInitAlgorithm apply_trafo_to_init(trafo::Function, initalg::MCMCInitAlgorithm) = initalg """ - abstract type TransformedMCMCTuningAlgorithm + abstract type MCMCTuningAlgorithm Abstract type for MCMC tuning algorithms. """ -abstract type TransformedMCMCTuningAlgorithm end -export TransformedMCMCTuningAlgorithm +abstract type MCMCTuningAlgorithm end +export MCMCTuningAlgorithm """ - abstract type TransformedMCMCBurninAlgorithm + abstract type MCMCBurninAlgorithm Abstract type for MCMC burn-in algorithms. """ -abstract type TransformedMCMCBurninAlgorithm end -export TransformedMCMCBurninAlgorithm +abstract type MCMCBurninAlgorithm end +export MCMCBurninAlgorithm @@ -46,7 +46,7 @@ end """ - abstract type TransformedMCMCIterator end + abstract type MCMCIterator end Represents the current state of an MCMC chain. @@ -102,8 +102,8 @@ isvalidchain(chain::MCMCIterator) isviablechain(chain::MCMCIterator) ``` """ -abstract type TransformedMCMCIterator end -export TransformedMCMCIterator +abstract type MCMCIterator end +export MCMCIterator function Base.show(io::IO, chain::MCMCIterator) @@ -145,7 +145,7 @@ DensitySampleVector(chain::MCMCIterator) = DensitySampleVector(sample_type(chain -abstract type TransformedAbstractMCMCTunerInstance end +abstract type AbstractMCMCTunerInstance end function tuning_init! end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl index 939ba64dd..5f7028c6d 100644 --- a/src/samplers/transformed_mcmc/mcmc_iterate.jl +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -6,7 +6,7 @@ mutable struct TransformedMCMCIterator{ Q<:MCMCProposal, SV<:DensitySampleVector, S<:DensitySample, -} <: TransformedMCMCIterator +} <: MCMCIterator rng::R rngpart_cycle::PR μ::D diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl index f4e96f0f4..b09c2972d 100644 --- a/src/samplers/transformed_mcmc/mcmc_sample.jl +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -1,4 +1,4 @@ -abstract type TransformedMCMCProposal end +abstract type MCMCProposal end """ BAT.TransformedMHProposal @@ -6,7 +6,7 @@ abstract type TransformedMCMCProposal end """ struct TransformedMHProposal{ D<:Union{Distribution, AbstractMeasure} -}<: TransformedMCMCProposal +}<: MCMCProposal proposal_dist::D end diff --git a/src/samplers/transformed_mcmc/mcmc_sampleid.jl b/src/samplers/transformed_mcmc/mcmc_sampleid.jl index e2e53766f..1dfb92c3f 100644 --- a/src/samplers/transformed_mcmc/mcmc_sampleid.jl +++ b/src/samplers/transformed_mcmc/mcmc_sampleid.jl @@ -1,11 +1,11 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -abstract type TransformedSampleID end +abstract type SampleID end struct TransformedMCMCSampleID{ T<:Int32, U<:Int64, -} <: TransformedSampleID +} <: SampleID chainid::T chaincycle::T stepno::U diff --git a/src/samplers/transformed_mcmc/mcmc_stats.jl b/src/samplers/transformed_mcmc/mcmc_stats.jl index 8fccc353e..cf57675d3 100644 --- a/src/samplers/transformed_mcmc/mcmc_stats.jl +++ b/src/samplers/transformed_mcmc/mcmc_stats.jl @@ -1,12 +1,12 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -abstract type TransformedAbstractMCMCStats end +abstract type AbstractMCMCStats end AbstractMCMCStats -struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end +struct TransformedMCMCNullStats <: AbstractMCMCStats end Base.push!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats @@ -15,7 +15,7 @@ Base.append!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats -struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats param_stats::BasicMvStatistics{P,FrequencyWeights} logtf_stats::BasicUvStatistics{L,FrequencyWeights} mode::Vector{P} diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl index 8352a88c7..c5a1cca01 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -2,18 +2,18 @@ """ - TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm + TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm No-op tuning, marks MCMC chains as tuned without performing any other changes on them. Useful if chains are pre-tuned or tuning is an internal part of the MCMC sampler implementation. """ -struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end +struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end export TransformedMCMCNoOpTuning -struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end +struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end (tuning::TransformedMCMCNoOpTuning)(chain::MCMCIterator) = TransformedMCMCNoOpTuner() get_tuner(tuning::TransformedMCMCNoOpTuning, chain::MCMCIterator) = TransformedMCMCNoOpTuner() diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl index 405470a38..93772cdf6 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -1,4 +1,4 @@ -@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm +@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm "Controls the weight given to new covariance information in adapting the proposal distribution." λ::Float64 = 0.5 @@ -22,7 +22,7 @@ end mutable struct TransformedProposalCovTuner{ S<:TransformedMCMCBasicStats -} <: TransformedAbstractMCMCTunerInstance +} <: AbstractMCMCTunerInstance config::TransformedAdaptiveMHTuning stats::S iteration::Int diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl index ae1572fbe..f3ea576e2 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -1,10 +1,10 @@ -@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning +@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning target_acceptance::Float64 = 0.234 #TODO AC: how to pass custom intitial value for cov matrix? σ_target_acceptance::Float64 = 0.05 gamma::Float64 = 2/3 end -@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance +@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance config::TransformedRAMTuner nsteps::Int = 0 end diff --git a/src/samplers/transformed_mcmc/mcmc_weighting.jl b/src/samplers/transformed_mcmc/mcmc_weighting.jl index 2d9662cae..478e735e5 100644 --- a/src/samplers/transformed_mcmc/mcmc_weighting.jl +++ b/src/samplers/transformed_mcmc/mcmc_weighting.jl @@ -2,14 +2,14 @@ """ - abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} + abstract type AbstractMCMCWeightingScheme{T<:Real} Abstract class for weighting schemes for MCMC samples. Weight values will have type `T`. """ -abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} end -export TransformedAbstractMCMCWeightingScheme +abstract type AbstractMCMCWeightingScheme{T<:Real} end +export AbstractMCMCWeightingScheme sample_weight_type(::Type{<:AbstractMCMCWeightingScheme{T}}) where {T} = T @@ -17,7 +17,7 @@ sample_weight_type(::Type{<:AbstractMCMCWeightingScheme{T}}) where {T} = T """ - struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} Sample weighting scheme suitable for sampling algorithms which may repeated samples multiple times in direct succession (e.g. @@ -29,14 +29,14 @@ Constructors: * ```$(FUNCTIONNAME)()``` """ -struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end +struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end export TransformedRepetitionWeighting TransformedRepetitionWeighting() = TransformedRepetitionWeighting{Int}() """ - TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} + TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} Sample weighting scheme suitable for accept/reject-based sampling algorithms (e.g. [`MetropolisHastings`](@ref)). Both accepted and rejected samples @@ -47,7 +47,7 @@ Constructors: * ```$(FUNCTIONNAME)()``` """ -struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end +struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end export TransformedARPWeighting TransformedARPWeighting() = TransformedARPWeighting{Float64}() diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl index 9edff8ccd..139ed6a2b 100644 --- a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -2,7 +2,7 @@ """ - struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm + struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm A multi-cycle MCMC burn-in algorithm. @@ -14,7 +14,7 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm +@with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm nsteps_per_cycle::Int64 = 10000 max_ncycles::Int = 30 nsteps_final::Int64 = div(nsteps_per_cycle, 10) diff --git a/src/samplers/transformed_mcmc/proposaldist.jl b/src/samplers/transformed_mcmc/proposaldist.jl index 94fb9bf3e..fd57cfb35 100644 --- a/src/samplers/transformed_mcmc/proposaldist.jl +++ b/src/samplers/transformed_mcmc/proposaldist.jl @@ -2,7 +2,7 @@ """ - abstract type TransformedAbstractProposalDist + abstract type AbstractProposalDist *BAT-internal, not part of stable public API.* @@ -13,7 +13,7 @@ The following functions must be implemented for subtypes: * `ValueShapes.totalndof`, returning the number of DOF (i.e. dimensionality). * `LinearAlgebra.issymmetric`, indicating whether p(a -> b) == p(b -> a) holds true. """ -abstract type TransformedAbstractProposalDist end +abstract type AbstractProposalDist end """ @@ -69,7 +69,7 @@ function proposal_rand! end -struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist d::D sampler_f::SamplerF s::S @@ -133,7 +133,7 @@ LinearAlgebra.issymmetric(pdist::TransformedGenericProposalDist) = issymmetric_a -struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist d::D scale::Vector{T} sampler_f::SamplerF @@ -175,10 +175,10 @@ end -abstract type TransformedProposalDistSpec end +abstract type ProposalDistSpec end -struct TransformedMvTDistProposal <: TransformedProposalDistSpec +struct TransformedMvTDistProposal <: ProposalDistSpec df::Float64 end @@ -197,7 +197,7 @@ function TransformedGenericProposalDist(::Type{MvTDist}, T::Type{<:AbstractFloat end -struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec +struct TransformedUvTDistProposalSpec <: ProposalDistSpec df::Float64 end diff --git a/src/samplers/transformed_mcmc/struct_list.jl b/src/samplers/transformed_mcmc/struct_list.jl index 4a16d922b..4ab047f02 100644 --- a/src/samplers/transformed_mcmc/struct_list.jl +++ b/src/samplers/transformed_mcmc/struct_list.jl @@ -8,34 +8,34 @@ struct TransformedGelmanRubinConvergence <: ConvergenceTest struct TransformedBrooksGelmanConvergence <: ConvergenceTest # Constructors: # @with_kw struct TransformedBrooksGelmanConvergence <: ConvergenceTest -struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end -struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats - struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} +struct TransformedMCMCNullStats <: AbstractMCMCStats end +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} # Constructors: -# struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end +# struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end # Constructors: -struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end -struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist -struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist -struct TransformedMvTDistProposal <: TransformedProposalDistSpec -struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec - struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm +struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedMvTDistProposal <: ProposalDistSpec +struct TransformedUvTDistProposalSpec <: ProposalDistSpec + struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm # Constructors: -# @with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm -struct TransformedNoMCMCTempering <: TransformedMCMCTempering end -# struct NoMCMCTemperingInstance <: TransformedMCMCTemperingInstance end - struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm +# @with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm +struct TransformedNoMCMCTempering <: MCMCTempering end +# struct NoMCMCTemperingInstance <: MCMCTemperingInstance end + struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm # Constructors: -# @with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm +# @with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm # function _construct_chain( # ) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids] struct TransformedMCMCSampleID{ -struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end -struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end -@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm +struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end +struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end +@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm mutable struct TransformedProposalCovTuner{ -@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning -@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance +@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning +@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance @with_kw struct TransformedMCMCIteratorInfo Constructors: struct TransformedMCMCSampleGenerator{ diff --git a/src/samplers/transformed_mcmc/tempering.jl b/src/samplers/transformed_mcmc/tempering.jl index 869488886..1601c7ae0 100644 --- a/src/samplers/transformed_mcmc/tempering.jl +++ b/src/samplers/transformed_mcmc/tempering.jl @@ -1,5 +1,5 @@ -abstract type TransformedMCMCTempering end -struct TransformedNoMCMCTempering <: TransformedMCMCTempering end +abstract type MCMCTempering end +struct TransformedNoMCMCTempering <: MCMCTempering end """ temper_mcmc_target!!(tempering::MCMCTemperingInstance, μ::BATMeasure, stepno::Integer) @@ -8,9 +8,9 @@ function temper_mcmc_target!! end -abstract type TransformedMCMCTemperingInstance end +abstract type MCMCTemperingInstance end -struct NoMCMCTemperingInstance <: TransformedMCMCTemperingInstance end +struct NoMCMCTemperingInstance <: MCMCTemperingInstance end temper_mcmc_target!!(tempering::NoMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ From e1d7abaa1807a5ef8b169942f689e9d0c331f3b6 Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Fri, 30 Jun 2023 13:57:00 +0200 Subject: [PATCH 07/33] chnage replace_type script --- src/samplers/transformed_mcmc/mcmc.jl | 2 +- .../transformed_mcmc/mcmc_algorithm.jl | 18 +++++++------ .../transformed_mcmc/mcmc_convergence.jl | 12 +++++---- src/samplers/transformed_mcmc/mcmc_iterate.jl | 2 +- src/samplers/transformed_mcmc/mcmc_sample.jl | 2 +- src/samplers/transformed_mcmc/mcmc_utils.jl | 2 ++ .../transformed_mcmc/multi_cycle_burnin.jl | 2 +- .../transformed_mcmc/replace_type_list.sh | 25 +++++++++---------- 8 files changed, 35 insertions(+), 30 deletions(-) diff --git a/src/samplers/transformed_mcmc/mcmc.jl b/src/samplers/transformed_mcmc/mcmc.jl index 2e7efec54..8ef0b3c67 100644 --- a/src/samplers/transformed_mcmc/mcmc.jl +++ b/src/samplers/transformed_mcmc/mcmc.jl @@ -1,6 +1,6 @@ using AffineMaps -include("mcmc_utils.jl") +#include("mcmc_utils.jl") include("mcmc_weighting.jl") include("proposaldist.jl") diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl index 07d4accf6..e50e52163 100644 --- a/src/samplers/transformed_mcmc/mcmc_algorithm.jl +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -13,7 +13,8 @@ Abstract type for MCMC initialization algorithms. abstract type MCMCInitAlgorithm end export MCMCInitAlgorithm -apply_trafo_to_init(trafo::Function, initalg::MCMCInitAlgorithm) = initalg +#TODO AC: reactivate +#apply_trafo_to_init(trafo::Function, initalg::MCMCInitAlgorithm) = initalg @@ -106,13 +107,14 @@ abstract type MCMCIterator end export MCMCIterator -function Base.show(io::IO, chain::MCMCIterator) - print(io, Base.typename(typeof(chain)).name, "(") - print(io, "id = "); show(io, mcmc_info(chain).id) - print(io, ", nsamples = "); show(io, nsamples(chain)) - print(io, ", density = "); show(io, getmeasure(chain)) - print(io, ")") -end +#TODO AC: reactivate +# function Base.show(io::IO, chain::MCMCIterator) +# print(io, Base.typename(typeof(chain)).name, "(") +# print(io, "id = "); show(io, mcmc_info(chain).id) +# print(io, ", nsamples = "); show(io, nsamples(chain)) +# print(io, ", density = "); show(io, getmeasure(chain)) +# print(io, ")") +# end function getalgorithm end diff --git a/src/samplers/transformed_mcmc/mcmc_convergence.jl b/src/samplers/transformed_mcmc/mcmc_convergence.jl index 9036f5b27..412e3b1d7 100644 --- a/src/samplers/transformed_mcmc/mcmc_convergence.jl +++ b/src/samplers/transformed_mcmc/mcmc_convergence.jl @@ -1,7 +1,7 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -function check_convergence!( +function transformed_check_convergence!( chains::AbstractVector{<:MCMCIterator}, samples::AbstractVector{<:DensitySampleVector}, algorithm::ConvergenceTest, @@ -23,7 +23,8 @@ end Gelman-Rubin ``\$R^2\$`` for all DOF. """ -function gr_Rsqr end +# TODO AC: reactivate +# function gr_Rsqr end function gr_Rsqr(stats::AbstractVector{<:TransformedMCMCBasicStats}) m = totalndof(first(stats)) @@ -32,9 +33,10 @@ function gr_Rsqr(stats::AbstractVector{<:TransformedMCMCBasicStats}) (W .+ B) ./ W end -function gr_Rsqr(samples::AbstractVector{<:DensitySampleVector}) - gr_Rsqr(TransformedMCMCBasicStats.(samples)) -end +#TODO AC: reactivate +# function gr_Rsqr(samples::AbstractVector{<:DensitySampleVector}) +# gr_Rsqr(TransformedMCMCBasicStats.(samples)) +# end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl index 5f7028c6d..fe927fbfc 100644 --- a/src/samplers/transformed_mcmc/mcmc_iterate.jl +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -3,7 +3,7 @@ mutable struct TransformedMCMCIterator{ PR<:RNGPartition, D<:BATMeasure, F, - Q<:MCMCProposal, + Q<:TransformedMCMCProposal, SV<:DensitySampleVector, S<:DensitySample, } <: MCMCIterator diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl index b09c2972d..04a2fd3a7 100644 --- a/src/samplers/transformed_mcmc/mcmc_sample.jl +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -24,7 +24,7 @@ struct TransformedMCMCDispatch end pre_transform::TR = bat_default(TransformedMCMCDispatch, Val(:pre_transform)) tuning_alg::MCMCTuningAlgorithm = TransformedRAMTuner() # TODO: use bat_defaults adaptive_transform::AdaptiveTransformSpec = default_adaptive_transform(tuning_alg) - proposal::MCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults + proposal::TransformedMCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults tempering = TransformedNoMCMCTempering() # TODO: use bat_defaults nchains::Int = 4 nsteps::Int = 10^5 diff --git a/src/samplers/transformed_mcmc/mcmc_utils.jl b/src/samplers/transformed_mcmc/mcmc_utils.jl index b87662cca..bce9f6bac 100644 --- a/src/samplers/transformed_mcmc/mcmc_utils.jl +++ b/src/samplers/transformed_mcmc/mcmc_utils.jl @@ -1,3 +1,5 @@ +# TODO AC: File not included as it would overwrite BAT.jl functions + function _cov_with_fallback(d) rng = bat_determ_rng() diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl index 139ed6a2b..5d987e11f 100644 --- a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -64,7 +64,7 @@ function mcmc_burnin!( isnothing(outputs) || append!(outputs, reduce(vcat, new_outputs)) - check_convergence!(chains, new_outputs, convergence_test) + transformed_check_convergence!(chains, new_outputs, convergence_test) # TODO AC: Rename # check_tuned/update_tuners... ntuned = count(c -> c.info.tuned, chains) diff --git a/src/samplers/transformed_mcmc/replace_type_list.sh b/src/samplers/transformed_mcmc/replace_type_list.sh index ec9216d63..8919c9185 100644 --- a/src/samplers/transformed_mcmc/replace_type_list.sh +++ b/src/samplers/transformed_mcmc/replace_type_list.sh @@ -1,14 +1,13 @@ # find . -name \*.jl -execdir sh -c -find . -name \*.jl -execdir sed -i -e "s/ \(MCMCInitAlgorithm\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(MCMCTuningAlgorithm\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(MCMCBurninAlgorithm\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(MCMCIterator\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(AbstractMCMCTunerInstance\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(MCMCProposal\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(SampleID\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(AbstractMCMCStats\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(AbstractMCMCWeightingScheme\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(AbstractProposalDist\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(ProposalDistSpec\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(MCMCTempering\)/ Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/ \(MCMCTemperingInstance\)/ Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(MCMCInitAlgorithm\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(MCMCTuningAlgorithm\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(MCMCBurninAlgorithm\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(AbstractMCMCTunerInstance\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(MCMCProposal\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(SampleID\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(AbstractMCMCStats\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(AbstractMCMCWeightingScheme\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(AbstractProposalDist\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(ProposalDistSpec\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(MCMCTempering\)/Transformed\1/g" {} \; +find . -name \*.jl -execdir sed -i -e "s/\(MCMCTemperingInstance\)/Transformed\1/g" {} \; From 66fa0adf8778116fec8338ffd3eece6e3b891989 Mon Sep 17 00:00:00 2001 From: AG Date: Fri, 30 Jun 2023 13:58:45 +0200 Subject: [PATCH 08/33] use Transformed prefix for all abstract types --- .../transformed_mcmc/chain_pool_init.jl | 6 +-- src/samplers/transformed_mcmc/example.jl | 2 +- .../transformed_mcmc/mcmc_algorithm.jl | 22 +++++----- src/samplers/transformed_mcmc/mcmc_iterate.jl | 26 ++++++------ src/samplers/transformed_mcmc/mcmc_sample.jl | 16 +++---- .../transformed_mcmc/mcmc_sampleid.jl | 30 ++++++------- src/samplers/transformed_mcmc/mcmc_stats.jl | 8 ++-- .../mcmc_tuning/mcmc_noop_tuner.jl | 6 +-- .../mcmc_tuning/mcmc_proposalcov_tuner.jl | 4 +- .../mcmc_tuning/mcmc_ram_tuner.jl | 4 +- .../transformed_mcmc/mcmc_weighting.jl | 16 +++---- .../transformed_mcmc/multi_cycle_burnin.jl | 8 ++-- src/samplers/transformed_mcmc/proposaldist.jl | 18 ++++---- src/samplers/transformed_mcmc/struct_list.jl | 42 +++++++++---------- src/samplers/transformed_mcmc/tempering.jl | 16 +++---- 15 files changed, 112 insertions(+), 112 deletions(-) diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl index 20acac8ad..ba0cffb26 100644 --- a/src/samplers/transformed_mcmc/chain_pool_init.jl +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -1,7 +1,7 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). """ - struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm + struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm MCMC chain pool initialization strategy. @@ -13,7 +13,7 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +@with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm init_tries_per_chain::ClosedInterval{Int64} = ClosedInterval(8, 128) nsteps_init::Int64 = 1000 initval_alg::InitvalAlgorithm = InitFromTarget() @@ -60,7 +60,7 @@ function mcmc_init!( density::AbstractMeasureOrDensity, nchains::Integer, init_alg::TransformedMCMCChainPoolInit, - tuning_alg::MCMCTuningAlgorithm, # TODO: part of algorithm? # MCMCTuner + tuning_alg::TransformedMCMCTuningAlgorithm, # TODO: part of algorithm? # MCMCTuner nonzero_weights::Bool, callback::Function ) diff --git a/src/samplers/transformed_mcmc/example.jl b/src/samplers/transformed_mcmc/example.jl index 8d270dc20..d59d98975 100644 --- a/src/samplers/transformed_mcmc/example.jl +++ b/src/samplers/transformed_mcmc/example.jl @@ -5,7 +5,7 @@ using ChangesOfVariables using BAT.LinearAlgebra using BAT.Distributions using BAT.InverseFunctions -import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoMCMCTempering, transformed_mcmc_step!!, TransformedMCMCSampleID +import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, TransformedMCMCTransformedSampleID using BAT.Random123 import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl index e50e52163..fedc92080 100644 --- a/src/samplers/transformed_mcmc/mcmc_algorithm.jl +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -6,35 +6,35 @@ function get_mcmc_tuning end #TODO: still needed """ - abstract type MCMCInitAlgorithm + abstract type TransformedMCMCInitAlgorithm Abstract type for MCMC initialization algorithms. """ -abstract type MCMCInitAlgorithm end -export MCMCInitAlgorithm +abstract type TransformedMCMCInitAlgorithm end +export TransformedMCMCInitAlgorithm #TODO AC: reactivate -#apply_trafo_to_init(trafo::Function, initalg::MCMCInitAlgorithm) = initalg +#apply_trafo_to_init(trafo::Function, initalg::TransformedMCMCInitAlgorithm) = initalg """ - abstract type MCMCTuningAlgorithm + abstract type TransformedMCMCTuningAlgorithm Abstract type for MCMC tuning algorithms. """ -abstract type MCMCTuningAlgorithm end -export MCMCTuningAlgorithm +abstract type TransformedMCMCTuningAlgorithm end +export TransformedMCMCTuningAlgorithm """ - abstract type MCMCBurninAlgorithm + abstract type TransformedMCMCBurninAlgorithm Abstract type for MCMC burn-in algorithms. """ -abstract type MCMCBurninAlgorithm end -export MCMCBurninAlgorithm +abstract type TransformedMCMCBurninAlgorithm end +export TransformedMCMCBurninAlgorithm @@ -147,7 +147,7 @@ DensitySampleVector(chain::MCMCIterator) = DensitySampleVector(sample_type(chain -abstract type AbstractMCMCTunerInstance end +abstract type TransformedAbstractMCMCTunerInstance end function tuning_init! end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl index fe927fbfc..ee686ce0f 100644 --- a/src/samplers/transformed_mcmc/mcmc_iterate.jl +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -3,7 +3,7 @@ mutable struct TransformedMCMCIterator{ PR<:RNGPartition, D<:BATMeasure, F, - Q<:TransformedMCMCProposal, + Q<:TransformedTransformedMCMCProposal, SV<:DensitySampleVector, S<:DensitySample, } <: MCMCIterator @@ -75,7 +75,7 @@ function TransformedMCMCIterator( g = init_adaptive_transform(rng, adaptive_transform_spec, μ) logd_x = logdensityof(μ, v_init) - sample_x = DensitySample(v_init, logd_x, 1, TransformedMCMCSampleID(id, 1, 0), nothing) # TODO + sample_x = DensitySample(v_init, logd_x, 1, TransformedMCMCTransformedSampleID(id, 1, 0), nothing) # TODO inverse_g = inverse(g) z = inverse_g(v_init) # sample_x.v logd_z = logdensityof(MeasureBase.pullback(g, μ),z) @@ -154,8 +154,8 @@ end function transformed_mcmc_step!!( iter::TransformedMCMCIterator, - tuner::AbstractMCMCTunerInstance, - tempering::MCMCTemperingInstance, + tuner::TransformedAbstractMCMCTunerInstance, + tempering::TransformedTransformedMCMCTemperingInstance, ) @unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter sample_x = last(samples) @@ -181,7 +181,7 @@ function transformed_mcmc_step!!( end sample_x_new, sample_z_new, samples_new = if accepted - sample_x_new = DensitySample(x_new, logd_x_new, 1, TransformedMCMCSampleID(iter.info.id, iter.info.cycle, iter.stepno), nothing) + sample_x_new = DensitySample(x_new, logd_x_new, 1, TransformedMCMCTransformedSampleID(iter.info.id, iter.info.cycle, iter.stepno), nothing) push!(samples, sample_x_new) sample_x_new, _rebuild_density_sample(sample_z, z_new, logd_z_new), samples else @@ -206,8 +206,8 @@ end function transformed_mcmc_iterate!( chain::TransformedMCMCIterator, - tuner::AbstractMCMCTunerInstance, - tempering::MCMCTemperingInstance; + tuner::TransformedAbstractMCMCTunerInstance, + tempering::TransformedTransformedMCMCTemperingInstance; max_nsteps::Integer = 1, max_time::Real = Inf, nonzero_weights::Bool = true, @@ -248,9 +248,9 @@ end function transformed_mcmc_iterate!( chain::MCMCIterator, - tuner::AbstractMCMCTunerInstance, - tempering::MCMCTemperingInstance; - # tuner::AbstractMCMCTunerInstance; + tuner::TransformedAbstractMCMCTunerInstance, + tempering::TransformedTransformedMCMCTemperingInstance; + # tuner::TransformedAbstractMCMCTunerInstance; max_nsteps::Integer = 1, max_time::Real = Inf, nonzero_weights::Bool = true, @@ -269,8 +269,8 @@ end function transformed_mcmc_iterate!( chains::AbstractVector{<:MCMCIterator}, - tuners::AbstractVector{<:AbstractMCMCTunerInstance}, - temperers::AbstractVector{<:MCMCTemperingInstance}; + tuners::AbstractVector{<:TransformedAbstractMCMCTunerInstance}, + temperers::AbstractVector{<:TransformedTransformedMCMCTemperingInstance}; kwargs... ) if isempty(chains) @@ -322,7 +322,7 @@ function next_cycle!( resize!(chain.samples, 1) chain.samples.weight[1] = 1 - chain.samples.info[1] = TransformedMCMCSampleID(chain.info.id, chain.info.cycle, chain.stepno) + chain.samples.info[1] = TransformedMCMCTransformedSampleID(chain.info.id, chain.info.cycle, chain.stepno) chain end diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl index 04a2fd3a7..626b15094 100644 --- a/src/samplers/transformed_mcmc/mcmc_sample.jl +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -1,4 +1,4 @@ -abstract type MCMCProposal end +abstract type TransformedMCMCProposal end """ BAT.TransformedMHProposal @@ -6,7 +6,7 @@ abstract type MCMCProposal end """ struct TransformedMHProposal{ D<:Union{Distribution, AbstractMeasure} -}<: MCMCProposal +}<: TransformedMCMCProposal proposal_dist::D end @@ -16,16 +16,16 @@ struct TransformedMCMCDispatch end @with_kw struct TransformedMCMCSampling{ TR<:AbstractTransformTarget, - IN<:MCMCInitAlgorithm, - BI<:MCMCBurninAlgorithm, + IN<:TransformedMCMCInitAlgorithm, + BI<:TransformedMCMCBurninAlgorithm, CT<:ConvergenceTest, CB<:Function } <: AbstractSamplingAlgorithm pre_transform::TR = bat_default(TransformedMCMCDispatch, Val(:pre_transform)) - tuning_alg::MCMCTuningAlgorithm = TransformedRAMTuner() # TODO: use bat_defaults + tuning_alg::TransformedMCMCTuningAlgorithm = TransformedRAMTuner() # TODO: use bat_defaults adaptive_transform::AdaptiveTransformSpec = default_adaptive_transform(tuning_alg) - proposal::TransformedMCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults - tempering = TransformedNoMCMCTempering() # TODO: use bat_defaults + proposal::TransformedTransformedMCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults + tempering = TransformedNoTransformedMCMCTempering() # TODO: use bat_defaults nchains::Int = 4 nsteps::Int = 10^5 #TODO: max_time ? @@ -145,7 +145,7 @@ function _run_sample_impl( transformed_mcmc_iterate!( chains, get_tuner.(Ref(TransformedMCMCNoOpTuning()),chains), - get_temperer.(Ref(TransformedNoMCMCTempering()), chains), + get_temperer.(Ref(TransformedNoTransformedMCMCTempering()), chains), max_nsteps = algorithm.nsteps, #TODO: maxtime nonzero_weights = algorithm.nonzero_weights, callback = (kwargs...) -> let pm=progress_meter; ProgressMeter.next!(pm) ; end, diff --git a/src/samplers/transformed_mcmc/mcmc_sampleid.jl b/src/samplers/transformed_mcmc/mcmc_sampleid.jl index 1dfb92c3f..fb072cd25 100644 --- a/src/samplers/transformed_mcmc/mcmc_sampleid.jl +++ b/src/samplers/transformed_mcmc/mcmc_sampleid.jl @@ -1,61 +1,61 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -abstract type SampleID end +abstract type TransformedSampleID end -struct TransformedMCMCSampleID{ +struct TransformedMCMCTransformedSampleID{ T<:Int32, U<:Int64, -} <: SampleID +} <: TransformedSampleID chainid::T chaincycle::T stepno::U end -function TransformedMCMCSampleID( +function TransformedMCMCTransformedSampleID( chainid::Integer, chaincycle::Integer, stepno::Integer, ) - TransformedMCMCSampleID(Int32(chainid), Int32(chaincycle), Int64(stepno)) + TransformedMCMCTransformedSampleID(Int32(chainid), Int32(chaincycle), Int64(stepno)) end -const TransformedMCMCSampleIDVector{TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} = StructArray{ - TransformedMCMCSampleID, +const TransformedMCMCTransformedSampleIDVector{TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} = StructArray{ + TransformedMCMCTransformedSampleID, 1, NamedTuple{(:chainid, :chaincycle, :stepno), Tuple{TV,TV,UV}}, Int } -function TransformedMCMCSampleIDVector(contents::Tuple{TV,TV,UV}) where {TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} - StructArray{TransformedMCMCSampleID}(contents)::TransformedMCMCSampleIDVector{TV,UV} +function TransformedMCMCTransformedSampleIDVector(contents::Tuple{TV,TV,UV}) where {TV<:AbstractVector{<:Int32},UV<:AbstractVector{<:Int64}} + StructArray{TransformedMCMCTransformedSampleID}(contents)::TransformedMCMCTransformedSampleIDVector{TV,UV} end -TransformedMCMCSampleIDVector(::UndefInitializer, len::Integer) = TransformedMCMCSampleIDVector(( +TransformedMCMCTransformedSampleIDVector(::UndefInitializer, len::Integer) = TransformedMCMCTransformedSampleIDVector(( Vector{Int32}(undef, len), Vector{Int32}(undef, len), Vector{Int64}(undef, len) )) -TransformedMCMCSampleIDVector() = TransformedMCMCSampleIDVector(undef, 0) +TransformedMCMCTransformedSampleIDVector() = TransformedMCMCTransformedSampleIDVector(undef, 0) -_create_undef_vector(::Type{TransformedMCMCSampleID}, len::Integer) = TransformedMCMCSampleIDVector(undef, len) +_create_undef_vector(::Type{TransformedMCMCTransformedSampleID}, len::Integer) = TransformedMCMCTransformedSampleIDVector(undef, len) # Specialize comparison, currently StructArray seems fall back to `(==)(A::AbstractArray, B::AbstractArray)` import Base.== -function(==)(A::TransformedMCMCSampleIDVector, B::TransformedMCMCSampleIDVector) +function(==)(A::TransformedMCMCTransformedSampleIDVector, B::TransformedMCMCTransformedSampleIDVector) A.chainid == B.chainid && A.chaincycle == B.chaincycle && A.stepno == B.stepno end -function Base.merge!(X::TransformedMCMCSampleIDVector, Xs::TransformedMCMCSampleIDVector...) +function Base.merge!(X::TransformedMCMCTransformedSampleIDVector, Xs::TransformedMCMCTransformedSampleIDVector...) for Y in Xs append!(X, Y) end X end -Base.merge(X::TransformedMCMCSampleIDVector, Xs::TransformedMCMCSampleIDVector...) = merge!(deepcopy(X), Xs...) +Base.merge(X::TransformedMCMCTransformedSampleIDVector, Xs::TransformedMCMCTransformedSampleIDVector...) = merge!(deepcopy(X), Xs...) diff --git a/src/samplers/transformed_mcmc/mcmc_stats.jl b/src/samplers/transformed_mcmc/mcmc_stats.jl index cf57675d3..214eafd4e 100644 --- a/src/samplers/transformed_mcmc/mcmc_stats.jl +++ b/src/samplers/transformed_mcmc/mcmc_stats.jl @@ -1,12 +1,12 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -abstract type AbstractMCMCStats end -AbstractMCMCStats +abstract type TransformedAbstractMCMCStats end +TransformedAbstractMCMCStats -struct TransformedMCMCNullStats <: AbstractMCMCStats end +struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end Base.push!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats @@ -15,7 +15,7 @@ Base.append!(stats::TransformedMCMCNullStats, sv::DensitySampleVector) = stats -struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats param_stats::BasicMvStatistics{P,FrequencyWeights} logtf_stats::BasicUvStatistics{L,FrequencyWeights} mode::Vector{P} diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl index c5a1cca01..8352a88c7 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -2,18 +2,18 @@ """ - TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm + TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm No-op tuning, marks MCMC chains as tuned without performing any other changes on them. Useful if chains are pre-tuned or tuning is an internal part of the MCMC sampler implementation. """ -struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end +struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end export TransformedMCMCNoOpTuning -struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end +struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end (tuning::TransformedMCMCNoOpTuning)(chain::MCMCIterator) = TransformedMCMCNoOpTuner() get_tuner(tuning::TransformedMCMCNoOpTuning, chain::MCMCIterator) = TransformedMCMCNoOpTuner() diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl index 93772cdf6..405470a38 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -1,4 +1,4 @@ -@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm +@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm "Controls the weight given to new covariance information in adapting the proposal distribution." λ::Float64 = 0.5 @@ -22,7 +22,7 @@ end mutable struct TransformedProposalCovTuner{ S<:TransformedMCMCBasicStats -} <: AbstractMCMCTunerInstance +} <: TransformedAbstractMCMCTunerInstance config::TransformedAdaptiveMHTuning stats::S iteration::Int diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl index f3ea576e2..ae1572fbe 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -1,10 +1,10 @@ -@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning +@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning target_acceptance::Float64 = 0.234 #TODO AC: how to pass custom intitial value for cov matrix? σ_target_acceptance::Float64 = 0.05 gamma::Float64 = 2/3 end -@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance +@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance config::TransformedRAMTuner nsteps::Int = 0 end diff --git a/src/samplers/transformed_mcmc/mcmc_weighting.jl b/src/samplers/transformed_mcmc/mcmc_weighting.jl index 478e735e5..d1fa41df7 100644 --- a/src/samplers/transformed_mcmc/mcmc_weighting.jl +++ b/src/samplers/transformed_mcmc/mcmc_weighting.jl @@ -2,22 +2,22 @@ """ - abstract type AbstractMCMCWeightingScheme{T<:Real} + abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} Abstract class for weighting schemes for MCMC samples. Weight values will have type `T`. """ -abstract type AbstractMCMCWeightingScheme{T<:Real} end -export AbstractMCMCWeightingScheme +abstract type TransformedAbstractMCMCWeightingScheme{T<:Real} end +export TransformedAbstractMCMCWeightingScheme -sample_weight_type(::Type{<:AbstractMCMCWeightingScheme{T}}) where {T} = T +sample_weight_type(::Type{<:TransformedAbstractMCMCWeightingScheme{T}}) where {T} = T """ - struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} Sample weighting scheme suitable for sampling algorithms which may repeated samples multiple times in direct succession (e.g. @@ -29,14 +29,14 @@ Constructors: * ```$(FUNCTIONNAME)()``` """ -struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end +struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end export TransformedRepetitionWeighting TransformedRepetitionWeighting() = TransformedRepetitionWeighting{Int}() """ - TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} + TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} Sample weighting scheme suitable for accept/reject-based sampling algorithms (e.g. [`MetropolisHastings`](@ref)). Both accepted and rejected samples @@ -47,7 +47,7 @@ Constructors: * ```$(FUNCTIONNAME)()``` """ -struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end +struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end export TransformedARPWeighting TransformedARPWeighting() = TransformedARPWeighting{Float64}() diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl index 5d987e11f..d8cc8f3ba 100644 --- a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -2,7 +2,7 @@ """ - struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm + struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm A multi-cycle MCMC burn-in algorithm. @@ -14,7 +14,7 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm +@with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm nsteps_per_cycle::Int64 = 10000 max_ncycles::Int = 30 nsteps_final::Int64 = div(nsteps_per_cycle, 10) @@ -26,8 +26,8 @@ export TransformedMCMCMultiCycleBurnin function mcmc_burnin!( outputs::Union{DensitySampleVector,Nothing}, chains::AbstractVector{<:MCMCIterator}, - tuners::AbstractVector{<:AbstractMCMCTunerInstance}, - temperers::AbstractVector{<:MCMCTemperingInstance}, + tuners::AbstractVector{<:TransformedAbstractMCMCTunerInstance}, + temperers::AbstractVector{<:TransformedTransformedMCMCTemperingInstance}, burnin_alg::TransformedMCMCMultiCycleBurnin, convergence_test::ConvergenceTest, strict_mode::Bool, diff --git a/src/samplers/transformed_mcmc/proposaldist.jl b/src/samplers/transformed_mcmc/proposaldist.jl index fd57cfb35..d58e6e3b4 100644 --- a/src/samplers/transformed_mcmc/proposaldist.jl +++ b/src/samplers/transformed_mcmc/proposaldist.jl @@ -2,7 +2,7 @@ """ - abstract type AbstractProposalDist + abstract type TransformedAbstractProposalDist *BAT-internal, not part of stable public API.* @@ -13,13 +13,13 @@ The following functions must be implemented for subtypes: * `ValueShapes.totalndof`, returning the number of DOF (i.e. dimensionality). * `LinearAlgebra.issymmetric`, indicating whether p(a -> b) == p(b -> a) holds true. """ -abstract type AbstractProposalDist end +abstract type TransformedAbstractProposalDist end """ proposaldist_logpdf( p::AbstractArray, - pdist::AbstractProposalDist, + pdist::TransformedAbstractProposalDist, v_proposed::AbstractVector, v_current:::AbstractVector ) @@ -69,7 +69,7 @@ function proposal_rand! end -struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist d::D sampler_f::SamplerF s::S @@ -94,7 +94,7 @@ TransformedGenericProposalDist(D::Type{<:Distribution{Multivariate}}, varndof::I Base.similar(q::TransformedGenericProposalDist, d::Distribution{Multivariate}) = TransformedGenericProposalDist(d, q.sampler_f) -function Base.convert(::Type{AbstractProposalDist}, q::TransformedGenericProposalDist, T::Type{<:AbstractFloat}, varndof::Integer) +function Base.convert(::Type{TransformedAbstractProposalDist}, q::TransformedGenericProposalDist, T::Type{<:AbstractFloat}, varndof::Integer) varndof != totalndof(q) && throw(ArgumentError("q has wrong number of DOF")) q end @@ -133,7 +133,7 @@ LinearAlgebra.issymmetric(pdist::TransformedGenericProposalDist) = issymmetric_a -struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist d::D scale::Vector{T} sampler_f::SamplerF @@ -175,10 +175,10 @@ end -abstract type ProposalDistSpec end +abstract type TransformedProposalDistSpec end -struct TransformedMvTDistProposal <: ProposalDistSpec +struct TransformedMvTDistProposal <: TransformedProposalDistSpec df::Float64 end @@ -197,7 +197,7 @@ function TransformedGenericProposalDist(::Type{MvTDist}, T::Type{<:AbstractFloat end -struct TransformedUvTDistProposalSpec <: ProposalDistSpec +struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec df::Float64 end diff --git a/src/samplers/transformed_mcmc/struct_list.jl b/src/samplers/transformed_mcmc/struct_list.jl index 4ab047f02..7b3982824 100644 --- a/src/samplers/transformed_mcmc/struct_list.jl +++ b/src/samplers/transformed_mcmc/struct_list.jl @@ -8,34 +8,34 @@ struct TransformedGelmanRubinConvergence <: ConvergenceTest struct TransformedBrooksGelmanConvergence <: ConvergenceTest # Constructors: # @with_kw struct TransformedBrooksGelmanConvergence <: ConvergenceTest -struct TransformedMCMCNullStats <: AbstractMCMCStats end -struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: AbstractMCMCStats - struct TransformedRepetitionWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} +struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end +struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats + struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} # Constructors: -# struct TransformedRepetitionWeighting{T<:Real} <: AbstractMCMCWeightingScheme{T} end +# struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end # Constructors: -struct TransformedARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} end -struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: AbstractProposalDist -struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: AbstractProposalDist -struct TransformedMvTDistProposal <: ProposalDistSpec -struct TransformedUvTDistProposalSpec <: ProposalDistSpec - struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm +struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end +struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist +struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist +struct TransformedMvTDistProposal <: TransformedProposalDistSpec +struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec + struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm # Constructors: -# @with_kw struct TransformedMCMCMultiCycleBurnin <: MCMCBurninAlgorithm -struct TransformedNoMCMCTempering <: MCMCTempering end -# struct NoMCMCTemperingInstance <: MCMCTemperingInstance end - struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +# @with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm +struct TransformedNoTransformedMCMCTempering <: TransformedMCMCTempering end +# struct NoTransformedTransformedMCMCTemperingInstance <: TransformedTransformedMCMCTemperingInstance end + struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm # Constructors: -# @with_kw struct TransformedMCMCChainPoolInit <: MCMCInitAlgorithm +# @with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm # function _construct_chain( # ) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids] -struct TransformedMCMCSampleID{ -struct TransformedMCMCNoOpTuning <: MCMCTuningAlgorithm end -struct TransformedMCMCNoOpTuner <: AbstractMCMCTunerInstance end -@with_kw struct TransformedAdaptiveMHTuning <: MCMCTuningAlgorithm +struct TransformedMCMCTransformedSampleID{ +struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end +struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end +@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm mutable struct TransformedProposalCovTuner{ -@with_kw struct TransformedRAMTuner <: MCMCTuningAlgorithm #TODO: rename to RAMTuning -@with_kw mutable struct TransformedRAMTunerInstance <: AbstractMCMCTunerInstance +@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning +@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance @with_kw struct TransformedMCMCIteratorInfo Constructors: struct TransformedMCMCSampleGenerator{ diff --git a/src/samplers/transformed_mcmc/tempering.jl b/src/samplers/transformed_mcmc/tempering.jl index 1601c7ae0..fcdc6886b 100644 --- a/src/samplers/transformed_mcmc/tempering.jl +++ b/src/samplers/transformed_mcmc/tempering.jl @@ -1,18 +1,18 @@ -abstract type MCMCTempering end -struct TransformedNoMCMCTempering <: MCMCTempering end +abstract type TransformedMCMCTempering end +struct TransformedNoTransformedMCMCTempering <: TransformedMCMCTempering end """ - temper_mcmc_target!!(tempering::MCMCTemperingInstance, μ::BATMeasure, stepno::Integer) + temper_mcmc_target!!(tempering::TransformedTransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) """ function temper_mcmc_target!! end -abstract type MCMCTemperingInstance end +abstract type TransformedTransformedMCMCTemperingInstance end -struct NoMCMCTemperingInstance <: MCMCTemperingInstance end +struct NoTransformedTransformedMCMCTemperingInstance <: TransformedTransformedMCMCTemperingInstance end -temper_mcmc_target!!(tempering::NoMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ +temper_mcmc_target!!(tempering::NoTransformedTransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ -get_temperer(tempering::TransformedNoMCMCTempering, density::BATMeasure) = NoMCMCTemperingInstance() -get_temperer(tempering::TransformedNoMCMCTempering, chain::MCMCIterator) = get_temperer(tempering, chain.μ) +get_temperer(tempering::TransformedNoTransformedMCMCTempering, density::BATMeasure) = NoTransformedTransformedMCMCTemperingInstance() +get_temperer(tempering::TransformedNoTransformedMCMCTempering, chain::MCMCIterator) = get_temperer(tempering, chain.μ) From 6c39fee205e110bc41e7fc7b1ef318a7b657381d Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Fri, 30 Jun 2023 14:16:28 +0200 Subject: [PATCH 09/33] new TransformedMCMCSampling and old MCMCSampling now both working --- .../transformed_mcmc/mcmc_algorithm.jl | 86 +++++++++---------- .../transformed_mcmc/mcmc_convergence.jl | 7 +- src/samplers/transformed_mcmc/mcmc_iterate.jl | 10 +-- src/samplers/transformed_mcmc/mcmc_sample.jl | 2 +- .../transformed_mcmc/multi_cycle_burnin.jl | 2 +- src/samplers/transformed_mcmc/proposaldist.jl | 73 ++++++++-------- src/samplers/transformed_mcmc/struct_list.jl | 2 +- src/samplers/transformed_mcmc/tempering.jl | 10 +-- 8 files changed, 96 insertions(+), 96 deletions(-) diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl index fedc92080..08bb9d14a 100644 --- a/src/samplers/transformed_mcmc/mcmc_algorithm.jl +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -46,65 +46,66 @@ export TransformedMCMCBurninAlgorithm end -""" - abstract type MCMCIterator end +# TODO AC: reactivate +# """ +# abstract type MCMCIterator end -Represents the current state of an MCMC chain. +# Represents the current state of an MCMC chain. -!!! note +# !!! note - The details of the `MCMCIterator` and `MCMCAlgorithm` API (see below) - currently do not form part of the stable API and are subject to change - without deprecation. +# The details of the `MCMCIterator` and `MCMCAlgorithm` API (see below) +# currently do not form part of the stable API and are subject to change +# without deprecation. -To implement a new MCMC algorithm, subtypes of both [`MCMCAlgorithm`](@ref) -and `MCMCIterator` are required. +# To implement a new MCMC algorithm, subtypes of both [`MCMCAlgorithm`](@ref) +# and `MCMCIterator` are required. -The following methods must be defined for subtypes of `MCMCIterator` (e.g. -`SomeMCMCIter<:MCMCIterator`): +# The following methods must be defined for subtypes of `MCMCIterator` (e.g. +# `SomeMCMCIter<:MCMCIterator`): -```julia +# ```julia -BAT.getmeasure(chain::SomeMCMCIter)::AbstractMeasureOrDensity +# BAT.getmeasure(chain::SomeMCMCIter)::AbstractMeasureOrDensity -BAT.getrng(chain::SomeMCMCIter)::AbstractRNG +# BAT.getrng(chain::SomeMCMCIter)::AbstractRNG -BAT.mcmc_info(chain::SomeMCMCIter)::TransformedMCMCIteratorInfo +# BAT.mcmc_info(chain::SomeMCMCIter)::TransformedMCMCIteratorInfo -BAT.nsteps(chain::SomeMCMCIter)::Int +# BAT.nsteps(chain::SomeMCMCIter)::Int -BAT.nsamples(chain::SomeMCMCIter)::Int +# BAT.nsamples(chain::SomeMCMCIter)::Int -BAT.current_sample(chain::SomeMCMCIter)::DensitySample +# BAT.current_sample(chain::SomeMCMCIter)::DensitySample -BAT.sample_type(chain::SomeMCMCIter)::Type{<:DensitySample} +# BAT.sample_type(chain::SomeMCMCIter)::Type{<:DensitySample} -BAT.samples_available(chain::SomeMCMCIter, nonzero_weights::Bool = false)::Bool +# BAT.samples_available(chain::SomeMCMCIter, nonzero_weights::Bool = false)::Bool -BAT.get_samples!(samples::DensitySampleVector, chain::SomeMCMCIter, nonzero_weights::Bool)::typeof(samples) +# BAT.get_samples!(samples::DensitySampleVector, chain::SomeMCMCIter, nonzero_weights::Bool)::typeof(samples) -BAT.next_cycle!(chain::SomeMCMCIter)::SomeMCMCIter +# BAT.next_cycle!(chain::SomeMCMCIter)::SomeMCMCIter -BAT.mcmc_step!( - chain::SomeMCMCIter - callback::Function, -)::nothing -``` +# BAT.mcmc_step!( +# chain::SomeMCMCIter +# callback::Function, +# )::nothing +# ``` -The following methods are implemented by default: +# The following methods are implemented by default: -```julia -getalgorithm(chain::MCMCIterator) -getmeasure(chain::MCMCIterator) -DensitySampleVector(chain::MCMCIterator) -mcmc_iterate!(chain::MCMCIterator, ...) -mcmc_iterate!(chains::AbstractVector{<:MCMCIterator}, ...) -isvalidchain(chain::MCMCIterator) -isviablechain(chain::MCMCIterator) -``` -""" -abstract type MCMCIterator end -export MCMCIterator +# ```julia +# getalgorithm(chain::MCMCIterator) +# getmeasure(chain::MCMCIterator) +# DensitySampleVector(chain::MCMCIterator) +# mcmc_iterate!(chain::MCMCIterator, ...) +# mcmc_iterate!(chains::AbstractVector{<:MCMCIterator}, ...) +# isvalidchain(chain::MCMCIterator) +# isviablechain(chain::MCMCIterator) +# ``` +# """ +# abstract type MCMCIterator end +# export MCMCIterator #TODO AC: reactivate @@ -142,9 +143,8 @@ function next_cycle! end function mcmc_step! end - -DensitySampleVector(chain::MCMCIterator) = DensitySampleVector(sample_type(chain), totalndof(getmeasure(chain))) - +# TODO AC: reactivate +#DensitySampleVector(chain::MCMCIterator) = DensitySampleVector(sample_type(chain), totalndof(getmeasure(chain))) abstract type TransformedAbstractMCMCTunerInstance end diff --git a/src/samplers/transformed_mcmc/mcmc_convergence.jl b/src/samplers/transformed_mcmc/mcmc_convergence.jl index 412e3b1d7..61c476b8b 100644 --- a/src/samplers/transformed_mcmc/mcmc_convergence.jl +++ b/src/samplers/transformed_mcmc/mcmc_convergence.jl @@ -114,9 +114,10 @@ function bg_R_2sqr(stats::AbstractVector{<:TransformedMCMCBasicStats}; corrected R_unc.*(d.+3)./(d.+1) end -function bg_R_2sqr(samples::AbstractVector{<:DensitySampleVector}; corrected::Bool = false) - bg_R_2sqr(TransformedMCMCBasicStats.(samples), corrected = corrected) -end +# TODO AC: reactivate +# function bg_R_2sqr(samples::AbstractVector{<:DensitySampleVector}; corrected::Bool = false) +# bg_R_2sqr(TransformedMCMCBasicStats.(samples), corrected = corrected) +# end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl index ee686ce0f..a2923bd8e 100644 --- a/src/samplers/transformed_mcmc/mcmc_iterate.jl +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -3,7 +3,7 @@ mutable struct TransformedMCMCIterator{ PR<:RNGPartition, D<:BATMeasure, F, - Q<:TransformedTransformedMCMCProposal, + Q<:TransformedMCMCProposal, SV<:DensitySampleVector, S<:DensitySample, } <: MCMCIterator @@ -155,7 +155,7 @@ end function transformed_mcmc_step!!( iter::TransformedMCMCIterator, tuner::TransformedAbstractMCMCTunerInstance, - tempering::TransformedTransformedMCMCTemperingInstance, + tempering::TransformedMCMCTemperingInstance, ) @unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter sample_x = last(samples) @@ -207,7 +207,7 @@ end function transformed_mcmc_iterate!( chain::TransformedMCMCIterator, tuner::TransformedAbstractMCMCTunerInstance, - tempering::TransformedTransformedMCMCTemperingInstance; + tempering::TransformedMCMCTemperingInstance; max_nsteps::Integer = 1, max_time::Real = Inf, nonzero_weights::Bool = true, @@ -249,7 +249,7 @@ end function transformed_mcmc_iterate!( chain::MCMCIterator, tuner::TransformedAbstractMCMCTunerInstance, - tempering::TransformedTransformedMCMCTemperingInstance; + tempering::TransformedMCMCTemperingInstance; # tuner::TransformedAbstractMCMCTunerInstance; max_nsteps::Integer = 1, max_time::Real = Inf, @@ -270,7 +270,7 @@ end function transformed_mcmc_iterate!( chains::AbstractVector{<:MCMCIterator}, tuners::AbstractVector{<:TransformedAbstractMCMCTunerInstance}, - temperers::AbstractVector{<:TransformedTransformedMCMCTemperingInstance}; + temperers::AbstractVector{<:TransformedMCMCTemperingInstance}; kwargs... ) if isempty(chains) diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl index 626b15094..e1d3c8734 100644 --- a/src/samplers/transformed_mcmc/mcmc_sample.jl +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -24,7 +24,7 @@ struct TransformedMCMCDispatch end pre_transform::TR = bat_default(TransformedMCMCDispatch, Val(:pre_transform)) tuning_alg::TransformedMCMCTuningAlgorithm = TransformedRAMTuner() # TODO: use bat_defaults adaptive_transform::AdaptiveTransformSpec = default_adaptive_transform(tuning_alg) - proposal::TransformedTransformedMCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults + proposal::TransformedMCMCProposal = TransformedMHProposal(Normal()) #TODO: use bat_defaults tempering = TransformedNoTransformedMCMCTempering() # TODO: use bat_defaults nchains::Int = 4 nsteps::Int = 10^5 diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl index d8cc8f3ba..2e167935c 100644 --- a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -27,7 +27,7 @@ function mcmc_burnin!( outputs::Union{DensitySampleVector,Nothing}, chains::AbstractVector{<:MCMCIterator}, tuners::AbstractVector{<:TransformedAbstractMCMCTunerInstance}, - temperers::AbstractVector{<:TransformedTransformedMCMCTemperingInstance}, + temperers::AbstractVector{<:TransformedMCMCTemperingInstance}, burnin_alg::TransformedMCMCMultiCycleBurnin, convergence_test::ConvergenceTest, strict_mode::Bool, diff --git a/src/samplers/transformed_mcmc/proposaldist.jl b/src/samplers/transformed_mcmc/proposaldist.jl index d58e6e3b4..231ce0be9 100644 --- a/src/samplers/transformed_mcmc/proposaldist.jl +++ b/src/samplers/transformed_mcmc/proposaldist.jl @@ -1,6 +1,5 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). - """ abstract type TransformedAbstractProposalDist @@ -15,57 +14,57 @@ The following functions must be implemented for subtypes: """ abstract type TransformedAbstractProposalDist end +# TODO AC: reactivate +# """ +# proposaldist_logpdf( +# p::AbstractArray, +# pdist::TransformedAbstractProposalDist, +# v_proposed::AbstractVector, +# v_current:::AbstractVector +# ) -""" - proposaldist_logpdf( - p::AbstractArray, - pdist::TransformedAbstractProposalDist, - v_proposed::AbstractVector, - v_current:::AbstractVector - ) - -*BAT-internal, not part of stable public API.* +# *BAT-internal, not part of stable public API.* -Returns log(PDF) value of `pdist` for transitioning from current to proposed -variate/parameters. -""" -function proposaldist_logpdf end +# Returns log(PDF) value of `pdist` for transitioning from current to proposed +# variate/parameters. +# """#function proposaldist_logpdf end # TODO: Implement proposaldist_logpdf for included proposal distributions -""" - function proposal_rand!( - rng::AbstractRNG, - pdist::TransformedGenericProposalDist, - v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, - v_current::Union{AbstractVector,VectorOfSimilarVectors} - ) +# TODO AC: reactivate +# """ +# function proposal_rand!( +# rng::AbstractRNG, +# pdist::TransformedGenericProposalDist, +# v_proposed::Union{AbstractVector,VectorOfSimilarVectors}, +# v_current::Union{AbstractVector,VectorOfSimilarVectors} +# ) -*BAT-internal, not part of stable public API.* +# *BAT-internal, not part of stable public API.* -Generate one or multiple proposed variate/parameter vectors, based on one or -multiple previous vectors. +# Generate one or multiple proposed variate/parameter vectors, based on one or +# multiple previous vectors. -Input: +# Input: -* `rng`: Random number generator to use -* `pdist`: Proposal distribution to use -* `v_current`: Old values (vector or column vectors, if a matrix) +# * `rng`: Random number generator to use +# * `pdist`: Proposal distribution to use +# * `v_current`: Old values (vector or column vectors, if a matrix) -Output is stored in +# Output is stored in -* `v_proposed`: New values (vector or column vectors, if a matrix) +# * `v_proposed`: New values (vector or column vectors, if a matrix) -The caller must guarantee: +# The caller must guarantee: -* `size(v_current, 1) == size(v_proposed, 1)` -* `size(v_current, 2) == size(v_proposed, 2)` or `size(v_current, 2) == 1` -* `v_proposed !== v_current` (no aliasing) +# * `size(v_current, 1) == size(v_proposed, 1)` +# * `size(v_current, 2) == size(v_proposed, 2)` or `size(v_current, 2) == 1` +# * `v_proposed !== v_current` (no aliasing) -Implementations of `proposal_rand!` must be thread-safe. -""" -function proposal_rand! end +# Implementations of `proposal_rand!` must be thread-safe. +# """ +# function proposal_rand! end diff --git a/src/samplers/transformed_mcmc/struct_list.jl b/src/samplers/transformed_mcmc/struct_list.jl index 7b3982824..604040216 100644 --- a/src/samplers/transformed_mcmc/struct_list.jl +++ b/src/samplers/transformed_mcmc/struct_list.jl @@ -23,7 +23,7 @@ struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec # Constructors: # @with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm struct TransformedNoTransformedMCMCTempering <: TransformedMCMCTempering end -# struct NoTransformedTransformedMCMCTemperingInstance <: TransformedTransformedMCMCTemperingInstance end +# struct NoTransformedMCMCTemperingInstance <: TransformedMCMCTemperingInstance end struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm # Constructors: # @with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm diff --git a/src/samplers/transformed_mcmc/tempering.jl b/src/samplers/transformed_mcmc/tempering.jl index fcdc6886b..4e0c4a005 100644 --- a/src/samplers/transformed_mcmc/tempering.jl +++ b/src/samplers/transformed_mcmc/tempering.jl @@ -2,17 +2,17 @@ abstract type TransformedMCMCTempering end struct TransformedNoTransformedMCMCTempering <: TransformedMCMCTempering end """ - temper_mcmc_target!!(tempering::TransformedTransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) + temper_mcmc_target!!(tempering::TransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) """ function temper_mcmc_target!! end -abstract type TransformedTransformedMCMCTemperingInstance end +abstract type TransformedMCMCTemperingInstance end -struct NoTransformedTransformedMCMCTemperingInstance <: TransformedTransformedMCMCTemperingInstance end +struct NoTransformedMCMCTemperingInstance <: TransformedMCMCTemperingInstance end -temper_mcmc_target!!(tempering::NoTransformedTransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ +temper_mcmc_target!!(tempering::NoTransformedMCMCTemperingInstance, μ::BATMeasure, stepno::Integer) = tempering, μ -get_temperer(tempering::TransformedNoTransformedMCMCTempering, density::BATMeasure) = NoTransformedTransformedMCMCTemperingInstance() +get_temperer(tempering::TransformedNoTransformedMCMCTempering, density::BATMeasure) = NoTransformedMCMCTemperingInstance() get_temperer(tempering::TransformedNoTransformedMCMCTempering, chain::MCMCIterator) = get_temperer(tempering, chain.μ) From f640fbc57a49e04638747076759e976febb47189 Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Mon, 3 Jul 2023 15:21:50 +0200 Subject: [PATCH 10/33] move example --- .../dev-internal/transformed_example.jl | 0 .../transformed_mcmc/replace_type_list.sh | 13 ------ src/samplers/transformed_mcmc/struct_list.jl | 46 ------------------- 3 files changed, 59 deletions(-) rename src/samplers/transformed_mcmc/example.jl => examples/dev-internal/transformed_example.jl (100%) delete mode 100644 src/samplers/transformed_mcmc/replace_type_list.sh delete mode 100644 src/samplers/transformed_mcmc/struct_list.jl diff --git a/src/samplers/transformed_mcmc/example.jl b/examples/dev-internal/transformed_example.jl similarity index 100% rename from src/samplers/transformed_mcmc/example.jl rename to examples/dev-internal/transformed_example.jl diff --git a/src/samplers/transformed_mcmc/replace_type_list.sh b/src/samplers/transformed_mcmc/replace_type_list.sh deleted file mode 100644 index 8919c9185..000000000 --- a/src/samplers/transformed_mcmc/replace_type_list.sh +++ /dev/null @@ -1,13 +0,0 @@ -# find . -name \*.jl -execdir sh -c -find . -name \*.jl -execdir sed -i -e "s/\(MCMCInitAlgorithm\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(MCMCTuningAlgorithm\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(MCMCBurninAlgorithm\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(AbstractMCMCTunerInstance\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(MCMCProposal\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(SampleID\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(AbstractMCMCStats\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(AbstractMCMCWeightingScheme\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(AbstractProposalDist\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(ProposalDistSpec\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(MCMCTempering\)/Transformed\1/g" {} \; -find . -name \*.jl -execdir sed -i -e "s/\(MCMCTemperingInstance\)/Transformed\1/g" {} \; diff --git a/src/samplers/transformed_mcmc/struct_list.jl b/src/samplers/transformed_mcmc/struct_list.jl deleted file mode 100644 index 604040216..000000000 --- a/src/samplers/transformed_mcmc/struct_list.jl +++ /dev/null @@ -1,46 +0,0 @@ - - - - -struct TransformedGelmanRubinConvergence <: ConvergenceTest -# Constructors: -# @with_kw struct TransformedGelmanRubinConvergence <: ConvergenceTest - struct TransformedBrooksGelmanConvergence <: ConvergenceTest -# Constructors: -# @with_kw struct TransformedBrooksGelmanConvergence <: ConvergenceTest -struct TransformedMCMCNullStats <: TransformedAbstractMCMCStats end -struct TransformedMCMCBasicStats{L<:Real,P<:Real} <: TransformedAbstractMCMCStats - struct TransformedRepetitionWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} -# Constructors: -# struct TransformedRepetitionWeighting{T<:Real} <: TransformedAbstractMCMCWeightingScheme{T} end -# Constructors: -struct TransformedARPWeighting{T<:AbstractFloat} <: TransformedAbstractMCMCWeightingScheme{T} end -struct TransformedGenericProposalDist{D<:Distribution{Multivariate},SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist -struct TransformedGenericUvProposalDist{D<:Distribution{Univariate},T<:Real,SamplerF,S<:Sampleable} <: TransformedAbstractProposalDist -struct TransformedMvTDistProposal <: TransformedProposalDistSpec -struct TransformedUvTDistProposalSpec <: TransformedProposalDistSpec - struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm -# Constructors: -# @with_kw struct TransformedMCMCMultiCycleBurnin <: TransformedMCMCBurninAlgorithm -struct TransformedNoTransformedMCMCTempering <: TransformedMCMCTempering end -# struct NoTransformedMCMCTemperingInstance <: TransformedMCMCTemperingInstance end - struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm -# Constructors: -# @with_kw struct TransformedMCMCChainPoolInit <: TransformedMCMCInitAlgorithm -# function _construct_chain( -# ) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids] -struct TransformedMCMCTransformedSampleID{ -struct TransformedMCMCNoOpTuning <: TransformedMCMCTuningAlgorithm end -struct TransformedMCMCNoOpTuner <: TransformedAbstractMCMCTunerInstance end -@with_kw struct TransformedAdaptiveMHTuning <: TransformedMCMCTuningAlgorithm -mutable struct TransformedProposalCovTuner{ -@with_kw struct TransformedRAMTuner <: TransformedMCMCTuningAlgorithm #TODO: rename to RAMTuning -@with_kw mutable struct TransformedRAMTunerInstance <: TransformedAbstractMCMCTunerInstance -@with_kw struct TransformedMCMCIteratorInfo -Constructors: -struct TransformedMCMCSampleGenerator{ -mutable struct TransformedMCMCIterator{ -struct TransformedMHProposal{ -# TODO AC: find a better solution for this. Problem is that in the with_kw constructor below, we need to dispatch on this type. -struct TransformedMCMCDispatch end -@with_kw struct TransformedMCMCSampling{ \ No newline at end of file From d0ce0109e900c5a7a69baf831de7afa2d7b84ed9 Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Mon, 3 Jul 2023 16:32:45 +0200 Subject: [PATCH 11/33] use full matrix instead of lower cholesky in AdaptiveMHTuner --- examples/dev-internal/transformed_example.jl | 15 ++++++++++++--- .../mcmc_tuning/mcmc_proposalcov_tuner.jl | 8 ++++---- src/transforms/adaptive_transform.jl | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/examples/dev-internal/transformed_example.jl b/examples/dev-internal/transformed_example.jl index d59d98975..84d169f8f 100644 --- a/examples/dev-internal/transformed_example.jl +++ b/examples/dev-internal/transformed_example.jl @@ -17,11 +17,20 @@ rng = Philox4x() posterior = BAT.example_posterior() my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000)) -my_samples = my_result.result -mh_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(tuning_alg=TransformedAdaptiveMHTuning(), pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000)) -(;chain, tuner) = BAT.g_state + + +density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior) +density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo) + +c = BAT._approx_cov(density) +f = BAT.CustomTransform(Mul(c)) + +my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f)) + +my_samples = my_result.result + using Plots diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl index 405470a38..7fa554961 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -85,7 +85,7 @@ function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, transform = chain.f_transform - #TODO AC: check with Oli + #TODO AC: rename S_L to S, check with Oli S_L = transform.A Σ_old = S_L @@ -113,8 +113,8 @@ function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, Σ_new = new_Σ_unscal * tuner.scale #TODO AC: check - S = cholesky(Positive, Σ_new) - chain.f_transform = Mul(S.L) + #S = cholesky(Positive, Σ_new) + chain.f_transform = Mul(Σ_new) tuner.iteration += 1 nothing @@ -133,7 +133,7 @@ tuning_callback(::TransformedProposalCovTuner) = nop_func function tune_mcmc_transform!!( rng::AbstractRNG, tuner::TransformedProposalCovTuner, - transform::Mul{<:LowerTriangular}, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, + transform::Any, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, p_accept::Real, z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead z_current::Vector{<:Float64}, diff --git a/src/transforms/adaptive_transform.jl b/src/transforms/adaptive_transform.jl index 6dba6955d..5c1f2aa3b 100644 --- a/src/transforms/adaptive_transform.jl +++ b/src/transforms/adaptive_transform.jl @@ -12,7 +12,7 @@ function init_adaptive_transform( adaptive_transform::CustomTransform, density ) - return adaptive_transform + return adaptive_transform.f end From 972d9677155f3ac7b763d67587f2758d225ca335 Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Tue, 4 Jul 2023 09:39:30 +0200 Subject: [PATCH 12/33] use cholesky lower for AdaptiveMHTuning --- .../mcmc_tuning/mcmc_proposalcov_tuner.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl index 7fa554961..bf5f43bd2 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -84,10 +84,8 @@ function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, transform = chain.f_transform - - #TODO AC: rename S_L to S, check with Oli - S_L = transform.A - Σ_old = S_L + A = transform.A + Σ_old = A*A' S = convert(Array, stats.param_stats.cov) a_t = 1 / t^λ @@ -112,9 +110,9 @@ function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, end Σ_new = new_Σ_unscal * tuner.scale - #TODO AC: check - #S = cholesky(Positive, Σ_new) - chain.f_transform = Mul(Σ_new) + + S_new = cholesky(Positive, Σ_new) + chain.f_transform = Mul(S_new.L) tuner.iteration += 1 nothing From 9cb262532db3c6d38124bd3ac5b60b073d1860b7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 6 Jul 2023 08:55:38 +0200 Subject: [PATCH 13/33] Adapt transformed MCMC code to using BATContext --- examples/dev-internal/transformed_example.jl | 22 ++++----- .../transformed_mcmc/chain_pool_init.jl | 31 ++++++------ .../transformed_mcmc/mcmc_algorithm.jl | 4 +- src/samplers/transformed_mcmc/mcmc_iterate.jl | 47 +++++++++++-------- src/samplers/transformed_mcmc/mcmc_sample.jl | 8 ++-- .../mcmc_tuning/mcmc_noop_tuner.jl | 4 +- .../mcmc_tuning/mcmc_proposalcov_tuner.jl | 4 +- .../mcmc_tuning/mcmc_ram_tuner.jl | 6 +-- src/transforms/adaptive_transform.jl | 8 ++-- 9 files changed, 68 insertions(+), 66 deletions(-) diff --git a/examples/dev-internal/transformed_example.jl b/examples/dev-internal/transformed_example.jl index 84d169f8f..047124d0a 100644 --- a/examples/dev-internal/transformed_example.jl +++ b/examples/dev-internal/transformed_example.jl @@ -6,19 +6,18 @@ using BAT.LinearAlgebra using BAT.Distributions using BAT.InverseFunctions import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, TransformedMCMCTransformedSampleID -using BAT.Random123 +using Random123 +using AutoDiffOperators import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling #ENV["JULIA_DEBUG"] = "BAT" -rng = Philox4x() +context = BATContext(ad = ADModule(:ForwardDiff)) posterior = BAT.example_posterior() -my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000)) - - +my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context) density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior) @@ -27,7 +26,7 @@ density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo) c = BAT._approx_cov(density) f = BAT.CustomTransform(Mul(c)) -my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f)) +my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f), context) my_samples = my_result.result @@ -36,9 +35,9 @@ my_samples = my_result.result using Plots plot(my_samples) -r_mh = @time BAT.bat_sample_impl(rng, posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true) ) +r_mh = @time BAT.bat_sample_impl(posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context) -r_hmc = @time BAT.bat_sample_impl(rng, posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000) ) +r_hmc = @time BAT.bat_sample_impl(posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context) plot(bat_sample(posterior).result) @@ -60,9 +59,8 @@ posterior.likelihood.density._log_f(rand(prior2)) posterior2 = PosteriorDensity(BAT.logfuncdensity(posterior.likelihood.density._log_f), prior2) -@profview r_ram2 = @time BAT.bat_sample_impl(rng, posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000)) - -@profview r_mh2 = @time BAT.bat_sample_impl(rng, posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true) ) +@profview r_ram2 = @time BAT.bat_sample_impl(posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context) -r_hmc2 = @time BAT.bat_sample_impl(rng, posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000) ) +@profview r_mh2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context) +r_hmc2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context) diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl index ba0cffb26..c5c11ad38 100644 --- a/src/samplers/transformed_mcmc/chain_pool_init.jl +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -37,12 +37,12 @@ function _construct_chain( id::Integer, algorithm::TransformedMCMCSampling, density::AbstractMeasureOrDensity, - initval_alg::InitvalAlgorithm + initval_alg::InitvalAlgorithm, + parent_context::BATContext ) - rng = AbstractRNG(rngpart, id) - v_init = bat_initval(rng, density, initval_alg).result - - TransformedMCMCIterator(rng, algorithm, density, id, v_init) + new_context = set_rng(parent_context, AbstractRNG(rngpart, id)) + v_init = bat_initval(density, initval_alg, new_context).result + return TransformedMCMCIterator(algorithm, density, id, v_init, new_context) end _gen_chains( @@ -50,19 +50,20 @@ _gen_chains( ids::AbstractRange{<:Integer}, algorithm::TransformedMCMCSampling, density::AbstractMeasureOrDensity, - initval_alg::InitvalAlgorithm -) = [_construct_chain(rngpart, id, algorithm, density, initval_alg) for id in ids] + initval_alg::InitvalAlgorithm, + context::BATContext +) = [_construct_chain(rngpart, id, algorithm, density, initval_alg, context) for id in ids] #TODO function mcmc_init!( - rng::AbstractRNG, algorithm::TransformedMCMCSampling, density::AbstractMeasureOrDensity, nchains::Integer, init_alg::TransformedMCMCChainPoolInit, tuning_alg::TransformedMCMCTuningAlgorithm, # TODO: part of algorithm? # MCMCTuner nonzero_weights::Bool, - callback::Function + callback::Function, + context::BATContext ) @info "TransformedMCMCChainPoolInit: trying to generate $nchains viable MCMC chain(s)." @@ -71,14 +72,15 @@ function mcmc_init!( min_nviable::Int = minimum(init_alg.init_tries_per_chain) * nchains max_ncandidates::Int = maximum(init_alg.init_tries_per_chain) * nchains - rngpart = RNGPartition(rng, Base.OneTo(max_ncandidates)) + rngpart = RNGPartition(get_rng(context), Base.OneTo(max_ncandidates)) ncandidates::Int = 0 @debug "Generating dummy MCMC chain to determine chain, output and tuner types." #TODO: remove! - dummy_initval = unshaped(bat_initval(rng, density, InitFromTarget()).result, varshape(density)) - dummy_chain = TransformedMCMCIterator(rng, algorithm, density, 1, dummy_initval) + dummy_context = deepcopy(context) + dummy_initval = unshaped(bat_initval(density, InitFromTarget(), dummy_context).result, varshape(density)) + dummy_chain = TransformedMCMCIterator(algorithm, density, 1, dummy_initval, dummy_context) dummy_tuner = get_tuner(tuning_alg, dummy_chain) dummy_temperer = get_temperer(algorithm.tempering, density) @@ -93,7 +95,7 @@ function mcmc_init!( n = min(min_nviable, max_ncandidates - ncandidates) @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." - new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg) + new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) filter!(isvalidchain, new_chains) @@ -135,7 +137,6 @@ function mcmc_init!( nsamples_thresh = floor(Int, 0.8 * median([nsamples(chain) for chain in viable_chains])) good_idxs = findall(chain -> nsamples(chain) >= nsamples_thresh, viable_chains) @debug "Found $(length(viable_chains)) MCMC chain(s) with at least $(nsamples_thresh) unique accepted samples." - append!(chains, view(viable_chains, good_idxs)) append!(tuners, view(viable_tuners, good_idxs)) @@ -153,7 +154,7 @@ function mcmc_init!( tidxs = LinearIndices(chains) n = length(tidxs) - modes = hcat(broadcast(samples -> Array(bat_findmode(rng, samples, MaxDensitySearch()).result), outputs)...) + modes = hcat(broadcast(samples -> Array(bat_findmode(samples, MaxDensitySearch(), context).result), outputs)...) final_chains = similar(chains, 0) final_tuners = similar(tuners, 0) diff --git a/src/samplers/transformed_mcmc/mcmc_algorithm.jl b/src/samplers/transformed_mcmc/mcmc_algorithm.jl index 08bb9d14a..f7e5aa8a6 100644 --- a/src/samplers/transformed_mcmc/mcmc_algorithm.jl +++ b/src/samplers/transformed_mcmc/mcmc_algorithm.jl @@ -68,7 +68,7 @@ end # BAT.getmeasure(chain::SomeMCMCIter)::AbstractMeasureOrDensity -# BAT.getrng(chain::SomeMCMCIter)::AbstractRNG +# BAT.getcontext(chain::SomeMCMCIter)::BATContext # BAT.mcmc_info(chain::SomeMCMCIter)::TransformedMCMCIteratorInfo @@ -122,8 +122,6 @@ function getalgorithm end function getmeasure end -function getrng end - function mcmc_info end function nsteps end diff --git a/src/samplers/transformed_mcmc/mcmc_iterate.jl b/src/samplers/transformed_mcmc/mcmc_iterate.jl index a2923bd8e..a59cbda78 100644 --- a/src/samplers/transformed_mcmc/mcmc_iterate.jl +++ b/src/samplers/transformed_mcmc/mcmc_iterate.jl @@ -1,13 +1,12 @@ mutable struct TransformedMCMCIterator{ - R<:AbstractRNG, PR<:RNGPartition, D<:BATMeasure, F, Q<:TransformedMCMCProposal, SV<:DensitySampleVector, S<:DensitySample, + CTX<:BATContext, } <: MCMCIterator - rng::R rngpart_cycle::PR μ::D f_transform::F @@ -17,11 +16,12 @@ mutable struct TransformedMCMCIterator{ stepno::Int n_accepted::Int info::TransformedMCMCIteratorInfo + context::CTX end getmeasure(chain::TransformedMCMCIterator) = chain.μ -getrng(chain::TransformedMCMCIterator) = chain.rng +get_context(chain::TransformedMCMCIterator) = chain.context mcmc_info(chain::TransformedMCMCIterator) = chain.info @@ -45,25 +45,25 @@ eff_acceptance_ratio(chain::TransformedMCMCIterator) = nsamples(chain) / chain.s #ctor function TransformedMCMCIterator( - rng::AbstractRNG, algorithm::TransformedMCMCSampling, target, id::Integer, - v_init::AbstractVector{<:Real} + v_init::AbstractVector{<:Real}, + context::BATContext ) - TransformedMCMCIterator(rng, algorithm, target, Int32(id), v_init) + TransformedMCMCIterator(algorithm, target, Int32(id), v_init, context) end #ctor function TransformedMCMCIterator( - rng::AbstractRNG, algorithm::TransformedMCMCSampling, target, id::Int32, v_init::AbstractVector{<:Real}, + context::BATContext, ) - rngpart_cycle = RNGPartition(rng, 0:(typemax(Int16) - 2)) + rngpart_cycle = RNGPartition(get_rng(context), 0:(typemax(Int16) - 2)) μ = target proposal = algorithm.proposal @@ -72,7 +72,7 @@ function TransformedMCMCIterator( n_accepted = 0 adaptive_transform_spec = algorithm.adaptive_transform - g = init_adaptive_transform(rng, adaptive_transform_spec, μ) + g = init_adaptive_transform(adaptive_transform_spec, μ, context) logd_x = logdensityof(μ, v_init) sample_x = DensitySample(v_init, logd_x, 1, TransformedMCMCTransformedSampleID(id, 1, 0), nothing) # TODO @@ -84,7 +84,6 @@ function TransformedMCMCIterator( samples = DensitySampleVector(([sample_x.v], [sample_x.logd], [sample_x.weight], [sample_x.info], [sample_x.aux] )) iter = TransformedMCMCIterator( - rng, rngpart_cycle, target, g, @@ -93,7 +92,8 @@ function TransformedMCMCIterator( sample_z, stepno, n_accepted, - TransformedMCMCIteratorInfo(id, cycle, false, false) + TransformedMCMCIteratorInfo(id, cycle, false, false), + context ) @@ -109,9 +109,10 @@ end function propose_mcmc( - iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:Any, <:TransformedMHProposal} + iter::TransformedMCMCIterator{<:Any, <:Any, <:Any, <:TransformedMHProposal} ) - @unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter + @unpack μ, f_transform, proposal, samples, sample_z, stepno, context = iter + rng = get_rng(context) sample_x = last(samples) x, logd_x = sample_x.v, sample_x.logd z, logd_z = sample_z.v, sample_z.logd @@ -157,7 +158,8 @@ function transformed_mcmc_step!!( tuner::TransformedAbstractMCMCTunerInstance, tempering::TransformedMCMCTemperingInstance, ) - @unpack rng, μ, f_transform, proposal, samples, sample_z, stepno = iter + @unpack μ, f_transform, proposal, samples, sample_z, stepno, context = iter + rng = get_rng(context) sample_x = last(samples) x, logd_x = sample_x.v, sample_x.logd z, logd_z = sample_z.v, sample_z.logd @@ -168,7 +170,7 @@ function transformed_mcmc_step!!( z_proposed, logd_z_proposed = sample_z_proposed.v, sample_z_proposed.logd x_proposed, logd_x_proposed = sample_x_proposed.v, sample_x_proposed.logd - tuner_new, f_transform = tune_mcmc_transform!!(rng, tuner, f_transform, p_accept, z_proposed, z, stepno) + tuner_new, f_transform = tune_mcmc_transform!!(tuner, f_transform, p_accept, z_proposed, z, stepno, context) accepted = rand(rng) <= p_accept @@ -193,11 +195,11 @@ function transformed_mcmc_step!!( f_new = f_transform - # iter_new = TransformedMCMCIterator(rng, μ_new, f_new, proposal, samples_new, sample_z_new, stepno, n_accepted+Int(accepted)) - iter.rng = rng + # iter_new = TransformedMCMCIterator(μ_new, f_new, proposal, samples_new, sample_z_new, stepno, n_accepted+Int(accepted), context) iter.μ, iter.f_transform, iter.samples, iter.sample_z = μ_new, f_new, samples_new, sample_z_new iter.n_accepted += Int(accepted) iter.stepno += 1 + @assert iter.context === context return (iter, tuner_new, tempering_new) end @@ -288,6 +290,8 @@ function transformed_mcmc_iterate!( end +#= +# Unused? function reset_chain( rng::AbstractRNG, chain::TransformedMCMCIterator, @@ -296,15 +300,18 @@ function reset_chain( #TODO reset cycle count? chain.rngpart_cycle = rngpart_cycle chain.info = TransformedMCMCIteratorInfo(chain.info, cycle=0) + chain.context = set_rng(chain.context, rng) # wants a next_cycle! # reset_rng_counters!(chain) end +=# function reset_rng_counters!(chain::TransformedMCMCIterator) - set_rng!(chain.rng, chain.rngpart_cycle, chain.info.cycle) - rngpart_step = RNGPartition(chain.rng, 0:(typemax(Int32) - 2)) - set_rng!(chain.rng, rngpart_step, chain.stepno) + rng = get_rng(get_context(chain)) + set_rng!(rng, chain.rngpart_cycle, chain.info.cycle) + rngpart_step = RNGPartition(rng, 0:(typemax(Int32) - 2)) + set_rng!(rng, rngpart_step, chain.stepno) nothing end diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl index e1d3c8734..054126c0d 100644 --- a/src/samplers/transformed_mcmc/mcmc_sample.jl +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -51,22 +51,22 @@ bat_default(::Type{TransformedMCMCDispatch}, ::Val{:burnin}, trafo::AbstractTran function bat_sample_impl( - rng::AbstractRNG, target::AnyMeasureOrDensity, - algorithm::TransformedMCMCSampling + algorithm::TransformedMCMCSampling, + context::BATContext ) density_notrafo = convert(AbstractMeasureOrDensity, target) density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo) init = mcmc_init!( - rng, algorithm, density, algorithm.nchains, apply_trafo_to_init(trafo, algorithm.init), algorithm.tuning_alg, algorithm.nonzero_weights, - algorithm.store_burnin ? algorithm.callback : nop_func + algorithm.store_burnin ? algorithm.callback : nop_func, + context ) @unpack chains, tuners, temperers = init diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl index 8352a88c7..3416ebfd8 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -32,13 +32,13 @@ end function tune_mcmc_transform!!( - rng::AbstractRNG, tuner::TransformedMCMCNoOpTuner, transform, p_accept::Real, z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead z_current::Vector{<:Float64}, - stepno::Int + stepno::Int, + context::BATContext ) return (tuner, transform) diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl index bf5f43bd2..e986d2c02 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -129,13 +129,13 @@ tuning_callback(::TransformedProposalCovTuner) = nop_func # this function is called in each mcmc_iterate step during tuning function tune_mcmc_transform!!( - rng::AbstractRNG, tuner::TransformedProposalCovTuner, transform::Any, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, p_accept::Real, z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead z_current::Vector{<:Float64}, - stepno::Int + stepno::Int, + context::BATContext ) return (tuner, transform) diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl index ae1572fbe..ae4b8a8e5 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -61,13 +61,13 @@ tuning_finalize!(tuner::TransformedRAMTunerInstance, chain::MCMCIterator) = noth default_adaptive_transform(tuner::TransformedRAMTuner) = TriangularAffineTransform() function tune_mcmc_transform!!( - rng::AbstractRNG, tuner::TransformedRAMTunerInstance, transform::Mul{<:LowerTriangular}, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}}, p_accept::Real, z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead z_current::Vector{<:Float64}, - stepno::Int + stepno::Int, + context::BATContext ) @unpack target_acceptance, gamma = tuner.config n = size(z_current,1) @@ -85,5 +85,3 @@ function tune_mcmc_transform!!( return (tuner, transform_new) end - - diff --git a/src/transforms/adaptive_transform.jl b/src/transforms/adaptive_transform.jl index 5c1f2aa3b..f44ffef4f 100644 --- a/src/transforms/adaptive_transform.jl +++ b/src/transforms/adaptive_transform.jl @@ -8,9 +8,9 @@ end CustomTransform() = CustomTransform(identity) function init_adaptive_transform( - rng::AbstractRNG, adaptive_transform::CustomTransform, - density + density, + context ) return adaptive_transform.f end @@ -20,9 +20,9 @@ end struct TriangularAffineTransform <: AdaptiveTransformSpec end function init_adaptive_transform( - rng::AbstractRNG, adaptive_transform::TriangularAffineTransform, - density + density, + context ) M = _approx_cov(density) s = cholesky(M).L From bb97b3a4e70d79c1d6e37629b5645491a89198f8 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 6 Jul 2023 10:33:27 +0200 Subject: [PATCH 14/33] FIx TransformedAdaptiveMHTuning and example --- examples/dev-internal/transformed_example.jl | 6 +++--- .../transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dev-internal/transformed_example.jl b/examples/dev-internal/transformed_example.jl index 047124d0a..adcfddb10 100644 --- a/examples/dev-internal/transformed_example.jl +++ b/examples/dev-internal/transformed_example.jl @@ -6,7 +6,7 @@ using BAT.LinearAlgebra using BAT.Distributions using BAT.InverseFunctions import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, TransformedMCMCTransformedSampleID -using Random123 +using Random123, PositiveFactorizations using AutoDiffOperators import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling @@ -23,8 +23,8 @@ my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_tra density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior) density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo) -c = BAT._approx_cov(density) -f = BAT.CustomTransform(Mul(c)) +s = cholesky(Positive, BAT._approx_cov(density)).L +f = BAT.CustomTransform(Mul(s)) my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f), context) diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl index e986d2c02..75aac0a51 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_proposalcov_tuner.jl @@ -58,7 +58,7 @@ end # this function is called once after each tuning cycle g_state = nothing -function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator, samples::DensitySampleVector) +function tuning_update!(tuner::TransformedProposalCovTuner, chain::TransformedMCMCIterator, samples::DensitySampleVector) global g_state = (;tuner, chain) stats = tuner.stats From 119c3fede94c7547327f0a12170a5c14d7ea9806 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 9 Jul 2023 11:22:38 +0200 Subject: [PATCH 15/33] Fix Project.toml --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4569a9c5b..5b02f6c40 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ version = "3.0.0" [deps] AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661" -AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" AutoDiffOperators = "6e1301d5-4f4d-4fb5-9679-7191e22f0e0e" From 7f45cee9c7ad7847676ca7f3fdb8cf5e853a3222 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 9 Jul 2023 11:35:00 +0200 Subject: [PATCH 16/33] Fix include order in transformed_mcmc --- src/samplers/transformed_mcmc/mcmc.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/samplers/transformed_mcmc/mcmc.jl b/src/samplers/transformed_mcmc/mcmc.jl index 8ef0b3c67..742e76195 100644 --- a/src/samplers/transformed_mcmc/mcmc.jl +++ b/src/samplers/transformed_mcmc/mcmc.jl @@ -7,10 +7,10 @@ include("proposaldist.jl") include("mcmc_sampleid.jl") include("mcmc_algorithm.jl") include("mcmc_stats.jl") -include("mcmc_tuning/mcmc_tuning.jl") -include("mcmc_convergence.jl") -include("tempering.jl") include("mcmc_sample.jl") +include("tempering.jl") include("mcmc_iterate.jl") +include("mcmc_tuning/mcmc_tuning.jl") +include("chain_pool_init.jl") +include("mcmc_convergence.jl") include("multi_cycle_burnin.jl") -include("chain_pool_init.jl") \ No newline at end of file From 9457054b2aacf3060a0fb1c6130564dcde2dc683 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 9 Jul 2023 11:50:50 +0200 Subject: [PATCH 17/33] Fix transformed_check_convergence! --- src/samplers/transformed_mcmc/mcmc_convergence.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/samplers/transformed_mcmc/mcmc_convergence.jl b/src/samplers/transformed_mcmc/mcmc_convergence.jl index 61c476b8b..83d568940 100644 --- a/src/samplers/transformed_mcmc/mcmc_convergence.jl +++ b/src/samplers/transformed_mcmc/mcmc_convergence.jl @@ -5,8 +5,9 @@ function transformed_check_convergence!( chains::AbstractVector{<:MCMCIterator}, samples::AbstractVector{<:DensitySampleVector}, algorithm::ConvergenceTest, + context::BATContext ) - result = convert(Bool, bat_convergence(samples, algorithm).result) + result = convert(Bool, bat_convergence(samples, algorithm, context).result) for chain in chains chain.info = TransformedMCMCIteratorInfo(chain.info, converged = result) end @@ -59,7 +60,7 @@ end export TransformedGelmanRubinConvergence -function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedGelmanRubinConvergence) +function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedGelmanRubinConvergence, ::BATContext) max_Rsqr = maximum(gr_Rsqr(samples)) vt = ValueAndThreshold{max_Rsqr}(max_Rsqr, <=, algorithm.threshold) converged = convert(Bool, vt) @@ -141,7 +142,7 @@ end export TransformedBrooksGelmanConvergence -function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedBrooksGelmanConvergence) +function bat_convergence_impl(samples::AbstractVector{<:DensitySampleVector}, algorithm::TransformedBrooksGelmanConvergence, ::BATContext) max_Rsqr = maximum(bg_R_2sqr(samples, corrected = algorithm.corrected)) vt = ValueAndThreshold{max_Rsqr}(max_Rsqr, <=, algorithm.threshold) converged = convert(Bool, vt) @@ -154,7 +155,7 @@ end -function bat_convergence_impl(samples::DensitySampleVector, algorithm::Union{TransformedGelmanRubinConvergence, TransformedBrooksGelmanConvergence}) +function bat_convergence_impl(samples::DensitySampleVector, algorithm::Union{TransformedGelmanRubinConvergence, TransformedBrooksGelmanConvergence}, context::BATContext) # create a vector of chains chains_ind = unique([i.chainid for i in samples.info]) vector_chains = DensitySampleVector[] @@ -164,5 +165,5 @@ function bat_convergence_impl(samples::DensitySampleVector, algorithm::Union{Tra push!(vector_chains, samples[mask_chain]) end - bat_convergence_impl(vector_chains, algorithm) + bat_convergence_impl(vector_chains, algorithm, context) end From 655273312fc976cdb70b6a044d6da8b125daa6bc Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 9 Jul 2023 12:05:54 +0200 Subject: [PATCH 18/33] Fix transformed bat_sample_impl and mcmc_burnin! --- src/samplers/transformed_mcmc/mcmc_sample.jl | 4 +++- src/samplers/transformed_mcmc/multi_cycle_burnin.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/samplers/transformed_mcmc/mcmc_sample.jl b/src/samplers/transformed_mcmc/mcmc_sample.jl index 054126c0d..957342583 100644 --- a/src/samplers/transformed_mcmc/mcmc_sample.jl +++ b/src/samplers/transformed_mcmc/mcmc_sample.jl @@ -56,7 +56,7 @@ function bat_sample_impl( context::BATContext ) density_notrafo = convert(AbstractMeasureOrDensity, target) - density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo) + density, trafo = transform_and_unshape(algorithm.pre_transform, density_notrafo, context) init = mcmc_init!( algorithm, @@ -113,6 +113,7 @@ function bat_sample_impl( (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) end +#= function _bat_sample_continue( target::AnyMeasureOrDensity, generator::TransformedMCMCSampleGenerator, @@ -130,6 +131,7 @@ function _bat_sample_continue( (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = TransformedMCMCSampleGenerator(chains, algorithm)) end +=# function _run_sample_impl( density::AnyMeasureOrDensity, diff --git a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl index 2e167935c..700ddd320 100644 --- a/src/samplers/transformed_mcmc/multi_cycle_burnin.jl +++ b/src/samplers/transformed_mcmc/multi_cycle_burnin.jl @@ -64,7 +64,7 @@ function mcmc_burnin!( isnothing(outputs) || append!(outputs, reduce(vcat, new_outputs)) - transformed_check_convergence!(chains, new_outputs, convergence_test) # TODO AC: Rename + transformed_check_convergence!(chains, new_outputs, convergence_test, BATContext()) # TODO AC: Rename # check_tuned/update_tuners... ntuned = count(c -> c.info.tuned, chains) From 0a027e8ea0d3371794f3babb037b5f7a598170d5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 9 Jul 2023 12:41:10 +0200 Subject: [PATCH 19/33] Adapt transformed example to API changes --- examples/dev-internal/transformed_example.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dev-internal/transformed_example.jl b/examples/dev-internal/transformed_example.jl index adcfddb10..79d1ea5ff 100644 --- a/examples/dev-internal/transformed_example.jl +++ b/examples/dev-internal/transformed_example.jl @@ -8,6 +8,7 @@ using BAT.InverseFunctions import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, TransformedMCMCTransformedSampleID using Random123, PositiveFactorizations using AutoDiffOperators +import AdvancedHMC import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling @@ -21,7 +22,7 @@ my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_tra density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior) -density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo) +density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo, context) s = cholesky(Positive, BAT._approx_cov(density)).L f = BAT.CustomTransform(Mul(s)) From 85dc22aafa7b141e5a0b93ec9a65824bcc4f43fb Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Tue, 11 Jul 2023 23:24:09 +0200 Subject: [PATCH 20/33] RAMTuner properly persist stepno through multi_cycle_burnin --- src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl index ae4b8a8e5..07baad293 100644 --- a/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/transformed_mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -71,7 +71,7 @@ function tune_mcmc_transform!!( ) @unpack target_acceptance, gamma = tuner.config n = size(z_current,1) - η = min(1, n * stepno^(-gamma)) + η = min(1, n * tuner.nsteps^(-gamma)) s_L = transform.A From 109a5c97c326d69f9c8066d13ad46b9596bf3544 Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Tue, 25 Jul 2023 00:08:54 +0200 Subject: [PATCH 21/33] rewrite mcmc_init! for optimized overall runtime --- src/samplers/mcmc/chain_pool_init.jl | 58 ++++++++++++++++------------ 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index c9e3604ca..5f6a74a24 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -66,13 +66,13 @@ function mcmc_init!( callback::Function, context::BATContext ) - @info "MCMCChainPoolInit: trying to generate $nchains viable MCMC chain(s)." - initval_alg = init_alg.initval_alg min_nviable::Int = minimum(init_alg.init_tries_per_chain) * nchains max_ncandidates::Int = maximum(init_alg.init_tries_per_chain) * nchains + @info "MCMCChainPoolInit: trying to generate $(min_nviable) viable MCMC chain(s)." + rngpart = RNGPartition(get_rng(context), Base.OneTo(max_ncandidates)) ncandidates::Int = 0 @@ -87,35 +87,45 @@ function mcmc_init!( chains = similar([dummy_chain], 0) tuners = similar([dummy_tuner], 0) outputs = similar([DensitySampleVector(dummy_chain)], 0) - cycle::Int = 1 + init_tries::Int = 1 while length(tuners) < min_nviable && ncandidates < max_ncandidates - n = min(min_nviable, max_ncandidates - ncandidates) - @debug "Generating $n $(cycle > 1 ? "additional " : "")candidate MCMC chain(s)." + viable_idxs = Vector{Int}() + viable_tuners = similar(tuners, 0) + viable_chains = similar(chains, 0) + viable_outputs = similar(outputs, 0) - new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) + # as the iteration after viable check is more costly, fill up to be at least capable to skip a complete reiteration. + while length(viable_idxs) < min_nviable-length(tuners) && ncandidates < max_ncandidates + n = max(min(min_nviable, max_ncandidates - ncandidates), min(min_nviable, Base.Threads.nthreads())) + @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." - filter!(isvalidchain, new_chains) + new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) - new_tuners = tuning_alg.(new_chains) - new_outputs = DensitySampleVector.(new_chains) - next_cycle!.(new_chains) - tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) - ncandidates += n + filter!(isvalidchain, new_chains) - @debug "Testing $(length(new_tuners)) candidate MCMC chain(s)." + new_tuners = tuning_alg.(new_chains) + new_outputs = DensitySampleVector.(new_chains) + next_cycle!.(new_chains) + tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) + ncandidates += n - mcmc_iterate!( - new_outputs, new_chains, new_tuners; - max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), - callback = callback, - nonzero_weights = nonzero_weights - ) + @debug "Testing $(length(new_tuners)) candidate MCMC chain(s)." + + mcmc_iterate!( + new_outputs, new_chains, new_tuners; + max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), + callback = callback, + nonzero_weights = nonzero_weights + ) + @info length.(new_outputs) + + append!(viable_idxs, findall(isviablechain.(new_chains))) - viable_idxs = findall(isviablechain.(new_chains)) - viable_tuners = new_tuners[viable_idxs] - viable_chains = new_chains[viable_idxs] - viable_outputs = new_outputs[viable_idxs] + append!(viable_tuners, new_tuners[viable_idxs]) + append!(viable_chains, new_chains[viable_idxs]) + append!(viable_outputs, new_outputs[viable_idxs]) + end @debug "Found $(length(viable_idxs)) viable MCMC chain(s)." @@ -136,7 +146,7 @@ function mcmc_init!( append!(outputs, view(viable_outputs, good_idxs)) end - cycle += 1 + init_tries += 1 end length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") From 756d0e4bf7521d0bd7b3780b590f535378f96b6f Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:53:45 +0200 Subject: [PATCH 22/33] add infrastructure to ease continue of chains --- src/samplers/mcmc/mcmc_sample.jl | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index 6084b6f3b..fe666778f 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -56,13 +56,30 @@ function bat_sample_impl( get_mcmc_tuning(mcmc_algorithm), algorithm.nonzero_weights, algorithm.store_burnin ? algorithm.callback : nop_func, - context + context, ) if !algorithm.store_burnin chain_outputs .= DensitySampleVector.(chains) end + run_sampling = _run_sample_impl(density, algorithm, chains, tuners, context, chain_outputs=chain_outputs) + samples_trafo, generator = run_sampling.result_trafo, run_sampling.generator + + samples_notrafo = inverse(trafo).(samples_trafo) + + (result=samples_notrafo, result_trafo=samples_trafo, trafo=trafo, generator=generator) +end + +function _run_sample_impl( + density::AnyMeasureOrDensity, + algorithm::MCMCSampling, + chains::AbstractVector{<:MCMCIterator}, + tuners, + context::BATContext; + description::AbstractString="MCMC iterate", + chain_outputs=DensitySampleVector.(chains) +) mcmc_burnin!( algorithm.store_burnin ? chain_outputs : nothing, tuners, @@ -88,7 +105,5 @@ function bat_sample_impl( isnothing(output) || append!.(Ref(output), chain_outputs) samples_trafo = varshape(density).(output) - samples_notrafo = inverse(trafo).(samples_trafo) - - (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = MCMCSampleGenerator(chains)) + (result_trafo = samples_trafo, generator = MCMCSampleGenerator(chains)) end From 3647b6e31f5974997aab85fae2cddef6c3e7bafa Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:55:46 +0200 Subject: [PATCH 23/33] spaces in return of bat_sample_impl --- src/samplers/mcmc/mcmc_sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index fe666778f..523161a63 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -68,7 +68,7 @@ function bat_sample_impl( samples_notrafo = inverse(trafo).(samples_trafo) - (result=samples_notrafo, result_trafo=samples_trafo, trafo=trafo, generator=generator) + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = generator) end function _run_sample_impl( From 902257d61f44a1cce0f6152623c5eef08cb9672b Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:58:33 +0200 Subject: [PATCH 24/33] introduce _bat_sample_continue --- src/samplers/mcmc/mcmc_sample.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index 523161a63..dbc4a44f2 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -107,3 +107,26 @@ function _run_sample_impl( (result_trafo = samples_trafo, generator = MCMCSampleGenerator(chains)) end + +function _bat_sample_continue( + target::AnyMeasureOrDensity, + algorithm::MCMCSampling, + generator::MCMCSampleGenerator, + context, + ;description::AbstractString = "MCMC iterate" +) + @unpack chains = generator + density_notrafo = convert(AbstractMeasureOrDensity, target) + density, trafo = transform_and_unshape(algorithm.trafo, density_notrafo, context) + + chain_outputs = DensitySampleVector.(chains) + + tuners = map(v -> get_mcmc_tuning(getproperty(v, :algorithm))(v), chains) + + run_sampling = _run_sample_impl(density, algorithm, chains, tuners, context, description=description, chain_outputs=chain_outputs) + samples_trafo, generator_new = run_sampling.result_trafo, run_sampling.generator + + samples_notrafo = inverse(trafo).(samples_trafo) + + (result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = generator_new) +end From 8cc015bde27b5d439191fd1d2725e7b71863f411 Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Tue, 25 Jul 2023 16:02:34 +0200 Subject: [PATCH 25/33] ProgressMeter for known infrastructure --- src/samplers/mcmc/chain_pool_init.jl | 7 ++++++- src/samplers/mcmc/mcmc_sample.jl | 6 +++++- src/samplers/mcmc/multi_cycle_burnin.jl | 7 ++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 5f6a74a24..2a63e4011 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -130,13 +130,18 @@ function mcmc_init!( @debug "Found $(length(viable_idxs)) viable MCMC chain(s)." if !isempty(viable_tuners) + desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_idxs), " of min_nviable=", length(tuners), "/", min_nviable ) + progress_meter = ProgressMeter.Progress(length(viable_idxs) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + mcmc_iterate!( viable_outputs, viable_chains, viable_tuners; max_nsteps = init_alg.nsteps_init, - callback = callback, + callback = (kwargs...)-> let pm=progress_meter, callback=callback ; callback(kwargs) ; ProgressMeter.next!(pm) ; end, nonzero_weights = nonzero_weights ) + ProgressMeter.finish!(progress_meter) + nsamples_thresh = floor(Int, 0.8 * median([nsamples(chain) for chain in viable_chains])) good_idxs = findall(chain -> nsamples(chain) >= nsamples_thresh, viable_chains) @debug "Found $(length(viable_tuners)) MCMC chain(s) with at least $(nsamples_thresh) unique accepted samples." diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index dbc4a44f2..685f80836 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -93,14 +93,18 @@ function _run_sample_impl( next_cycle!.(chains) + progress_meter = ProgressMeter.Progress(algorithm.nchains * algorithm.nsteps, desc=description, barlen=80 - length(description), dt=0.1) + mcmc_iterate!( chain_outputs, chains; max_nsteps = algorithm.nsteps, nonzero_weights = algorithm.nonzero_weights, - callback = algorithm.callback + callback = (kwargs...) -> let pm=progress_meter, callback=algorithm.callback ; callback(kwargs) ; ProgressMeter.next!(pm) ; end, ) + ProgressMeter.finish!(progress_meter) + output = DensitySampleVector(first(chains)) isnothing(output) || append!.(Ref(output), chain_outputs) samples_trafo = varshape(density).(output) diff --git a/src/samplers/mcmc/multi_cycle_burnin.jl b/src/samplers/mcmc/multi_cycle_burnin.jl index 206199df6..cf98a33ed 100644 --- a/src/samplers/mcmc/multi_cycle_burnin.jl +++ b/src/samplers/mcmc/multi_cycle_burnin.jl @@ -48,13 +48,18 @@ function mcmc_burnin!( tuning_reinit!.(tuners, chains, burnin_alg.nsteps_per_cycle) + desc_string = string("Burnin cycle ", cycles, "/max_cycles=", burnin_alg.max_ncycles," for nchains=", length(chains)) + progress_meter = ProgressMeter.Progress(length(chains)*burnin_alg.nsteps_per_cycle, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + mcmc_iterate!( new_outputs, chains, tuners, max_nsteps = burnin_alg.nsteps_per_cycle, nonzero_weights = nonzero_weights, - callback = callback + callback = (kwargs...) -> let pm=progress_meter, callback=callback ; callback(kwargs) ; ProgressMeter.next!(progress_meter) ; end, ) + ProgressMeter.finish!(progress_meter) + tuning_update!.(tuners, chains, new_outputs) isnothing(outputs) || append!.(outputs, new_outputs) From 41b817d64e3562595ede84175257603618791372 Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Tue, 25 Jul 2023 17:42:55 +0200 Subject: [PATCH 26/33] ahmc evaluates params [NaN,...] in times --- ext/ahmc_impl/ahmc_sampler_impl.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index d711d72de..2c4567379 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -78,7 +78,8 @@ function AHMCIterator( throw(ErrorException("HamiltonianMC requires an ADSelector to be specified in the BAT context")) end - f = checked_logdensityof(density) + # TODO AC ToDo!: discuss with @oschulz + f = logdensityof(density) fg = valgrad_func(f, adsel) init_hamiltonian = AdvancedHMC.Hamiltonian(metric, f, fg) From 0c4a7030fd4dcd6a6ad7a4e4b4c9978afc0b186b Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Fri, 28 Jul 2023 01:32:14 +0200 Subject: [PATCH 27/33] clustered init suggestion --- src/samplers/mcmc/chain_pool_init.jl | 40 ++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 2a63e4011..166097899 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -55,6 +55,41 @@ _gen_chains( context::BATContext ) = [_construct_chain(rngpart, id, algorithm, density, initval_alg, context) for id in ids] +# TODO AC discuss +function _cluster_selection( + chains::AbstractVector{<:MCMCIterator}, + tuners, + outputs::AbstractVector{<:DensitySampleVector}, + scale::Real=3, + decision_range_skip::Real=0.9, +) + logds_by_chain = [view(s.logd,(floor(Int,decision_range_skip*length(s))):length(s)) for s in outputs] + means = [mean(x) for x in logds_by_chain] + stddevs = [std(x) for x in logds_by_chain] + + # yet uncategoriesed + uncat = eachindex(chains, tuners, outputs, logds_by_chain, stddevs, means) + + # clustered indices + cidxs = Vector{Vector{eltype(uncat)}}() + # categories all to clusters + while length(uncat) > 0 + idxmin = findmin(view(stddevs,uncat))[2] + + cidx_sel = map(means_remaining_uncat -> abs(means_remaining_uncat-means[uncat[idxmin]]) < scale*stddevs[uncat[idxmin]], view(means,uncat)) + + push!(cidxs, uncat[cidx_sel]) + uncat = uncat[.!cidx_sel] + end + means_c = [ mean(reduce(vcat, view(samples_by_chain.logd, ids))) for ids in cidxs] + idx_order = sortperm(means_c, rev=true) + + chains_by_cluster = [ reduce(vcat, view(chains, ids)) for ids in cidxs[idx_order]] + tuners_by_cluster = [ reduce(vcat, view(tuners, ids)) for ids in cidxs[idx_order]] + outputs_by_cluster = [ reduce(vcat, view(outputs, ids)) for ids in cidxs[idx_order]] + ( chains = chains_by_cluster, tuners = tuners_by_cluster, outputs = outputs_by_cluster, ) +end + function mcmc_init!( algorithm::MCMCAlgorithm, @@ -154,6 +189,11 @@ function mcmc_init!( init_tries += 1 end + # TODO AC + if true + @unpack chains, tuners, outputs = _cluster_selection(chains, tuners, outputs) + end + length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") m = nchains From d988a84ea054feca99c6fc62d438500ba1b18c2a Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Fri, 28 Jul 2023 02:00:33 +0200 Subject: [PATCH 28/33] switch to median in cluster selection --- src/samplers/mcmc/chain_pool_init.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 166097899..6f5944326 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -64,11 +64,11 @@ function _cluster_selection( decision_range_skip::Real=0.9, ) logds_by_chain = [view(s.logd,(floor(Int,decision_range_skip*length(s))):length(s)) for s in outputs] - means = [mean(x) for x in logds_by_chain] + medians = [median(x) for x in logds_by_chain] stddevs = [std(x) for x in logds_by_chain] # yet uncategoriesed - uncat = eachindex(chains, tuners, outputs, logds_by_chain, stddevs, means) + uncat = eachindex(chains, tuners, outputs, logds_by_chain, stddevs, medians) # clustered indices cidxs = Vector{Vector{eltype(uncat)}}() @@ -76,13 +76,13 @@ function _cluster_selection( while length(uncat) > 0 idxmin = findmin(view(stddevs,uncat))[2] - cidx_sel = map(means_remaining_uncat -> abs(means_remaining_uncat-means[uncat[idxmin]]) < scale*stddevs[uncat[idxmin]], view(means,uncat)) + cidx_sel = map(means_remaining_uncat -> abs(means_remaining_uncat-medians[uncat[idxmin]]) < scale*stddevs[uncat[idxmin]], view(medians,uncat)) push!(cidxs, uncat[cidx_sel]) uncat = uncat[.!cidx_sel] end - means_c = [ mean(reduce(vcat, view(samples_by_chain.logd, ids))) for ids in cidxs] - idx_order = sortperm(means_c, rev=true) + medians_c = [ median(reduce(vcat, view(logds_by_chain, ids))) for ids in cidxs] + idx_order = sortperm(medians_c, rev=true) chains_by_cluster = [ reduce(vcat, view(chains, ids)) for ids in cidxs[idx_order]] tuners_by_cluster = [ reduce(vcat, view(tuners, ids)) for ids in cidxs[idx_order]] From 9909eede73af97dd8166e3f7112cb833a510109e Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Fri, 28 Jul 2023 02:41:07 +0200 Subject: [PATCH 29/33] forward best cluster --- src/samplers/mcmc/chain_pool_init.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 6f5944326..12520f31c 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -87,7 +87,7 @@ function _cluster_selection( chains_by_cluster = [ reduce(vcat, view(chains, ids)) for ids in cidxs[idx_order]] tuners_by_cluster = [ reduce(vcat, view(tuners, ids)) for ids in cidxs[idx_order]] outputs_by_cluster = [ reduce(vcat, view(outputs, ids)) for ids in cidxs[idx_order]] - ( chains = chains_by_cluster, tuners = tuners_by_cluster, outputs = outputs_by_cluster, ) + ( chains = chains_by_cluster[1], tuners = tuners_by_cluster[1], outputs = outputs_by_cluster[1], ) end From fc93e18cbaef8c3ccabcf6c8011a02560d785d83 Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Fri, 28 Jul 2023 11:41:07 +0200 Subject: [PATCH 30/33] _cluster_selection correct forward of chains&tuner --- src/samplers/mcmc/chain_pool_init.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 12520f31c..5cdc7e3ac 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -84,9 +84,9 @@ function _cluster_selection( medians_c = [ median(reduce(vcat, view(logds_by_chain, ids))) for ids in cidxs] idx_order = sortperm(medians_c, rev=true) - chains_by_cluster = [ reduce(vcat, view(chains, ids)) for ids in cidxs[idx_order]] - tuners_by_cluster = [ reduce(vcat, view(tuners, ids)) for ids in cidxs[idx_order]] - outputs_by_cluster = [ reduce(vcat, view(outputs, ids)) for ids in cidxs[idx_order]] + chains_by_cluster = [ view(chains, ids) for ids in cidxs[idx_order]] + tuners_by_cluster = [ view(tuners, ids) for ids in cidxs[idx_order]] + outputs_by_cluster = [ view(outputs, ids) for ids in cidxs[idx_order]] ( chains = chains_by_cluster[1], tuners = tuners_by_cluster[1], outputs = outputs_by_cluster[1], ) end From 3372df371f63d362822826c87bf8566fce3555e7 Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Fri, 28 Jul 2023 12:14:04 +0200 Subject: [PATCH 31/33] _cluster_selection proper fail criterion --- src/samplers/mcmc/chain_pool_init.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index 5cdc7e3ac..edc16016f 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -192,9 +192,11 @@ function mcmc_init!( # TODO AC if true @unpack chains, tuners, outputs = _cluster_selection(chains, tuners, outputs) + length(tuners) < nchains && error("Failed to generate $nchains viable MCMC chains") + else + length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") end - length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains") m = nchains tidxs = LinearIndices(tuners) From d6c01b375c756ff3f119c2935efb4a17a9a56a79 Mon Sep 17 00:00:00 2001 From: waldie11 <86674066+waldie11@users.noreply.github.com> Date: Wed, 2 Aug 2023 12:12:12 +0200 Subject: [PATCH 32/33] viable_idxs corrected --- src/samplers/mcmc/chain_pool_init.jl | 13 ++-- .../transformed_mcmc/chain_pool_init.jl | 60 +++++++++++-------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index edc16016f..2fa3dc85d 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -125,14 +125,13 @@ function mcmc_init!( init_tries::Int = 1 while length(tuners) < min_nviable && ncandidates < max_ncandidates - viable_idxs = Vector{Int}() viable_tuners = similar(tuners, 0) viable_chains = similar(chains, 0) viable_outputs = similar(outputs, 0) # as the iteration after viable check is more costly, fill up to be at least capable to skip a complete reiteration. - while length(viable_idxs) < min_nviable-length(tuners) && ncandidates < max_ncandidates - n = max(min(min_nviable, max_ncandidates - ncandidates), min(min_nviable, Base.Threads.nthreads())) + while length(viable_tuners) < min_nviable-length(tuners) && ncandidates < max_ncandidates + n = min(min_nviable, max_ncandidates - ncandidates) @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) @@ -155,18 +154,18 @@ function mcmc_init!( ) @info length.(new_outputs) - append!(viable_idxs, findall(isviablechain.(new_chains))) + viable_idxs = findall(isviablechain.(new_chains)) append!(viable_tuners, new_tuners[viable_idxs]) append!(viable_chains, new_chains[viable_idxs]) append!(viable_outputs, new_outputs[viable_idxs]) end - @debug "Found $(length(viable_idxs)) viable MCMC chain(s)." + @debug "Found $(length(viable_tuners)) viable MCMC chain(s)." if !isempty(viable_tuners) - desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_idxs), " of min_nviable=", length(tuners), "/", min_nviable ) - progress_meter = ProgressMeter.Progress(length(viable_idxs) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_tuners), " of min_nviable=", length(tuners), "/", min_nviable ) + progress_meter = ProgressMeter.Progress(length(viable_tuners) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) mcmc_iterate!( viable_outputs, viable_chains, viable_tuners; diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl index c5c11ad38..6d975a3b5 100644 --- a/src/samplers/transformed_mcmc/chain_pool_init.jl +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -91,42 +91,50 @@ function mcmc_init!( init_tries::Int = 1 while length(tuners) < min_nviable && ncandidates < max_ncandidates + viable_tuners = similar(tuners, 0) + viable_chains = similar(chains, 0) + viable_outputs = similar(outputs, 0) - n = min(min_nviable, max_ncandidates - ncandidates) - @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." + # as the iteration after viable check is more costly, fill up to be at least capable to skip a complete reiteration. + while length(viable_tuners) < min_nviable-length(tuners) && ncandidates < max_ncandidates + n = min(min_nviable, max_ncandidates - ncandidates) + @debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)." - new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) + new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context) - filter!(isvalidchain, new_chains) + filter!(isvalidchain, new_chains) - new_tuners = get_tuner.(Ref(tuning_alg), new_chains) - new_temperers = fill(get_temperer(algorithm.tempering, density), size(new_tuners,1)) - - next_cycle!.(new_chains) - - tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) - ncandidates += n + new_tuners = get_tuner.(Ref(tuning_alg), new_chains) + new_temperers = fill(get_temperer(algorithm.tempering, density), size(new_tuners,1)) - @debug "Testing $(length(new_chains)) candidate MCMC chain(s)." + next_cycle!.(new_chains) - transformed_mcmc_iterate!( - new_chains, new_tuners, new_temperers, - max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), - callback = callback, - nonzero_weights = nonzero_weights - ) + tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init) + ncandidates += n - # testing if chains are viable: - viable_idxs = findall(isviablechain.(new_chains)) - viable_temperers = new_temperers[viable_idxs] - viable_tuners = new_tuners[viable_idxs] - viable_chains = new_chains[viable_idxs] + @debug "Testing $(length(new_chains)) candidate MCMC chain(s)." - @debug "Found $(length(viable_idxs)) viable MCMC chain(s)." + transformed_mcmc_iterate!( + new_chains, new_tuners, new_temperers, + max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50), + callback = callback, + nonzero_weights = nonzero_weights + ) + + # testing if chains are viable: + viable_idxs = findall(isviablechain.(new_chains)) + + append!(viable_tuners, new_tuners[viable_idxs]) + append!(viable_chains, new_chains[viable_idxs]) + append!(viable_outputs, new_outputs[viable_idxs]) + + end + + @debug "Found $(length(viable_tuners)) viable MCMC chain(s)." if !isempty(viable_chains) - desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_idxs), " of min_nviable=", length(tuners), "/", min_nviable ) - progress_meter = ProgressMeter.Progress(length(viable_idxs) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) + desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_tuners), " of min_nviable=", length(tuners), "/", min_nviable ) + progress_meter = ProgressMeter.Progress(length(viable_tuners) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1) transformed_mcmc_iterate!( viable_chains, viable_tuners, viable_temperers; max_nsteps = init_alg.nsteps_init, From b6565c393ac0ab34b85e95426fc3e1aa3f381a1d Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Mon, 25 Sep 2023 10:08:50 +0200 Subject: [PATCH 33/33] quick fixes --- .vscode/settings.json | 3 +++ src/samplers/transformed_mcmc/chain_pool_init.jl | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..95ac9a593 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "julia.environmentPath": "C:\\Users\\Cornelius\\.julia\\environments\\v1.9" +} \ No newline at end of file diff --git a/src/samplers/transformed_mcmc/chain_pool_init.jl b/src/samplers/transformed_mcmc/chain_pool_init.jl index 6d975a3b5..adcc39e6d 100644 --- a/src/samplers/transformed_mcmc/chain_pool_init.jl +++ b/src/samplers/transformed_mcmc/chain_pool_init.jl @@ -93,7 +93,8 @@ function mcmc_init!( while length(tuners) < min_nviable && ncandidates < max_ncandidates viable_tuners = similar(tuners, 0) viable_chains = similar(chains, 0) - viable_outputs = similar(outputs, 0) + viable_temperers = similar(temperers, 0) + viable_outputs = [] #similar(outputs, 0) #TODO # as the iteration after viable check is more costly, fill up to be at least capable to skip a complete reiteration. while length(viable_tuners) < min_nviable-length(tuners) && ncandidates < max_ncandidates @@ -124,9 +125,12 @@ function mcmc_init!( # testing if chains are viable: viable_idxs = findall(isviablechain.(new_chains)) + new_outputs = getproperty.(new_chains, :samples) #TODO ? + append!(viable_tuners, new_tuners[viable_idxs]) append!(viable_chains, new_chains[viable_idxs]) append!(viable_outputs, new_outputs[viable_idxs]) + append!(viable_temperers, new_temperers[viable_idxs]) end