Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c1cec60
use Tranformed prefix for all structs
Jun 30, 2023
9cbacf2
use Transformed prefix for all types
Jun 30, 2023
6cbe392
update deps
Jun 30, 2023
e32abca
add AdaptiveTransform
Jun 30, 2023
50c02ba
include transformed_mcmc
Jun 30, 2023
b0794f0
Revert "use Transformed prefix for all types"
Cornelius-G Jun 30, 2023
e1d7aba
chnage replace_type script
Cornelius-G Jun 30, 2023
66fa0ad
use Transformed prefix for all abstract types
Jun 30, 2023
6c39fee
new TransformedMCMCSampling and old MCMCSampling now both working
Cornelius-G Jun 30, 2023
f640fbc
move example
Cornelius-G Jul 3, 2023
d0ce010
use full matrix instead of lower cholesky in AdaptiveMHTuner
Cornelius-G Jul 3, 2023
972d967
use cholesky lower for AdaptiveMHTuning
Cornelius-G Jul 4, 2023
756827f
Merge commit 'c9d7fd98bf61d05fedc0be1643b41d7aabe43c98' into RenameTr…
oschulz Jul 6, 2023
9cb2625
Adapt transformed MCMC code to using BATContext
oschulz Jul 6, 2023
bb97b3a
FIx TransformedAdaptiveMHTuning and example
oschulz Jul 6, 2023
b2d5449
Merge branch 'main' into RenameTransformed
oschulz Jul 9, 2023
119c3fe
Fix Project.toml
oschulz Jul 9, 2023
7f45cee
Fix include order in transformed_mcmc
oschulz Jul 9, 2023
9457054
Fix transformed_check_convergence!
oschulz Jul 9, 2023
6552733
Fix transformed bat_sample_impl and mcmc_burnin!
oschulz Jul 9, 2023
8dc1ebf
Merge branch 'main' into trafo-merge
oschulz Jul 9, 2023
0a027e8
Adapt transformed example to API changes
oschulz Jul 9, 2023
85dc22a
RAMTuner properly persist stepno through multi_cycle_burnin
waldie11 Jul 11, 2023
9602973
Merge remote-tracking branch 'origin/main' into RenameTransformed
waldie11 Jul 11, 2023
f8f4b8d
Merge branch 'main' into RenameTransformed
waldie11 Jul 19, 2023
109a5c9
rewrite mcmc_init! for optimized overall runtime
waldie11 Jul 24, 2023
756d0e4
add infrastructure to ease continue of chains
waldie11 Jul 25, 2023
3647b6e
spaces in return of bat_sample_impl
waldie11 Jul 25, 2023
902257d
introduce _bat_sample_continue
waldie11 Jul 25, 2023
8cc015b
ProgressMeter for known infrastructure
waldie11 Jul 25, 2023
41b817d
ahmc evaluates params [NaN,...] in times
waldie11 Jul 25, 2023
0c4a703
clustered init suggestion
waldie11 Jul 27, 2023
d988a84
switch to median in cluster selection
waldie11 Jul 28, 2023
9909eed
forward best cluster
waldie11 Jul 28, 2023
fc93e18
_cluster_selection correct forward of chains&tuner
waldie11 Jul 28, 2023
3372df3
_cluster_selection proper fail criterion
waldie11 Jul 28, 2023
d6c01b3
viable_idxs corrected
waldie11 Aug 2, 2023
a27f521
Merge branch 'main' into RenameTransformed
Cornelius-G Sep 25, 2023
b6565c3
quick fixes
Cornelius-G Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"julia.environmentPath": "C:\\Users\\Cornelius\\.julia\\environments\\v1.9"
}
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ParallelProcessingTools = "8e8a01fc-6193-5ca1-a2f1-20776dae4199"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand Down Expand Up @@ -119,7 +120,7 @@ IrrationalConstants = "0.1, 0.2"
KernelDensity = "0.5, 0.6"
LaTeXStrings = "1"
MacroTools = "0.5"
MeasureBase = "0.12, 0.13, 0.14"
MeasureBase = "0.14"
Measurements = "2"
NamedArrays = "0.9, 0.10"
NestedSamplers = "0.8"
Expand Down
67 changes: 67 additions & 0 deletions examples/dev-internal/transformed_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using BAT
using BAT.MeasureBase
using AffineMaps
using ChangesOfVariables
using BAT.LinearAlgebra
using BAT.Distributions
using BAT.InverseFunctions
import BAT: TransformedMCMCIterator, TransformedAdaptiveMHTuning, TransformedRAMTuner, TransformedMHProposal, TransformedNoTransformedMCMCTempering, transformed_mcmc_step!!, TransformedMCMCTransformedSampleID
using Random123, PositiveFactorizations
using AutoDiffOperators
import AdvancedHMC

