# Explore the bias of NPI on $s$

In this notebook, we use a simple model and synthetic data to explore the effect of NPI on the esitmate of $s$ in the population genetics model.

Let's first import some modules.

In [None]:
import cmdstanpy
import numpy as np
import scipy.stats as sts
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.special import expit
from matplotlib.gridspec import GridSpec

In [None]:
cmdstanpy.install_cmdstan()

## ODE model

Define the SEIR model with two variants (wild-type and mutant) as a system of ODEs:

\begin{equation}
\begin{split}
\frac{dS}{dt} &= -\beta S (I_w + (1+s) I_m) \\
\frac{dE_w}{dt} &= \beta S I_w - \alpha E_w \\
\frac{dE_m}{dt} &= \beta (1+s) S I_m - \alpha E_m \\
\frac{dI_w}{dt} &= \alpha E_w - \gamma I_w \\
\frac{dI_m}{dt} &= \alpha E_m - \gamma I_m \\
\end{split}
\end{equation}

The parameter $\beta$ is a function of time, given by

\begin{equation}
\beta(t) = (1-H_t) \beta_0 + H_t \beta_1
\end{equation}
where $H_t = (1 + e^{t-t_1})^{-1}$ is a smoothed step function

In [None]:
def betat(t, beta0, beta1, t1):
    Ht = expit(t - t1)
    return beta0 * (1 - Ht) + beta1 * Ht

def ode_seir(t, y, par):
    ## unpack parameters and state
    beta0, beta1, t1, alpha, gamma, s, zeta, p0 = par    
    S, Ew, Em, Iw, Im, DeltaEIw, DeltaEIm = y
    ## auxiliary values
    beta = betat(t, beta0, beta1, t1)
    FOIw = beta * Iw
    FOIm = beta * Im * (1+s)
    ## define derivative
    dSdt = -S*(FOIw + FOIm)
    dEwdt = S*FOIw - alpha * Ew
    dEmdt = S*FOIm - alpha * Em
    dIwdt = alpha * Ew - gamma * Iw
    dImdt = alpha * Em - gamma * Im
    DeltaEIw = alpha * Ew
    DeltaEIm = alpha * Em
    return np.array([dSdt, dEwdt, dEmdt, dIwdt, dImdt, DeltaEIw, DeltaEIm])

def get_init(par):
    beta0, beta1, t1, alpha, gamma, s, zeta, p0 = par    
    xw = 0.5*(-(alpha + gamma) + np.sqrt((alpha+gamma)**2 + 4*alpha*(beta0 - gamma)))
    xm = 0.5*(-(alpha + gamma) + np.sqrt((alpha+gamma)**2 + 4*alpha*(beta0*(1+s) - gamma)))
    yw = alpha / (xw + alpha + gamma)
    ym = alpha / (xm + alpha + gamma)
    y0 = np.array([1-zeta, (1-yw)*zeta*(1-p0), (1-ym)*zeta*p0, yw*zeta*(1-p0), ym*zeta*p0, 0, 0])
    return y0

def gen_data(N, M, par, t_span):
    y0 = get_init(par)
    sol = solve_ivp(lambda t, y: ode_seir(t,y,par), t_span, y0, dense_output=True)
    ts = np.linspace(*t_span, N)
    ys = sol.sol(ts)
    fm = ys[2,:] / (ys[1,:] + ys[2,:])
    Ms = np.array([sts.poisson.rvs(M) for _ in range(N)])
    Fm = sts.binom.rvs(Ms, fm)
    return ts, Fm, Ms
    

In [None]:
beta0 = 0.8
beta1 = 0.1
t1 = 25
zeta = 5e-6
alpha = 1/3
gamma = 1/4
R0 = beta0/gamma
s = 0.3
p0 = 0.25

kwarg_ivp = {
    "dense_output" : True,
    "rtol" : 1e-8,
    "atol" : 1e-8
}

print("R0 =", R0)

par = (beta0, beta1, t1, alpha, gamma, s, zeta, p0)

y0 = get_init(par)

print(y0)

t_span = (0, 50)
ts = np.linspace(*t_span, 1000)

