Skip to content

Commit

Permalink
Add ODEObservable and update linear ODE example
Browse files Browse the repository at this point in the history
  • Loading branch information
bgroenks96 committed Jun 19, 2024
1 parent 44fb5bf commit 41a27b9
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 261 deletions.
607 changes: 359 additions & 248 deletions examples/linearode/Manifest.toml

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions examples/linearode/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SimulationBasedInference = "78927d98-f421-490e-8789-96b006983a5c"
Expand Down
41 changes: 30 additions & 11 deletions examples/linearode/linearode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ using OrdinaryDiffEq
using Plots, StatsPlots
import Random

# extensions
using DynamicHMC

using DisplayAs #hide

# and then initialize a random number generator for reproducibility.
Expand All @@ -38,35 +41,36 @@ odeprob = ODEProblem(ode_func, [1.0], tspan, ode_p)
# state from the ODE integrator. The `SimulatorObservable(name, func, t0, tsave, coords)` additionally takes
# an initial time point, a vector of observed time points, and a tuple specifying the shape or coordiantes of
# the observable at each time point. Here, `(1,)` indicates that the state is a one-dimensional vector.
tsave = tspan[1]+0.1:0.2:tspan[end];
dt = 0.2
tsave = tspan[begin]+dt:dt:tspan[end];
n_obs = length(tsave);
observable = SimulatorObservable(:y, integrator -> integrator.u, tspan[1], tsave, (1,), samplerate=0.01);
observable = ODEObservable(:y, odeprob, tsave, samplerate=0.01);
forward_prob = SimulatorForwardProblem(odeprob, observable)

# In order to set up our synthetic example, we need to some data to infer from.
# Here we generate the data by running the forward model and adding Gaussian noise.
ode_solver = Tsit5()
forward_sol = solve(forward_prob, ode_solver, saveat=0.01);
forward_sol = solve(forward_prob, ode_solver);
true_obs = get_observable(forward_sol, :y)
noise_scale = 0.05
noisy_obs = true_obs .+ noise_scale*randn(rng, n_obs);
# plot the results
plot(forward_sol.sol, label="True solution", linewidth=3, color=:black)
plot(true_obs, label="True solution", linewidth=3, color=:black)
plt = scatter!(tsave, noisy_obs, label="Noisy observations", alpha=0.5)
DisplayAs.Text(DisplayAs.PNG(plt)) #hide

# Here we set our priors. We use a weakly informative `Beta(2,2)` prior which puts
# less mass at the tails. We could also use a flat prior `Beta(1,1)` if we wanted to
# be more agnostic to further minimize the influence of the prior.
model_prior = prior=Beta(2,2));
noise_scale_prior = prior=Exponential(noise_scale));
noise_scale_prior = prior=Exponential(0.1));
p1 = Plots.plot(model_prior.dist.α)
p2 = Plots.plot(noise_scale_prior.dist.σ)
plt = Plots.plot(p1, p2)
DisplayAs.Text(DisplayAs.PNG(plt)) #hide

# Now we assign a simple Gaussian likelihood for the obsevation/noise model.
lik = SimulatorLikelihood(IsoNormal, observable, noisy_obs, noise_scale_prior);
lik = IsotropicGaussianLikelihood(observable, noisy_obs, noise_scale_prior);
nothing #hide

# We now have all of the ingredients needed to set up and solve the inference problem.
Expand Down Expand Up @@ -96,7 +100,7 @@ posterior_obs_std_enis = std(prior_ens_obs, weights(importance_weights), 2)[:,1]
posterior_mean_enis = mean(prior_ens, weights(importance_weights))