import BAT: mcmc_iterate!, transformed_mcmc_iterate!, TransformedMCMCSampling

#ENV["JULIA_DEBUG"] = "BAT"

context = BATContext(ad = ADModule(:ForwardDiff))

posterior = BAT.example_posterior()

my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context)


density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior)
density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo, context)

s = cholesky(Positive, BAT._approx_cov(density)).L
f = BAT.CustomTransform(Mul(s))

my_result = @time BAT.bat_sample_impl(posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f), context)

my_samples = my_result.result



using Plots
plot(my_samples)

r_mh = @time BAT.bat_sample_impl(posterior, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context)

r_hmc = @time BAT.bat_sample_impl(posterior, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context)

plot(bat_sample(posterior).result)

using BAT.Distributions
using BAT.ValueShapes
prior2 = NamedTupleDist(ShapedAsNT,
b = [4.2, 3.3],
a = Exponential(1.0),
c = Normal(1.0,3.0),
d = product_distribution(Weibull.(ones(2),1)),
e = Beta(1.0, 1.0),
f = MvNormal([0.3,-2.9],Matrix([1.7 0.5;0.5 2.3]))
)

posterior.likelihood.density._log_f(rand(posterior.prior))

posterior.likelihood.density._log_f(rand(prior2))

posterior2 = PosteriorDensity(BAT.logfuncdensity(posterior.likelihood.density._log_f), prior2)


@profview r_ram2 = @time BAT.bat_sample_impl(posterior2, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000), context)

@profview r_mh2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling( nchains=4, nsteps=4*100000, store_burnin=true), context)

r_hmc2 = @time BAT.bat_sample_impl(posterior2, MCMCSampling(mcalg=HamiltonianMC(), nchains=4, nsteps=4*20000), context)
3 changes: 2 additions & 1 deletion ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ function AHMCIterator(
throw(ErrorException("HamiltonianMC requires an ADSelector to be specified in the BAT context"))
end

f = checked_logdensityof(density)
# TODO AC ToDo!: discuss with @oschulz
f = logdensityof(density)
fg = valgrad_func(f, adsel)

init_hamiltonian = AdvancedHMC.Hamiltonian(metric, f, fg)
Expand Down
1 change: 1 addition & 0 deletions src/BAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import HypothesisTests
import MeasureBase
import Measurements
import NamedArrays
import ProgressMeter
import Random123
import Sobol
import StableRNGs
Expand Down
110 changes: 83 additions & 27 deletions src/samplers/mcmc/chain_pool_init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,41 @@ _gen_chains(
context::BATContext
) = [_construct_chain(rngpart, id, algorithm, density, initval_alg, context) for id in ids]

# TODO AC discuss
function _cluster_selection(
chains::AbstractVector{<:MCMCIterator},
tuners,
outputs::AbstractVector{<:DensitySampleVector},
scale::Real=3,
decision_range_skip::Real=0.9,
)
logds_by_chain = [view(s.logd,(floor(Int,decision_range_skip*length(s))):length(s)) for s in outputs]
medians = [median(x) for x in logds_by_chain]
stddevs = [std(x) for x in logds_by_chain]

# yet uncategoriesed
uncat = eachindex(chains, tuners, outputs, logds_by_chain, stddevs, medians)

# clustered indices
cidxs = Vector{Vector{eltype(uncat)}}()
# categories all to clusters
while length(uncat) > 0
idxmin = findmin(view(stddevs,uncat))[2]

cidx_sel = map(means_remaining_uncat -> abs(means_remaining_uncat-medians[uncat[idxmin]]) < scale*stddevs[uncat[idxmin]], view(medians,uncat))

push!(cidxs, uncat[cidx_sel])
uncat = uncat[.!cidx_sel]
end
medians_c = [ median(reduce(vcat, view(logds_by_chain, ids))) for ids in cidxs]
idx_order = sortperm(medians_c, rev=true)

chains_by_cluster = [ view(chains, ids) for ids in cidxs[idx_order]]
tuners_by_cluster = [ view(tuners, ids) for ids in cidxs[idx_order]]
outputs_by_cluster = [ view(outputs, ids) for ids in cidxs[idx_order]]
( chains = chains_by_cluster[1], tuners = tuners_by_cluster[1], outputs = outputs_by_cluster[1], )
end


function mcmc_init!(
algorithm::MCMCAlgorithm,
Expand All @@ -66,13 +101,13 @@ function mcmc_init!(
callback::Function,
context::BATContext
)
@info "MCMCChainPoolInit: trying to generate $nchains viable MCMC chain(s)."

initval_alg = init_alg.initval_alg

min_nviable::Int = minimum(init_alg.init_tries_per_chain) * nchains
max_ncandidates::Int = maximum(init_alg.init_tries_per_chain) * nchains

@info "MCMCChainPoolInit: trying to generate $(min_nviable) viable MCMC chain(s)."

rngpart = RNGPartition(get_rng(context), Base.OneTo(max_ncandidates))

ncandidates::Int = 0
Expand All @@ -87,46 +122,60 @@ function mcmc_init!(
chains = similar([dummy_chain], 0)
tuners = similar([dummy_tuner], 0)
outputs = similar([DensitySampleVector(dummy_chain)], 0)
cycle::Int = 1
init_tries::Int = 1

while length(tuners) < min_nviable && ncandidates < max_ncandidates
n = min(min_nviable, max_ncandidates - ncandidates)
@debug "Generating $n $(cycle > 1 ? "additional " : "")candidate MCMC chain(s)."
viable_tuners = similar(tuners, 0)
viable_chains = similar(chains, 0)
viable_outputs = similar(outputs, 0)

new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context)
# as the iteration after viable check is more costly, fill up to be at least capable to skip a complete reiteration.
while length(viable_tuners) < min_nviable-length(tuners) && ncandidates < max_ncandidates
n = min(min_nviable, max_ncandidates - ncandidates)
@debug "Generating $n $(init_tries > 1 ? "additional " : "")candidate MCMC chain(s)."

filter!(isvalidchain, new_chains)
new_chains = _gen_chains(rngpart, ncandidates .+ (one(Int64):n), algorithm, density, initval_alg, context)

new_tuners = tuning_alg.(new_chains)
new_outputs = DensitySampleVector.(new_chains)
next_cycle!.(new_chains)
tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init)
ncandidates += n
filter!(isvalidchain, new_chains)

@debug "Testing $(length(new_tuners)) candidate MCMC chain(s)."
new_tuners = tuning_alg.(new_chains)
new_outputs = DensitySampleVector.(new_chains)
next_cycle!.(new_chains)
tuning_init!.(new_tuners, new_chains, init_alg.nsteps_init)
ncandidates += n

mcmc_iterate!(
new_outputs, new_chains, new_tuners;
max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50),
callback = callback,
nonzero_weights = nonzero_weights
)
@debug "Testing $(length(new_tuners)) candidate MCMC chain(s)."