sol = solve_ivp(lambda t, y: ode_seir(t,y,par), t_span, y0, **kwarg_ivp)

tobs, numvar, numtotal = gen_data(50, 100, par, t_span)

In [None]:
ys = sol.sol(ts)
fs = ys[2,:] / (ys[1,:] + ys[2,:])

fig, axs = plt.subplots(2,1, figsize=(7,7))

axs[0].plot(ts, ys[3], label = '$I_w$')
axs[0].plot(ts, ys[4], label = '$I_m$')

#axs[0].plot(ts, ys[1], label = '$E_w$')
#axs[0].plot(ts, ys[2], label = '$E_m$')


axs[0].legend()

#axs[0].set_yscale('log')

axs[1].plot(ts, fs, color='k')

axs[1].scatter(tobs, numvar / numtotal, s=10)

for t, k, n in zip(tobs, numvar, numtotal):
    CI = sts.beta.interval(0.95, k+0.5, n-k+0.5)
    axs[1].plot([t, t], CI, color='k', alpha=0.5)

In [None]:
sm = cmdstanpy.CmdStanModel(stan_file="../popgen_simple.stan")

In [None]:
N = 50
M = 100
t_span = (0,50)
tobs, numvar, numtotal = gen_data(N, M, par, t_span)

data_dict = {
    "N" : N,
    "NumSam" : numtotal,
    "NumVar" : numvar,
    "T" : tobs,
    "T_G" : 1/alpha + 1/gamma
}

sam = sm.sample(data=data_dict, output_dir="../stan-cache/")

In [None]:
chain = sam.stan_variables()

fig, ax = plt.subplots(1, 1)

ax.hist(chain["s"], 50)

print(np.mean(chain["s"]))

In [None]:
fig, ax = plt.subplots(1, 1)


## plot data
ax.scatter(tobs, numvar / numtotal, s=10, color='k')

for t, k, n in zip(tobs, numvar, numtotal):
    CI = sts.beta.interval(0.95, k+0.5, n-k+0.5)
    ax.plot([t, t], CI, color='k', alpha=0.5)
    
## plot fit
phat = chain["phat"]
pl, pu = np.percentile(phat, axis=0, q=[2.5, 97.5])
pm = np.mean(phat, axis=0)

ax.plot(tobs, pm, color='tab:blue')
ax.fill_between(tobs, pl, pu, color='tab:blue', alpha=0.3)

## Estimate s for several values of $\beta_1/\beta_0$

In [None]:
beta0 = 0.6
t1 = 30
zeta = 5e-6
alpha = 1/3
gamma = 1/4
R0 = beta0/gamma
s = 0.35
p0 = 0.05
N = 50
M = 100
t_span = (0,70)

ratios = np.linspace(0, 1, 11)
s_samples = []
phats = []
datas = []
sols = []

for i, r in enumerate(ratios):
    beta1 = beta0 * r
    par = (beta0, beta1, t1, alpha, gamma, s, zeta, p0)
    y0 = get_init(par)
    sol = solve_ivp(lambda t, y: ode_seir(t,y,par), t_span, y0, **kwarg_ivp)
    sols.append(sol)
    tobs, numvar, numtotal = gen_data(N, M, par, t_span)
    datas.append((tobs, numvar, numtotal))

    data_dict = {
        "N" : N,
        "NumSam" : numtotal,
        "NumVar" : numvar,
        "T" : tobs,
        "T_G" : 1/alpha + 1/gamma
    }
    
    sam = sm.sample(data=data_dict, output_dir="../stan-cache/")
    
    s_samples.append(sam.stan_variable("s"))
    phats.append(sam.stan_variable("phat"))

In [None]:
def diff_r_pade(x):
    D = (gamma-alpha)**2 + 4*alpha*beta0
    return alpha*beta0*np.sqrt(D) * x / (D + alpha*beta0*x)

def get_r(x):
    D = (gamma-alpha)**2 + 4*alpha*beta0*(1+x)
    return 0.5*(-(gamma+alpha)**2 + np.sqrt(D))

def diff_r(x):
    return get_r(x) - get_r(0)

