# Simulate and fit the SIjR model

In this notebook, we simulate data from a $SI^jR$ model, and fit a Stan model to this data.

In [None]:
import numpy as np
import cmdstanpy
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from scipy.integrate import solve_ivp
import scipy.stats as sts

In [None]:
plt.rcParams.update({'font.size': 18})

## Compile the Stan models

In [None]:
## smoothed hypoexponential approximation
sm_smooth = cmdstanpy.CmdStanModel(stan_file="../stan-models/SIjR-model-smooth.stan")
## fixed hypoexponential approximation
sm_fixed = cmdstanpy.CmdStanModel(stan_file="../stan-models/SIjR-model-fixed.stan")

## Generate synthetic data

In [None]:
def sir_model(t, y, beta, tau, j):
    n = len(y)
    dy = np.zeros(n)
    dy[0] = -beta * y[0] * sum(y[1:-1])
    dy[1] = beta * y[0] * sum(y[1:-1]) - j/tau * y[1]
    for i in range(2,n-1):
        dy[i] = j/tau * (y[i-1] - y[i])
    dy[-1] = j/tau * y[-2]
    return dy

def gen_incidence_data(beta_gt, j_gt, eps_gt, tau_gt, N, t0, tmax, M):
    ## generate data
    u0 = np.zeros(j_gt+2)
    u0[0] = 1-eps_gt
    u0[1] = eps_gt
    dt = (tmax-t0)/N
    t_obs = np.linspace(dt, tmax, N)
    sol = solve_ivp(lambda t, u: sir_model(t, u, beta_gt, tau_gt, j_gt), (t0, tmax), u0, t_eval=t_obs)
    S = [u0[0]] + [u[0] for u in sol.y.T]
    Yhat = [(S[i] - S[i+1])*M for i in range(N)]
    Y = [sts.poisson.rvs(Yhat[i]) for i in range(N)]
    return t_obs, Y, Yhat

def sample_generation_interval(alpha, beta):
    ## sample biased duration
    TI = sts.gamma.rvs(alpha+1, scale=1/beta)
    ## sample from uniform(0,TI)
    return sts.uniform.rvs(scale=TI)    
   
def density_generation_interval(t, alpha, beta):
    mu = alpha/beta
    return sts.gamma.sf(t, alpha, scale=1/beta)/mu

In [None]:
beta_gt = 0.5
j_gt = 4
eps_gt = 1e-3
tau_gt = 5.0
M = 1000
t0, tmax = 0, 40
N = 40
L = 100

## incidence data
t_obs, Y, Yhat_gt = gen_incidence_data(beta_gt, j_gt, eps_gt, tau_gt, N, t0, tmax, M)

## serial interval data
GenInt = [sample_generation_interval(j_gt, j_gt/tau_gt) for _ in range(L)]

In [None]:
## plot the data

fig, (ax, bx) = plt.subplots(1,2, figsize=(14,5))

ax.scatter(t_obs, Y)
ax.plot(t_obs, Yhat_gt)

SrtGenInt = [0] + sorted(GenInt)
Fs = np.linspace(0, 1, L+1)

bx.step(SrtGenInt, Fs[::-1], where='post', color='k')

ax.set_xlabel("time")
ax.set_ylabel("incidence")

bx.set_xlabel("serial interval")
bx.set_ylabel("emperical survival function")

## Fit the stan model to data

Either use the smoothed or fixed approximation.

In [None]:
Kmax = 10 ## limits the max number of used compartments

NumGridPts = 100 ## points for simulation and plotting

data_dict = {
    "N" : N,
    "M" : M,
    "T" : t_obs,
    "Y" : Y,
    "K" : Kmax,
    "h" : 1e-2, ## required for smoothed approximation
    "L" : L, ## if zero, no addl data
    "GenInt" : GenInt,
    "NumGridPts" : NumGridPts,
}

init_dict = {
    "beta" : beta_gt,
    "eps" : eps_gt,
    "tau" : tau_gt
}

## fit model:
sam = sm_smooth.sample(data=data_dict, chains=2, iter_sampling=1000, iter_warmup=1000,
                inits=init_dict, show_progress=False, output_dir="../stan-cache/",
                step_size=0.01)

## extract parameters:
traces = sam.stan_variables()

## Make a plot of the fit, data and parameter estimates

