Skip to content

Commit

Permalink
Fix failing EnIS test
Browse files Browse the repository at this point in the history
  • Loading branch information
bgroenks96 committed May 22, 2024
1 parent 3532b8f commit cd8bbb6
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 232 deletions.
7 changes: 5 additions & 2 deletions src/ensembles/importance_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,18 @@ function importance_weights(obs::AbstractVector, pred::AbstractMatrix, R::Diagon
n_obs, ensemble_size = size(pred)
@assert n_obs == length(obs)
residual = obs .- pred
loglik = dropdims(-0.5*(1/diag(R))*residual.^2, dims=1)
loglik = sum(-0.5*inv(R)*residual.^2, dims=1)[1,:]

# Log of normalizing constant
log_z = logsumexp(loglik)

# Weights
logw = loglik .- log_z
weights = exp.(logw)
@assert length(weights) == ensemble_size && round(sum(weights), digits=10) == 1.0 "particle weights do not sum to unity!"

# sanity checks
@assert length(weights) == ensemble_size "weight vector has incorrect dimensions: $(size(weights))"
@assert abs(sum(weights) - 1.0) < sqrt(eps()) "particle weights do not sum to unity! ∑w=$(sum(weights))"

Neff = round(1/sum(weights.^2))

Expand Down
31 changes: 19 additions & 12 deletions src/forward_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct SimulatorForwardProblem{probType,obsType,configType,names} <: SciMLBase.A
config::configType
end

const SimulatorSciMLForwardProblem = SimulatorForwardProblem{<:SciMLBase.AbstractSciMLProblem}
const SimulatorSciMLForwardProblem{probType} = SimulatorForwardProblem{probType} where {probType<:SciMLBase.AbstractSciMLProblem}

"""
SimulatorForwardProblem(prob::SciMLBase.AbstractSciMLProblem, observables::SimulatorObservable...)
Expand Down Expand Up @@ -42,25 +42,32 @@ parameters `p0` and a default transient observable.
SimulatorForwardProblem(f, p0::AbstractVector) = SimulatorForwardProblem(f, p0, SimulatorObservable(:y, state -> state.u))

"""
SciMLBase.remaker_of(forward_prob::SimulatorForwardProblem)
Returns a function which will rebuild a `SimulatorForwardProblem` from its arguments.
The remaker function additionally provides a keyword argument `copy_observables` which,
if `true`, will `deepcopy` the observables to ensure independence. The default setting is `true`.
"""
function SciMLBase.remaker_of(forward_prob::SimulatorForwardProblem)
function remake_forward_prob(;
SciMLBase.remake(
forward_prob::SimulatorForwardProblem;
prob=forward_prob.prob,
observables=forward_prob.observables,
config=forward_prob.config,
copy_observables=true,
kwargs...
)
new_observables = copy_observables ? deepcopy(observables) : observables
return SimulatorForwardProblem(remake(prob; kwargs...), new_observables, config)
end
Rebuilds a `SimulatorForwardProblem` from its individual components. If `copy_observables=true`,
then `remake` will `deepcopy` the observables to ensure independence. The default setting is `true`.
"""
function SciMLBase.remake(
forward_prob::SimulatorForwardProblem;
prob=forward_prob.prob,
observables=forward_prob.observables,
config=forward_prob.config,
copy_observables=true,
kwargs...
)
new_observables = copy_observables ? deepcopy(observables) : observables
return SimulatorForwardProblem(remake(prob; kwargs...), new_observables, config)
end

SciMLBase.remaker_of(forward_prob::SimulatorForwardProblem) = (;kwargs...) -> remake(forward_prob; kwargs...)

# DiffEqBase dispatches to make solve/init interface work correctly
DiffEqBase.check_prob_alg_pairing(prob::SimulatorSciMLForwardProblem, alg) = DiffEqBase.check_prob_alg_pairing(prob.prob, alg)
DiffEqBase.isinplace(prob::SimulatorSciMLForwardProblem) = DiffEqBase.isinplace(prob.prob)
Expand Down
4 changes: 2 additions & 2 deletions test/ensembles/enis_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ end

@testset "EnIS: evensen_scalar_nonlinear" begin
x_true = 1.0
b_true = 0.2
b_true = 0.1
σ_y = 0.1
ensemble_size = 1024
x_prior = Normal(0,1)
# log-normal with mean 0.1 and stddev 0.2
# b_prior = autoprior(0.1, 0.2, lower=0.0, upper=Inf)
b_prior = Bijectors.TransformedDistribution(Normal(log(0.2), 1), Base.Fix1(broadcast, exp))
b_prior = Bijectors.TransformedDistribution(Normal(0.0, 1.0), Base.Fix1(broadcast, exp))
rng = Random.MersenneTwister(1234)
testprob = evensen_scalar_nonlinear(x_true, b_true; n_obs=100, rng, x_prior, b_prior)
transform = bijector(testprob.prior.model)
Expand Down
Loading

0 comments on commit cd8bbb6

Please sign in to comment.