# Now we plot the prior vs. the posterior predictions.
plot(tsave, true_obs, label="True solution", c=:black, linewidth=2, title="Importance weighted posterior predictions")
plot(tsave, true_obs, label="True solution", c=:black, linewidth=2, title="Linear ODE posterior predictions (EnIS)")
plot!(tsave, prior_ens_obs_mean, label="Prior", c=:gray, linestyle=:dash, ribbon=2*prior_ens_obs_std, alpha=0.5, linewidth=2)
plot!(tsave, posterior_obs_mean_enis, label="Posterior", c=:blue, linestyle=:dash, ribbon=2*posterior_obs_std_enis, alpha=0.7, linewidth=2)
plt = scatter!(tsave, noisy_obs, label="Noisy observations", c=:orange)
Expand Down Expand Up @@ -136,11 +140,26 @@ plot!(tsave, posterior_obs_mean_eks, label="Posterior", c=:blue, linestyle=:dash
plt = scatter!(tsave, noisy_obs, label="Noisy observations", c=:black)
DisplayAs.Text(DisplayAs.PNG(plt)) #hide

# Finally, we can plot the posterior predictions of all of the algorithms and compare.
# Now solve using the gold standard No U-turn sampler (NUTS). This will take a few minutes to run.
# Note that this would generally not be feasible more expensive simulators.
hmc_sol = @time solve(inference_prob, MCMC(NUTS()), num_samples=1000, rng=rng);
posterior_hmc = transpose(Array(hmc_sol.result))
posterior_mean_hmc = mean(posterior_hmc, dims=2)
posterior_obs_hmc = reduce(hcat, map(out -> out.y, hmc_sol.storage.outputs))
posterior_obs_mean_hmc = mean(posterior_obs_hmc, dims=2)[:,1]
posterior_obs_std_hmc = std(posterior_obs_hmc, dims=2)[:,1]
plot(tsave, true_obs, label="True solution", c=:black, linewidth=2, title="EKS")
plot!(tsave, prior_ens_obs_mean, label="Prior", c=:gray, linestyle=:dash, ribbon=2*prior_ens_obs_std, alpha=0.5, linewidth=2)
plot!(tsave, posterior_obs_mean_enis, label="Posterior (EnIS)", linestyle=:dash, ribbon=2*posterior_obs_std_enis, alpha=0.5, linewidth=3)
plot!(tsave, posterior_obs_mean_esmda, label="Posterior (ES-MDA)", linestyle=:dash, ribbon=2*posterior_obs_std_esmda, alpha=0.5, linewidth=3)
plot!(tsave, posterior_obs_mean_eks, dims=2, label="Posterior (EKS)", linestyle=:dash, ribbon=2*posterior_obs_std_eks, alpha=0.5, linewidth=3)
plot!(tsave, posterior_obs_mean_hmc, label="Posterior", c=:blue, linestyle=:dash, ribbon=2*posterior_obs_std_hmc, alpha=0.7, linewidth=2)
plt = scatter!(tsave, noisy_obs, label="Noisy observations", c=:black)

# Finally, we can plot the posterior predictions of all of the algorithms and compare.
plot(tsave, true_obs, label="True solution", c=:black, linewidth=2, title="Linear ODE: Inference algorithm comparison", dpi=300, xlabel="time")
plot!(tsave, prior_ens_obs_mean, label="Prior", c=:gray, linestyle=:dash, ribbon=2*prior_ens_obs_std, alpha=0.4, linewidth=2)
plot!(tsave, posterior_obs_mean_enis, label="Posterior (EnIS)", linestyle=:dash, ribbon=2*posterior_obs_std_enis, alpha=0.4, linewidth=3)
plot!(tsave, posterior_obs_mean_esmda, label="Posterior (ES-MDA)", linestyle=:dash, ribbon=2*posterior_obs_std_esmda, alpha=0.4, linewidth=3)
plot!(tsave, posterior_obs_mean_eks, label="Posterior (EKS)", linestyle=:dash, ribbon=2*posterior_obs_std_eks, alpha=0.4, linewidth=3)
plot!(tsave, posterior_obs_mean_hmc, label="Posterior (HMC)", linestyle=:dash, ribbon=2*posterior_obs_std_hmc, alpha=0.4, linewidth=3)
plt = scatter!(tsave, noisy_obs, label="Noisy observations", c=:black)
savefig("res/linearode_poseterior_preds_comparison.png") #hide
DisplayAs.Text(DisplayAs.PNG(plt)) #hide
Binary file added res/linearode_poseterior_preds_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/linearode_posterior_preds_enis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion src/SimulationBasedInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ include("inference_problem.jl")
export SimulatorForwardSolver
include("forward_solve.jl")

export SimulatorODEForwardSolver
export SimulatorODEForwardSolver, ODEObservable
include("forward_solve_ode.jl")

include("emulators/Emulators.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/ensembles/ensemble_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
function sample_ensemble_predictive(
sol::SimulatorInferenceSolution{<:EnsembleInferenceAlgorithm},
new_storage::SimulationData=SimulationArrayStorage();
num_samples_per_sim::Int=10,
num_samples_per_sim::Int=1,
pred_transform=identity,
iterations=[],
rng::Random.AbstractRNG=Random.default_rng(),
Expand Down
4 changes: 4 additions & 0 deletions src/forward_solve_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ function SimulatorForwardProblem(
return SimulatorForwardProblem(prob, named_observables, SimulatorODEConfig(stepfunc, obs_to_prob_time))
end

function ODEObservable(name::Symbol, prob::AbstractODEProblem, tsave; obsfunc=identity, kwargs...)
return SimulatorObservable(name, integrator -> obsfunc(integrator.u), prob.tspan[1], tsave, size(prob.u0); kwargs...)
end

"""
SimulatorODEForwardSolver{algType,uType,tType,iip,integratorType<:AbstractODEIntegrator{algType,iip,uType,tType}} <: AbstractODEIntegrator{algType,iip,uType,tType}
Expand Down

0 comments on commit 41a27b9

Please sign in to comment.