In [None]:
parnames = ["beta", "tau", "j", "eps"]
pretty_parnames = ["$\\beta$", "$\\tau$", "$j$", "$\\epsilon$"]
gt_vals = [beta_gt, tau_gt, j_gt, eps_gt]
n = len(parnames)

fig = plt.figure(figsize=(14,13))

d = 4
gs = GridSpec(d*n,d*n, wspace=1.2, hspace=1.2)

axs = []
bxs = []

for i1, pn1 in enumerate(parnames):
    for i2, pn2 in enumerate(parnames):
        if i1 > i2:
            ax = fig.add_subplot(gs[d*i1:d*(i1+1),d*i2:d*(i2+1)])
            xs, ys = traces[pn2], traces[pn1]
            xy = np.vstack([xs, ys])
            z = sts.gaussian_kde(xy)(xy)
            ax.scatter(xs, ys, s=5, c=z)
            if i2 != 0:
                ax.axes.get_yaxis().set_visible(False)
            else:
                ax.set_ylabel(pretty_parnames[i1])
            if i1 != n-1:
                ax.axes.get_xaxis().set_visible(False)
            else:
                ax.set_xlabel(pretty_parnames[i2])
            axs.append(ax)
        elif i1 == i2:
            ax = fig.add_subplot(gs[d*i1:d*(i1+1),d*i2:d*(i2+1)])
            xs = traces[pn1]
            us = np.linspace(np.min(xs), np.max(xs), 1000)
            z = sts.gaussian_kde(xs)(us)
            ax.plot(us, z, linewidth=2)
            ax.fill_between(us, z, alpha=0.5, linewidth=0)
            ax.set_title(pretty_parnames[i1])
            ax.axvline(x=gt_vals[i1], color='k')
            ax.yaxis.tick_right()
            ax.yaxis.set_label_position("right")
            ax.set_ylabel("density")
            bxs.append(bx)

## plot data and fit
ax = fig.add_subplot(gs[:3,9:])

Yhat = traces["Yhat"]
p = [2.5, 97.5]
Yhat_CrI = np.percentile(Yhat, p, axis=0)
Yhat_mean = [np.mean(y) for y in Yhat.T]

Ysim = traces["Ysim"]
Ysim_CrI = np.percentile(Ysim, p, axis=0)

ax.plot(t_obs, Yhat_mean, color='tab:blue', zorder=1, label="prediction")
ax.fill_between(t_obs, *Yhat_CrI, color='tab:blue', alpha=0.5)
ax.fill_between(t_obs, *Ysim_CrI, color='tab:blue', alpha=0.3)

ax.scatter(t_obs, Y, color='k', zorder=2, label="data")

ax.set_ylabel("observed cases ($C$)")
ax.set_xlabel("time ($t$)")
ax.legend(fontsize='xx-small')

## plot serial interval data and fit
ax = fig.add_subplot(gs[4:7,11:])

SrtGenInt = [0] + sorted(GenInt)
Fs = np.linspace(0, 1, L+1)

ax.step(SrtGenInt, Fs[::-1], where='post', color='k', zorder=2)

gen_int_surv = np.exp(traces["gen_int_surv"])
ts = np.linspace(0, max(GenInt), NumGridPts)
gen_int_surv_mean = np.mean(gen_int_surv, axis=0)
gen_int_surv_low = np.percentile(gen_int_surv, axis=0, q=2.5)
gen_int_surv_high = np.percentile(gen_int_surv, axis=0, q=97.5)

ax.plot(ts, gen_int_surv_mean, linewidth=3, zorder=1)
ax.fill_between(ts, gen_int_surv_low, gen_int_surv_high, linewidth=0, alpha=0.5, zorder=1)

ax.set_ylabel("$Pr[T > t]$")
ax.set_xlabel("length generation interval ($t$)")

## add labels

fig.text(0.1, 0.9, 'A', fontsize='x-large')
fig.text(0.55, 0.9, 'B', fontsize='x-large')
fig.text(0.6, 0.675, 'C', fontsize='x-large')

            
fig.align_ylabels()

fig.savefig("../fit-posterior-density.png", bbox_inches='tight', dpi=300)

### Separate figures for fit and posterior

Simulated data and psoterior predictive checks

