/
bridge_sampling_integration.jl
143 lines (114 loc) · 5.07 KB
/
bridge_sampling_integration.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# This file is a part of BAT.jl, licensed under the MIT License (MIT).
"""
struct BridgeSampling <: IntegrationAlgorithm
*Experimental feature, not part of stable public API.*
BridgeSampling integration algorithm.
Constructors:
* ```$(FUNCTIONNAME)(; fields...)```
Fields:
$(TYPEDFIELDS)
"""
@with_kw struct BridgeSampling{TR<:AbstractTransformTarget,ESS<:EffSampleSizeAlgorithm} <: IntegrationAlgorithm
trafo::TR = PriorToGaussian()
essalg::ESS = EffSampleSizeFromAC()
strict::Bool = true
# ToDo: add argument for proposal density generator
end
export BridgeSampling
function bat_integrate_impl(target::EvaluatedMeasure, algorithm::BridgeSampling, context::BATContext)
@argcheck !isnothing(target.samples)
transformed_target, _ = transform_and_unshape(algorithm.trafo, target, context)
renomalized_target, logweight = auto_renormalize(transformed_target)
measure, samples = renomalized_target.measure, renomalized_target.samples
(value, error) = bridge_sampling_integral(measure, samples, algorithm.strict, algorithm.essalg, context)
rescaled_value, rescaled_error = exp(BigFloat(log(value) - logweight)), exp(BigFloat(log(error) - logweight))
result = Measurements.measurement(rescaled_value, rescaled_error)
return (result = result, logweight = logweight)
end
#!!!!! Use EvaluatedMeasure
function bridge_sampling_integral(
target_density::BATMeasure,
target_samples::DensitySampleVector,
proposal_density::BATMeasure,
proposal_samples::DensitySampleVector,
strict::Bool,
ess_alg::EffSampleSizeAlgorithm,
context::BATContext
)
N1 = Int(sum(target_samples.weight))
N2 = Int(sum(proposal_samples.weight))
#####################
# Evaluate integral #
#####################
#calculate elements for iterative determination of marginal likelhood
l1 = [exp(target_samples.logd[i]-logdensityof(proposal_density,x)) for (i,x) in enumerate(target_samples.v)]
l2 = [exp(logdensityof(target_density,x)-proposal_samples.logd[i]) for (i,x) in enumerate(proposal_samples.v)]
s1 = N1/(N2+N1)
s2 = N2/(N1+N2)
#calculate marginal likelhood iteratively
prev_int = 0
counter = 0
current_int = 0.1
while abs(current_int-prev_int)/current_int > 10^(-15)
prev_int = current_int
numerator = 0
for (i, w) in enumerate(proposal_samples.weight)
numerator += w*(l2[i]/(s1*l2[i]+s2*prev_int))
end
numerator = numerator/N2
denominator = 0
for (i, w) in enumerate(target_samples.weight)
denominator += w/(s1*l1[i]+s2*prev_int)
end
denominator = denominator/N1
current_int = numerator/denominator
if counter == 500
msg = "The iterative scheme is not converging!!"
if strict
throw(ErrorException(msg))
else
@warn(msg)
end
end
counter=counter+1
end
#################
#Evaluate error #
#################
#pre calculate objects for error estimate
# ToDo: Make this type-stable:
f1 = [exp(logdensityof(target_density,x))/current_int/(s1*exp(logdensityof(target_density,x))/current_int+s2*exp(proposal_samples.logd[i])) for (i,x) in enumerate(proposal_samples.v)]
f2 = [[exp(logdensityof(proposal_density,x))/(s1*exp(target_samples.logd[i])/current_int+s2*exp(logdensityof(proposal_density,x)))] for (i,x) in enumerate(target_samples.v)]
f2_density_vector = DensitySampleVector(f2,target_samples.logd,weight=target_samples.weight)
mean1, var1 = StatsBase.mean_and_var(f1, FrequencyWeights(proposal_samples.weight), corrected = true)
mean2, var2 = mean(f2_density_vector)[1],cov(f2_density_vector)[1]
N1_eff = bat_eff_sample_size_impl(f2_density_vector,ess_alg,context).result[1]
# calculate Root mean squared error
r_MSE = sqrt(var1/(mean1^2*N2)+(var2/mean2^2)/N1_eff)*current_int
value, error = current_int, r_MSE
return (Float64(value)::Float64, Float64(error)::Float64) # Force type stability, see above.
end
#!!!!!! Use EvaluatedMeasure
function bridge_sampling_integral(
target_measure::BATMeasure,
target_samples::DensitySampleVector,
strict::Bool,
ess_alg::EffSampleSizeAlgorithm,
context::BATContext
)
num_samples = size(target_samples.weight)[1]
n_first = floor(Int,num_samples/2)
first_batch = target_samples[1:n_first]
second_batch = target_samples[n_first+1:end]
#####################
# proposal function #
#####################
#Determine proposal function
post_mean = vec(mean(first_batch))
post_cov = Array(cov(first_batch)) #TODO: other covariance approximations
post_cov_pd = PDMat(cholesky(Positive, post_cov))
proposal_measure = batmeasure(MvNormal(post_mean,post_cov_pd))
proposal_samples = bat_sample_impl(proposal_measure, IIDSampling(nsamples=Int(sum(second_batch.weight))), context).result
proposal_measure = batmeasure(proposal_measure)
bridge_sampling_integral(target_measure,second_batch,proposal_measure,proposal_samples,strict,ess_alg,context)
end