Skip to content
4 changes: 2 additions & 2 deletions src/core/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,10 @@ end

function Base.run(mm::MarginalModel; ntimesteps::Int=typemax(Int))
run(mm.base, ntimesteps=ntimesteps)
run(mm.marginal, ntimesteps=ntimesteps)
run(mm.modified, ntimesteps=ntimesteps)
end

function build(mm::MarginalModel)
build(mm.base)
build(mm.marginal)
build(mm.modified)
end
2 changes: 1 addition & 1 deletion src/core/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ modelinstance_def(m::Model) = modeldef(modelinstance(m))

is_built(m::Model) = !(dirty(m.md) || modelinstance(m) === nothing)

is_built(mm::MarginalModel) = (is_built(mm.base) && is_built(mm.marginal))
is_built(mm::MarginalModel) = (is_built(mm.base) && is_built(mm.modified))

@delegate compinstance(m::Model, name::Symbol) => mi
@delegate has_comp(m::Model, name::Symbol) => md
Expand Down
12 changes: 10 additions & 2 deletions src/core/types/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ from those of another `marginal` Model` that has a difference of `delta`.
"""
struct MarginalModel <: AbstractModel
base::Model
marginal::Model
modified::Model
delta::Float64

function MarginalModel(base::Model, delta::Float64=1.0)
Expand All @@ -43,5 +43,13 @@ struct MarginalModel <: AbstractModel
end

function Base.getindex(mm::MarginalModel, comp_name::Symbol, name::Symbol)
return (mm.marginal[comp_name, name] .- mm.base[comp_name, name]) ./ mm.delta
return (mm.modified[comp_name, name] .- mm.base[comp_name, name]) ./ mm.delta
end

function Base.getproperty(base::MarginalModel, s::Symbol)
if (s == :marginal)
@warn("Use of 'MarginalModel.marginal' will be deprecated, in favor of 'MarginalModel.modified'");
return getfield(base, :modified);
end
return getfield(base, s);
end
6 changes: 3 additions & 3 deletions src/mcs/montecarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ function _copy_sim_params(sim_inst::SimulationInstance{T}) where T <: AbstractSi
for m in sim_inst.models
if m isa MarginalModel
push!(flat_model_list, m.base)
push!(flat_model_list, m.marginal)
push!(flat_model_list, m.modified)
else
push!(flat_model_list, m)
end
Expand All @@ -255,7 +255,7 @@ function _restore_sim_params!(sim_inst::SimulationInstance{T},
for m in sim_inst.models
if m isa MarginalModel
push!(flat_model_list, m.base)
push!(flat_model_list, m.marginal)
push!(flat_model_list, m.modified)
else
push!(flat_model_list, m)
end
Expand Down Expand Up @@ -358,7 +358,7 @@ function _perturb_params!(sim_inst::SimulationInstance{T}, trialnum::Int) where

for m in sim_inst.models
# If it's a MarginalModel, need to perturb the params in both the base and marginal modeldefs
mds = m isa MarginalModel ? [m.base.mi.md, m.marginal.mi.md] : [m.mi.md]
mds = m isa MarginalModel ? [m.base.mi.md, m.modified.mi.md] : [m.mi.md]
for md in mds
for trans in sim_inst.sim_def.translist
param = external_param(md, trans.paramname)
Expand Down
6 changes: 3 additions & 3 deletions test/mcs/test_marginalmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function post_trial1(si, trialnum, ntimesteps, tup)
end

si = run(simdef, mm1, 2, post_trial_func = post_trial1, results_in_memory = true)
@test all(iszero, si.results[1][:emissions, :E_Global][!, :E_Global]) # Test that the marginal emission saved from the MarginalModel are zeros (because there's no difference between mm.base and mm.marginal)
@test all(iszero, si.results[1][:emissions, :E_Global][!, :E_Global]) # Test that the marginal emission saved from the MarginalModel are zeros (because there's no difference between mm.base and mm.modified)


# Test running a vector of MarginalModels
Expand All @@ -35,7 +35,7 @@ end

si = run(simdef, [mm1, mm2], 2, post_trial_func = post_trial2, results_in_memory = true)
@test all(iszero, si.results[1][:emissions, :E_Global][!, :E_Global]) # Test that the regular model has non-zero emissions
@test all(iszero, si.results[2][:emissions, :E_Global][!, :E_Global]) # Test that the marginal emission saved from the MarginalModel are zeros (because there's no difference between mm.base and mm.marginal)
@test all(iszero, si.results[2][:emissions, :E_Global][!, :E_Global]) # Test that the marginal emission saved from the MarginalModel are zeros (because there's no difference between mm.base and mm.modified)


# Test running a vector of a Model and a MarginalModel
Expand All @@ -48,4 +48,4 @@ end

si = run(simdef, [m, mm1], 2, post_trial_func = post_trial3, results_in_memory = true)
@test all(!iszero, si.results[1][:emissions, :E_Global][!, :E_Global]) # Test that the regular model has non-zero emissions
@test all(iszero, si.results[2][:emissions, :E_Global][!, :E_Global]) # Test that the marginal emission saved from the MarginalModel are zeros (because there's no difference between mm.base and mm.marginal)
@test all(iszero, si.results[2][:emissions, :E_Global][!, :E_Global]) # Test that the marginal emission saved from the MarginalModel are zeros (because there's no difference between mm.base and mm.modified)
4 changes: 2 additions & 2 deletions test/test_marginal_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ set_param!(model1, :compA, :parA, x1)

mm = MarginalModel(model1, .5)

model2 = mm.marginal
model2 = mm.modified
update_param!(model2, :parA, x2)

run(mm)
Expand All @@ -33,7 +33,7 @@ end

mm2 = create_marginal_model(model1, 0.5)

mm2_marginal = mm2.marginal
mm2_marginal = @test_logs (:warn, "Use of 'MarginalModel.marginal' will be deprecated, in favor of 'MarginalModel.modified'") mm2.marginal
update_param!(mm2_marginal, :parA, x2)

run(mm2)
Expand Down