Simon Frost (@sdwfrost), 2022-12-22
The classical ODE version of the SIR model is:
- Deterministic
- Continuous in time
- Continuous in state
In this notebook, we try to infer the parameter values from a simulated dataset using profile likelihood to capture uncertainty in parameter estimates, using ProfileLikelihood.jl.
using OrdinaryDiffEq
using ProfileLikelihood
using StatsFuns
using Random
using Distributions
using Optimization
using OptimizationOptimJL
using QuasiMonteCarlo
using CairoMakie
using LaTeXStrings
using DataFrames
The following function provides the derivatives of the model, which it changes in-place. A variable is included for the cumulative number of infections,
function sir_ode!(du, u, p, t)
(S, I, R, C) = u
(β, c, γ) = p
N = S+I+R
infection = β*c*I/N*S
recovery = γ*I
@inbounds begin
du[1] = -infection
du[2] = infection - recovery
du[3] = recovery
du[4] = infection
end
nothing
end;
We set the timespan for simulations, tspan
, initial conditions, u0
, and parameter values, p
(which are unpacked above as [β, c, γ]
).
δt = 1.0
tmax = 40.0
tspan = (0.0,tmax);
u₀ = [990.0, 10.0, 0.0, 0.0]; # S, I, R, C
p = [0.05,10.0,0.25]; # β, c, γ
prob_ode = ODEProblem(sir_ode!, u₀, tspan, p)
sol_ode = solve(prob_ode, Tsit5(), saveat=δt);
We convert the output to an Array
for further processing.
out = Array(sol_ode);
The following code demonstrates how to plot the time series using Makie.jl.
colors = [:blue, :red, :green, :purple]
legends = ["S", "I", "R", "C"]
fig = Figure()
ax = Axis(fig[1, 1])
for i = 1:4
lines!(ax, sol_ode.t, out[i,:], label = legends[i], color = colors[i])
end
axislegend(ax)
ax.xlabel = "Time"
ax.ylabel = "Number"
fig
The cumulative counts are extracted, and the new cases per day are calculated from the cumulative counts.
C = out[4,:];
X = C[2:end] .- C[1:(end-1)];
Although the ODE system is deterministic, we can add measurement error to the counts of new cases. Here, a Poisson distribution is used, although a negative binomial could also be used (which would introduce an additional parameter for the variance).
Random.seed!(1234);
data = rand.(Poisson.(X));
ProfileLikelihood.jl expects a log-likelihood function with the parameter vector, θ
, the data, and the integrator used for the model - see the documentation on the integrator interface of DifferentialEquations.jl
for more details.
function ll(θ, data, integrator)
(i0,β) = θ
integrator.p[1] = β
integrator.p[2] = 10.0
integrator.p[3] = 0.25
I = i0*1000.0
u₀=[1000.0-I,I,0.0,0.0]
reinit!(integrator, u₀)
solve!(integrator)
sol = integrator.sol
out = Array(sol)
C = out[4,:]
X = C[2:end] .- C[1:(end-1)]
nonpos = sum(X .<= 0)
if nonpos > 0
return Inf
end
sum(logpdf.(Poisson.(X),data))
end;
We specify the lower and upper bounds of the parameter values, lb
and ub
respectively, and the initial parameter values, θ₀
.
lb = [0.0, 0.0]
ub = [1.0, 1.0]
θ = [0.01, 0.05]
θ₀ = [0.01, 0.1];
The following shows how to obtain a single log-likelihood value for a set of parameter values using the integrator interface.
integrator = init(prob_ode, Tsit5(); saveat = δt) # takes the same arguments as `solve`
ll(θ₀, data, integrator)
-1709.5158069805168
We use the log-likelihood function, ll
, to define a LikelihoodProblem
, along with initial parameter values, θ₀
, the function describing the model, sir_ode!
, the initial conditions, u₀
, and the maximum time.
syms = [:i₀, :β]
prob = LikelihoodProblem(
ll, θ₀, sir_ode!, u₀, tmax;
syms=syms,
data=data,
ode_parameters=p, # temp values for p
ode_kwargs=(verbose=false, saveat=δt),
f_kwargs=(adtype=Optimization.AutoFiniteDiff(),),
prob_kwargs=(lb=lb, ub=ub),
ode_alg=Tsit5()
);
Now that we have defined the LikelihoodProblem
, we can obtain the maximum likelhood estimate of the parameters using one of the algorithms in Optimization.jl
. Here, we use NelderMead
from Optim.jl
, imported with using OptimizationOptimJL
at the beginning of the notebook.
sol = mle(prob, NelderMead())
θ̂ = get_mle(sol);
Similar code can be used to obtain the profile likelihood intervals.
prof = profile(prob, sol; alg=NelderMead(), parallel=false)
confints = get_confidence_intervals(prof);
fig = plot_profiles(prof; latex_names=[L"i_0", L"\beta"])
fig
The following shows the fitted parameter estimates and the 95% confidence intervals based on profile likelihood.
ENV["COLUMNS"]=80
df_res = DataFrame(
Parameters = [:i₀, :β],
CILower = [confints[i][1] for i in 1:2],
CIUpper = [confints[i][2] for i in 1:2],
FittedValues = θ̂,
TrueValues = [0.01,0.05],
NominalStartValues = θ₀
)
df_res
2×6 DataFrame
Row │ Parameters CILower CIUpper FittedValues TrueValues Nominal
Sta ⋯
│ Symbol Float64 Float64 Float64 Float64 Float64
⋯
─────┼─────────────────────────────────────────────────────────────────────
─────
1 │ i₀ 0.00730514 0.0121367 0.00946008 0.01
⋯
2 │ β 0.0479153 0.0519872 0.0499348 0.05
1 column om
itted
ProfileLikelihood.jl
also provides a function to generate prediction intervals based on the profile likelihood intervals for individual parameters, and to combine the parameter-wise intervals to create a single interval. This requires a function that takes a vector of parameters, θ
, with a second argument that can be used to pass information such as the time span and the number of data points.
function prediction_function(θ, data)
(i0,β) = θ
tspan = data["tspan"]
npts = data["npts"]
t2 = LinRange(tspan[1]+1, tspan[2], npts)
t1 = LinRange(tspan[1], tspan[2]-1, npts)
I = i0*1000.0
prob = remake(prob_ode,u0=[1000.0-I,I,0.0,0.0],p=[β,10.0,0.25],tspan=tspan)
sol = solve(prob,Tsit5())
return sol(t2)[4,:] .- sol(t1)[4,:]
end;
npts = 1000
t_pred = LinRange(tspan[1]+1, tspan[2], npts)
d = Dict("tspan" => tspan, "npts" => npts);
exact_soln = prediction_function([0.01,0.05], d)
mle_soln = prediction_function(θ̂, d);
parameter_wise, union_intervals, all_curves, param_range =
get_prediction_intervals(prediction_function,
prof,
d);
The following figure shows individual intervals and the combined interval.
fig = Figure(fontsize=32, resolution=(1800, 900))
alp = join('a':'b')
latex_names = [L"i_0", L"\beta"]
for i in 1:2
ax = Axis(fig[1, i], title=L"(%$(alp[i])): Profile-wise PI for %$(latex_names[i])",
titlealign=:left, width=400, height=300)
band!(ax, t_pred, getindex.(parameter_wise[i], 1), getindex.(parameter_wise[1], 2), color=(:grey, 0.7), transparency=true)
lines!(ax, t_pred, exact_soln, color=:red)
lines!(ax, t_pred, mle_soln, color=:blue, linestyle=:dash)
lines!(ax, t_pred, getindex.(parameter_wise[i], 1), color=:black, linewidth=3)
lines!(ax, t_pred, getindex.(parameter_wise[i], 2), color=:black, linewidth=3)
end
ax = Axis(fig[1,3], title=L"(c):$ $ Union of all intervals",
titlealign=:left, width=400, height=300)
band!(ax, t_pred, getindex.(union_intervals, 1), getindex.(union_intervals, 2), color=(:grey, 0.7), transparency=true)
lines!(ax, t_pred, getindex.(union_intervals, 1), color=:black, linewidth=3)
lines!(ax, t_pred, getindex.(union_intervals, 2), color=:black, linewidth=3)
lines!(ax, t_pred, exact_soln, color=:red)
lines!(ax, t_pred, mle_soln, color=:blue, linestyle=:dash)
fig