* A: cases and predicted incidence
* B: emperical CDF of the serial intervals and model predictions of the CDF

In [None]:
fig, axs = plt.subplots(1,2, figsize=(14,4))

## plot data and fit
ax = axs[0]

ax.scatter(t_obs, Y, color='k', zorder=2, label="data")

Yhat = traces["Yhat"]
p = [2.5, 97.5]
Yhat_CrI = np.percentile(Yhat, p, axis=0)
Yhat_mean = [np.mean(y) for y in Yhat.T]

Ysim = traces["Ysim"]
Ysim_CrI = np.percentile(Ysim, p, axis=0)

ax.plot(t_obs, Yhat_mean, color='tab:blue', zorder=1, linewidth=2, label="prediction")
ax.fill_between(t_obs, *Yhat_CrI, color='tab:blue', alpha=0.5)
ax.fill_between(t_obs, *Ysim_CrI, color='tab:blue', alpha=0.3)

ax.set_ylabel("observed cases ($C$)")
ax.set_xlabel("time ($t$)")
ax.legend(fontsize='x-small')

## plot serial interval data and fit
ax = axs[1]

SrtGenInt = [0] + sorted(GenInt)
Fs = np.linspace(0, 1, L+1)

ax.step(SrtGenInt, Fs[::-1], where='post', color='k', zorder=2, label="data")

gen_int_surv = np.exp(traces["gen_int_surv"])
ts = np.linspace(0, max(GenInt), NumGridPts)
gen_int_surv_mean = np.mean(gen_int_surv, axis=0)
gen_int_surv_low = np.percentile(gen_int_surv, axis=0, q=2.5)
gen_int_surv_high = np.percentile(gen_int_surv, axis=0, q=97.5)

ax.plot(ts, gen_int_surv_mean, linewidth=2, zorder=1, label="prediction")
ax.fill_between(ts, gen_int_surv_low, gen_int_surv_high, linewidth=0, alpha=0.5, zorder=1)

ax.set_ylabel("$Pr[T > t]$")
ax.set_xlabel("length generation interval ($t$)")
ax.legend(fontsize='x-small')

## add labels

for X, ax in zip("AB", axs):
    ax.text(-0.05, 1.05, X, fontsize='x-large', transform=ax.transAxes)
            
fig.savefig("../fit.pdf", bbox_inches='tight', dpi=300)

Marginal and joint posterior parameter distributions

In [None]:
parnames = ["beta", "tau", "j", "eps"]
pretty_parnames = ["$\\beta$", "$\\tau$", "$j$", "$\\epsilon$"]
gt_vals = [beta_gt, tau_gt, j_gt, eps_gt]
n = len(parnames)

fig = plt.figure(figsize=(14,13))

gs = GridSpec(n,n, wspace=0.2, hspace=0.2)

axs = []
bxs = []

for i1, pn1 in enumerate(parnames):
    for i2, pn2 in enumerate(parnames):
        if i1 > i2:
            ax = fig.add_subplot(gs[i1:i1+1,i2:i2+1])
            xs, ys = traces[pn2], traces[pn1]
            xy = np.vstack([xs, ys])
            z = sts.gaussian_kde(xy)(xy)
            ax.scatter(xs, ys, s=5, c=z)
            if i2 != 0:
                ax.axes.get_yaxis().set_visible(False)
            else:
                ax.set_ylabel(pretty_parnames[i1])
            if i1 != n-1:
                ax.axes.get_xaxis().set_visible(False)
            else:
                ax.set_xlabel(pretty_parnames[i2])
            axs.append(ax)
        elif i1 == i2:
            ax = fig.add_subplot(gs[i1:i1+1,i2:i2+1])
            xs = traces[pn1]
            us = np.linspace(np.min(xs), np.max(xs), 1000)
            z = sts.gaussian_kde(xs)(us)
            ax.plot(us, z, linewidth=2)
            ax.fill_between(us, z, alpha=0.5, linewidth=0)
            ax.set_title(pretty_parnames[i1])
            ax.axvline(x=gt_vals[i1], color='k')
            ax.yaxis.tick_right()
            ax.yaxis.set_label_position("right")
            ax.set_ylabel("density")
            bxs.append(bx)

            
fig.align_ylabels()

fig.savefig("../posterior-density.png", bbox_inches='tight', dpi=300)