viable_idxs = findall(isviablechain.(new_chains))
viable_tuners = new_tuners[viable_idxs]
viable_chains = new_chains[viable_idxs]
viable_outputs = new_outputs[viable_idxs]
mcmc_iterate!(
new_outputs, new_chains, new_tuners;
max_nsteps = clamp(div(init_alg.nsteps_init, 5), 10, 50),
callback = callback,
nonzero_weights = nonzero_weights
)
@info length.(new_outputs)

viable_idxs = findall(isviablechain.(new_chains))

@debug "Found $(length(viable_idxs)) viable MCMC chain(s)."
append!(viable_tuners, new_tuners[viable_idxs])
append!(viable_chains, new_chains[viable_idxs])
append!(viable_outputs, new_outputs[viable_idxs])
end

@debug "Found $(length(viable_tuners)) viable MCMC chain(s)."

if !isempty(viable_tuners)
desc_string = string("Init try ", init_tries, " for nvalid=", length(viable_tuners), " of min_nviable=", length(tuners), "/", min_nviable )
progress_meter = ProgressMeter.Progress(length(viable_tuners) * init_alg.nsteps_init, desc=desc_string, barlen=80-length(desc_string), dt=0.1)

mcmc_iterate!(
viable_outputs, viable_chains, viable_tuners;
max_nsteps = init_alg.nsteps_init,
callback = callback,
callback = (kwargs...)-> let pm=progress_meter, callback=callback ; callback(kwargs) ; ProgressMeter.next!(pm) ; end,
nonzero_weights = nonzero_weights
)

ProgressMeter.finish!(progress_meter)

