# Inference of a function map using Turing.jl
Simon Frost (@sdwfrost), 2025-10-04

## Introduction

In this notebook we demonstrate how to implement a discrete time function map model in Julia, and perform Bayesian inference using Turing.jl. The model is a simple SIR model with an additional state variable to track cumulative incidence. We assume that the number of observed cases in each time step is a binomial sample of the true incidence, with a reporting probability `q`. The structure of this notebook is similar to that of the [Markov POMP tutorial](https://github.com/epirecipes/sir-julia/markdown/markov_pomp/markov_pomp.md), except by choosing a function map rather than a Markov model, we can use automatic differentiation and the NUTS sampler in Turing.jl, as (a) there is no random number generation during inference and (b) all the parameters are continuous.

## Libraries

In [None]:
using Turing
using MCMCChains
using Distributions
using Random
using Plots
using StatsPlots
using Base.Threads;

The binomial distribution in `Distributions.jl` only accepts integer values, so we use a custom `GeneralizedBinomial` distribution that can handle non-integer values. We include the `GeneralizedBinomial` distribution from a separate file.

In [None]:
include("generalized_binomial.jl")
import .GeneralizedBinomialExt: GeneralizedBinomial;

## Utility functions

To assist in comparison with the continuous time models, we define a function that takes a constant rate, `r`, over a timespan, `t`, and converts it to a proportion.

In [None]:
@inline function rate_to_proportion(r,t)
    1-exp(-r*t)
end;

## Transitions

We define a function that takes the 'old' state variables, `u`, and writes the 'new' state variables into `du.` Note that the timestep, `δt`, is passed as an explicit parameter.

In [None]:
function sir_map!(du,u,p,t)
    (S, I, R, C) = u
    (β, γ, q, N, δt) = p
    infection = rate_to_proportion(β*I/N, δt)*S
    recovery = rate_to_proportion(γ, δt)*I
    @inbounds begin
        du[1] = S - infection
        du[2] = I + infection - recovery
        du[3] = R + recovery
        du[4] = C + infection
    end
    nothing
end;

In [None]:
function solve_map(f, u0, nsteps, p)
    # Pre-allocate array with correct type
    sol = similar(u0, length(u0), nsteps + 1)
    # Initialize the first column with the initial state
    sol[:, 1] = u0
    # Iterate over the time steps
    @inbounds for t in 2:nsteps+1
        u = @view sol[:, t-1] # Get the current state
        du = @view sol[:, t]  # Prepare the next state
        f(du, u, p, t)        # Call the function to update du
    end
    return sol
end;

## Time domain

In [None]:
δt = 1.0 # Time step
nsteps = 40
tmax = nsteps*δt
t = 0.0:δt:tmax;

## Initial conditions

Note that we define the state variables as floating point rather than as integers (c.f. the Markov model examples), as we will be treating the initial number of infected individuals as a continuous parameter in our inference.

In [None]:
u0 = [990.0, 10.0, 0.0, 0.0];

## Parameter values

We define the parameters as a named tuple; this will make it easier to modify individual parameters during inference.

In [None]:
p = (β=0.5, γ=0.25, q=0.75, N=1000.0, δt=δt);

## Running the model

In [None]:
sol_map = solve_map(sir_map!, u0, nsteps, p);

## Post-processing

We unpack the solution into separate variables for convenience.

In [None]:
S, I, R, C = eachrow(sol_map);

## Plotting

We can now plot the results.

In [None]:
plot(t,
     [S I R C],
     label=["S" "I" "R" "C"],
     xlabel="Time",
     ylabel="Number")

## Inference using Turing.jl

We first simulate some observed data, `Y`, by taking a binomial sample of the incidence (the difference in cumulative cases between time steps) using a parameter `q`, which represents the fraction of cases that are reported. We use `GeneralizedBinomial` to allow for non-integer values of `C`.

In [None]:
Y = rand.(GeneralizedBinomial.(C[2:end]-C[1:end-1], p.q));

In order to provide additional information on the reporting probability, we will also draw a sample of individuals from the final timepoint, and record how many of them have been infected and recovered. We will use this information to estimate the reporting level, `q`.

In [None]:
ZN = 100
Z = rand(GeneralizedBinomial(ZN, R[end]/p.N));

We now define a function to calculate the log probability density of the observed case data, `Y`, given the output from `solve_map`. This function runs the model using the current parameter values, extracts the predicted incidence per time step, and calculates the log probability density using the `GeneralizedBinomial` distribution.

In [None]:
function logpdf(Y, sol, q)
    C = sol[4,:]
    ll = 0.0
    X = (C[2:end] .- C[1:end-1])
    for i in 1:length(Y)
        ll += Distributions.logpdf(GeneralizedBinomial(X[i], q), Y[i])
    end
    return ll
end;

### Estimation using case data only

The Turing model takes the observed data, `Y`, the initial conditions, `u0`, the number of time steps, `nsteps`, and the fixed parameters, `p`, as arguments. It defines priors for the parameters we want to estimate (`β` and `I₀`), updates the initial conditions and parameter tuple with the current MCMC values, and adds the log-likelihood to the model using `Turing.@addlogprob!`.

In [None]:
@model function sir_map_estimate_q(Y, u0, nsteps, p)
    # Priors for the parameters we want to estimate
    β ~ Uniform(0.25, 0.75)
    I₀ ~ Uniform(5.0, 50.0)
    q ~ Uniform(0.1, 0.9)

    # Create parameter tuple with current MCMC values
    p_new = merge(p, (β = β, q = q))
    u0_new = [p.N - I₀, I₀, 0.0, 0.0]

    # Solve the model with the current parameters
    sol = solve_map(sir_map!, u0_new, nsteps, p_new)

    # Add the log-likelihood of the cases to the model
    Turing.@addlogprob! logpdf(Y, sol, q)

    return nothing
end;

In [None]:
sir_model_estimate_q = sir_map_estimate_q(Y, u0, nsteps, p)
chain_estimate_q = sample(sir_model_estimate_q, NUTS(0.65), 10000; progress=false);

In [None]:
describe(chain_estimate_q)

In [None]:
plot(chain_estimate_q)

In [None]:
nsims = 1000
I₀_means = Array{Float64}(undef, nsims)
β_means = Array{Float64}(undef, nsims)
q_means = Array{Float64}(undef, nsims)
I₀_coverage = Array{Float64}(undef, nsims)
β_coverage = Array{Float64}(undef, nsims)
q_coverage = Array{Float64}(undef, nsims)
Threads.@threads for i in 1:nsims
    Y_sim = rand.(GeneralizedBinomial.(C[2:end]-C[1:end-1], p.q))
    r = sample(sir_map_estimate_q(Y_sim, u0, nsteps, p),
               NUTS(1000,0.65),
               10000;
               verbose=false,
               progress=false,
               initial_params=(β=0.5, I₀=10.0, q=0.75))
    I₀_means[i] = mean(r[:I₀])
    I₀_cov = sum(r[:I₀] .<= u0[2]) / length(r[:I₀])
    β_means[i] = mean(r[:β])
    β_cov = sum(r[:β] .<= p.β) / length(r[:β])
    q_means[i] = mean(r[:q])
    q_cov = sum(r[:q] .<= p.q) / length(r[:β])
    I₀_coverage[i] = I₀_cov
    β_coverage[i] = β_cov
    q_coverage[i] = q_cov
end;

In [None]:
# Convenience function to check if the true value is within the credible interval
function in_credible_interval(x, lwr=0.025, upr=0.975)
    return x >= lwr && x <= upr
end;

In [None]:
pl_β_coverage = histogram(β_coverage, bins=0:0.1:1.0, label=false, title="β", ylabel="Density", density=true, xrotation=45, xlim=(0.0,1.0))
pl_I₀_coverage = histogram(I₀_coverage, bins=0:0.1:1.0, label=false, title="i₀", ylabel="Density", density=true, xrotation=45, xlim=(0.0,1.0))
pl_q_coverage = histogram(q_coverage, bins=0:0.1:1.0, label=false, title="q", ylabel="Density", density=true, xrotation=45, xlim=(0.0,1.0))
plot(pl_β_coverage, pl_I₀_coverage, pl_q_coverage, layout=(1,3), plot_title="Distribution of CDF of true value")

The coverage of the 95% credible intervals is given by the proportion of simulations where the true value is within the interval.

In [None]:
sum(in_credible_interval.(β_coverage)) / nsims

In [None]:
sum(in_credible_interval.(I₀_coverage)) / nsims

In [None]:
sum(in_credible_interval.(q_coverage)) / nsims

We can also look at the distribution of the posterior means, which should fall around the true value.

In [None]:
pl_β_means = histogram(β_means, label=false, title="β", ylabel="Density", density=true, xrotation=45, xlim=(0.48, 0.52))
vline!([p.β], label="True value")
pl_I₀_means = histogram(I₀_means, label=false, title="I₀", ylabel="Density", density=true, xrotation=45, xlim=(5.0,15.0))
vline!([u0[2]], label="True value")
pl_q_means = histogram(q_means, label=false, title="q", ylabel="Density", density=true, xrotation=45, xlim=(0.65,0.85))
vline!([p.q], label="True value")
plot(pl_β_means, pl_I₀_means, pl_q_means, layout=(1,3), plot_title="Distribution of posterior means")

## Estimation using case data and final prevalence survey

In [None]:
@model function sir_map_estimate_q_prevalence(Y, Z, ZN, u0, nsteps, p)
    # Priors for the parameters we want to estimate
    β ~ Uniform(0.25, 0.75)
    I₀ ~ Uniform(5.0, 50.0)
    q ~ Uniform(0.1, 0.9)

    # Create parameter tuple with current MCMC values
    p_new = merge(p, (β = β, q = q))
    u0_new = [p.N - I₀, I₀, 0.0, 0.0]

    # Solve the model with the current parameters
    sol = solve_map(sir_map!, u0_new, nsteps, p_new)

    # Add the log-likelihood of the cases to the model
    Turing.@addlogprob! logpdf(Y, sol, q)
    
    # Calculate contribution from end prevalence study
    zp = sol[3,end]/p.N
    zp = max(min(zp,1.0),0.0) # To ensure boundedness
    Z ~ GeneralizedBinomial(ZN, zp)

    return nothing
end;

In [None]:
sir_model_estimate_q_prevalence = sir_map_estimate_q_prevalence(Y, Z, ZN, u0, nsteps, p)
chain_estimate_q_prevalence = sample(sir_model_estimate_q_prevalence, NUTS(0.65), 10000; progress=false);

In [None]:
describe(chain_estimate_q_prevalence)

In [None]:
plot(chain_estimate_q_prevalence)

In [None]:
I₀_prev_means = Array{Float64}(undef, nsims)
β_prev_means = Array{Float64}(undef, nsims)
q_prev_means = Array{Float64}(undef, nsims)
I₀_prev_coverage = Array{Float64}(undef, nsims)
β_prev_coverage = Array{Float64}(undef, nsims)
q_prev_coverage = Array{Float64}(undef, nsims)
Threads.@threads for i in 1:nsims
    Y_sim = rand.(GeneralizedBinomial.(C[2:end]-C[1:end-1], p.q))
    Z_sim = rand(GeneralizedBinomial(ZN, R[end]/p.N))
    r = sample(sir_map_estimate_q_prevalence(Y_sim, Z_sim, ZN, u0, nsteps, p),
               NUTS(1000,0.65),
               10000;
               verbose=false,
               progress=false,
               initial_params=(β=0.5, I₀=10.0, q=0.75))
    I₀_prev_means[i] = mean(r[:I₀])
    I₀_cov = sum(r[:I₀] .<= u0[2]) / length(r[:I₀])
    β_prev_means[i] = mean(r[:β])
    β_cov = sum(r[:β] .<= p.β) / length(r[:β])
    q_prev_means[i] = mean(r[:q])
    q_cov = sum(r[:q] .<= p.q) / length(r[:q])
    I₀_prev_coverage[i] = I₀_cov
    β_prev_coverage[i] = β_cov
    q_prev_coverage[i] = q_cov
end;

In [None]:
pl_β_prev_coverage = histogram(β_prev_coverage, bins=0:0.1:1.0, label=false, title="β", ylabel="Density", density=true, xrotation=45, xlim=(0.0,1.0))
pl_I₀_prev_coverage = histogram(I₀_prev_coverage, bins=0:0.1:1.0, label=false, title="i₀", ylabel="Density", density=true, xrotation=45, xlim=(0.0,1.0))
pl_q_prev_coverage = histogram(q_prev_coverage, bins=0:0.1:1.0, label=false, title="q", ylabel="Density", density=true, xrotation=45, xlim=(0.0,1.0))
plot(pl_β_prev_coverage, pl_I₀_prev_coverage, pl_q_prev_coverage, layout=(1,3), plot_title="Distribution of CDF of true value")

The coverage of the 95% credible intervals is given by the proportion of simulations where the true value is within the interval.

In [None]:
sum(in_credible_interval.(β_prev_coverage)) / nsims

In [None]:
sum(in_credible_interval.(I₀_prev_coverage)) / nsims

In [None]:
sum(in_credible_interval.(q_prev_coverage)) / nsims

We can also look at the distribution of the posterior means, which should fall around the true value.

In [None]:
pl_β_prev_means = histogram(β_prev_means, label=false, title="β", ylabel="Density", density=true, xrotation=45, xlim=(0.48, 0.52))
vline!([p.β], label="True value")
pl_I₀_prev_means = histogram(I₀_prev_means, label=false, title="I₀", ylabel="Density", density=true, xrotation=45, xlim=(5.0,15.0))
vline!([u0[2]], label="True value")
pl_q_prev_means = histogram(q_prev_means, label=false, title="q", ylabel="Density", density=true, xrotation=45, xlim=(0.65,0.85))
vline!([p.q], label="True value")
plot(pl_β_prev_means, pl_I₀_prev_means, pl_q_prev_means, layout=(1,3), plot_title="Distribution of posterior means")

## Discussion

The use of a continuous state, deterministic model with continuous parameters allowed us to use the NUTS sampler in Turing.jl, which is generally more efficient than standard Metropolis Hastings. The additional data on underreporting provided to the model in this example did not appear to add very much information to the inference.