def diff_r_lin(x):
    R0 = beta0/gamma
    T_G = 1/alpha + 1/gamma
    return x*R0/T_G / np.sqrt(1 + 4*(R0-1)/(T_G*(alpha+gamma)))

### make figure for supplement

In [None]:
fig = plt.figure(figsize=(10,8))

gs = GridSpec(2,2)
ax = fig.add_subplot(gs[1,1])

c='tab:blue'
w = 0.05
ax.violinplot(s_samples, positions=ratios, showextrema=False, widths=w)
ax.boxplot(s_samples, positions=ratios, widths=w,
           showfliers=False, whis=(2.5, 97.5),
           boxprops=dict(color=c), capprops=dict(color=c),
           whiskerprops=dict(color=c), medianprops=dict(color=c))


ax.set_xlim(-0.1,1.1)
ax.set_xticks(ratios)
ax.set_xticklabels([f'{x:0.1f}' for x in ratios])
ax.set_xlabel("reduction in transmission rate ($\\beta_1 / \\beta_0$)")
ax.set_ylabel("estimate of $s$")

T_G = 1/alpha + 1/gamma
s_adj = diff_r(s) * T_G

ax.axhline(y=s_adj, color='k')

## plot trajectories

bx = fig.add_subplot(gs[0,0])

for i in range(len(ratios)):
    sol = sols[i]
    ts = np.linspace(0, t_span[1], 1000)
    ys = sol.sol(ts)
    bx.plot(ts, ys[3] + ys[4], color='k', alpha=0.5)

ymax = 0.002
bx.set_ylim(0-ymax/50, ymax + ymax/50)
bx.set_yticks(np.linspace(0, ymax, 5))
bx.set_ylabel('fraction infected ($I$)')

## plot mutant fraction

cx = fig.add_subplot(gs[0,1])

for i in range(len(ratios)):
    sol = sols[i]
    ts = np.linspace(0, t_span[1], 1000)
    ys = sol.sol(ts)
    cx.plot(ts, ys[2] / (ys[1] + ys[2]), color='k', alpha=0.5)
    
## plot example fit

dx = fig.add_subplot(gs[1,0])

idx = 3

tobs, numvar, numtotal = datas[idx]


## plot data
dx.scatter(tobs, numvar / numtotal, s=10, color='k')

for t, k, n in zip(tobs, numvar, numtotal):
    CI = sts.beta.interval(0.95, k+0.5, n-k+0.5)
    dx.plot([t, t], CI, color='k', alpha=0.5)
    
## plot fit
phat = phats[idx]
pl, pu = np.percentile(phat, axis=0, q=[2.5, 97.5])
pm = np.mean(phat, axis=0)

dx.plot(tobs, pm, color='tab:blue')
dx.fill_between(tobs, pl, pu, color='tab:blue', alpha=0.3)

dx.text(0.1, 0.9, f"$\\beta_1 / \\beta_0 = {ratios[idx]:0.1f}$", 
        ha='left', va='top', transform=dx.transAxes)

for xx in (cx, dx):
    xx.set_ylabel("mutant frequency ($f_{mt}$)")


xmax = t_span[1]+2
xmin = -2

for xx in (bx, cx, dx):
    xx.axvspan(t1, xmax, color='r', alpha=0.2, linewidth=0)
    xx.set_xlim(xmin, xmax)
    xx.set_xlabel('time ($t$)')

    
    
for xx, X in zip([bx, cx, dx, ax], 'ABCD'):
    xx.text(-0.18, 1.04, X, fontsize='xx-large', transform=xx.transAxes)
    
fig.align_ylabels()
    
fig.savefig("../effect-NPI-on-s.pdf", bbox_inches='tight')

### How accurate is the Pade approximation?

In [None]:
xs = np.linspace(-1, 2.0, 100)

fig, ax = plt.subplots(1, 1)

ys = [diff_r_pade(x) for x in xs]
ax.plot(xs, ys, label="Pade")
ys = [diff_r(x) for x in xs]
ax.plot(xs, ys, label="exact")
ys = [diff_r_lin(x) for x in xs]
ax.plot(xs, ys, label="linear")

ax.legend()