nsamples_thresh = floor(Int, 0.8 * median([nsamples(chain) for chain in viable_chains]))
good_idxs = findall(chain -> nsamples(chain) >= nsamples_thresh, viable_chains)
@debug "Found $(length(viable_tuners)) MCMC chain(s) with at least $(nsamples_thresh) unique accepted samples."
Expand All @@ -136,10 +185,17 @@ function mcmc_init!(
append!(outputs, view(viable_outputs, good_idxs))
end

cycle += 1
init_tries += 1
end

# TODO AC
if true
@unpack chains, tuners, outputs = _cluster_selection(chains, tuners, outputs)
length(tuners) < nchains && error("Failed to generate $nchains viable MCMC chains")
else
length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains")
end

length(tuners) < min_nviable && error("Failed to generate $min_nviable viable MCMC chains")

m = nchains
tidxs = LinearIndices(tuners)
Expand Down
48 changes: 45 additions & 3 deletions src/samplers/mcmc/mcmc_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,30 @@ function bat_sample_impl(
get_mcmc_tuning(mcmc_algorithm),
algorithm.nonzero_weights,
algorithm.store_burnin ? algorithm.callback : nop_func,
context
context,
)

if !algorithm.store_burnin
chain_outputs .= DensitySampleVector.(chains)
end

run_sampling = _run_sample_impl(density, algorithm, chains, tuners, context, chain_outputs=chain_outputs)
samples_trafo, generator = run_sampling.result_trafo, run_sampling.generator

samples_notrafo = inverse(trafo).(samples_trafo)

(result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = generator)
end

function _run_sample_impl(
density::AnyMeasureOrDensity,
algorithm::MCMCSampling,
chains::AbstractVector{<:MCMCIterator},
tuners,
context::BATContext;
description::AbstractString="MCMC iterate",
chain_outputs=DensitySampleVector.(chains)
)
mcmc_burnin!(
algorithm.store_burnin ? chain_outputs : nothing,
tuners,
Expand All @@ -76,19 +93,44 @@ function bat_sample_impl(

next_cycle!.(chains)

progress_meter = ProgressMeter.Progress(algorithm.nchains * algorithm.nsteps, desc=description, barlen=80 - length(description), dt=0.1)

mcmc_iterate!(
chain_outputs,
chains;
max_nsteps = algorithm.nsteps,
nonzero_weights = algorithm.nonzero_weights,
callback = algorithm.callback
callback = (kwargs...) -> let pm=progress_meter, callback=algorithm.callback ; callback(kwargs) ; ProgressMeter.next!(pm) ; end,
)

ProgressMeter.finish!(progress_meter)

output = DensitySampleVector(first(chains))
isnothing(output) || append!.(Ref(output), chain_outputs)
samples_trafo = varshape(density).(output)

(result_trafo = samples_trafo, generator = MCMCSampleGenerator(chains))
end

function _bat_sample_continue(
target::AnyMeasureOrDensity,
algorithm::MCMCSampling,
generator::MCMCSampleGenerator,
context,
;description::AbstractString = "MCMC iterate"
)
@unpack chains = generator
density_notrafo = convert(AbstractMeasureOrDensity, target)
density, trafo = transform_and_unshape(algorithm.trafo, density_notrafo, context)

chain_outputs = DensitySampleVector.(chains)

tuners = map(v -> get_mcmc_tuning(getproperty(v, :algorithm))(v), chains)

run_sampling = _run_sample_impl(density, algorithm, chains, tuners, context, description=description, chain_outputs=chain_outputs)
samples_trafo, generator_new = run_sampling.result_trafo, run_sampling.generator

samples_notrafo = inverse(trafo).(samples_trafo)

(result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = MCMCSampleGenerator(chains))
(result = samples_notrafo, result_trafo = samples_trafo, trafo = trafo, generator = generator_new)
end
7 changes: 6 additions & 1 deletion src/samplers/mcmc/multi_cycle_burnin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ function mcmc_burnin!(

tuning_reinit!.(tuners, chains, burnin_alg.nsteps_per_cycle)

desc_string = string("Burnin cycle ", cycles, "/max_cycles=", burnin_alg.max_ncycles," for nchains=", length(chains))
progress_meter = ProgressMeter.Progress(length(chains)*burnin_alg.nsteps_per_cycle, desc=desc_string, barlen=80-length(desc_string), dt=0.1)

mcmc_iterate!(
new_outputs, chains, tuners,
max_nsteps = burnin_alg.nsteps_per_cycle,
nonzero_weights = nonzero_weights,
callback = callback
callback = (kwargs...) -> let pm=progress_meter, callback=callback ; callback(kwargs) ; ProgressMeter.next!(progress_meter) ; end,
)

ProgressMeter.finish!(progress_meter)

tuning_update!.(tuners, chains, new_outputs)
isnothing(outputs) || append!.(outputs, new_outputs)

Expand Down
1 change: 1 addition & 0 deletions src/samplers/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

include("bat_sample.jl")
include("mcmc/mcmc.jl")
include("transformed_mcmc/mcmc.jl")
include("evaluated_measure.jl")
include("importance/importance_sampler.jl")
Loading