In [None]:
import matplotlib.pyplot as plt
import scipy.stats as sts
import numpy as np
import cmdstanpy ## import stan interface for Python
from scipy.integrate import solve_ivp
from matplotlib.gridspec import GridSpec
import os

import sys
sys.path.append("..")

import stancourse.utilities as util
from stancourse import plots

if os.name == "nt": ## adds compiler to path in Windows
    cmdstanpy.utils.cxx_toolchain_path() 

# Advanced Stan Models

* Using the ODE integrator
* Viral Dynamics model example
* Choosing a good initial parameter guess
* Multi-threading with `map_rect`
* Hierarchical Viral Dynamics model

## ODEs in Stan

Consider the initial value problem with $x \in \mathbb{R}^n$, parameter vector $\theta \in \mathbb{R}^k$ and data $D$.

\begin{equation}
\dot{x} = F(t, x, \theta)\,,\quad x(t_0) = x_0(\theta)
\end{equation}

ODE integrators in Stan

```cpp
vector[n] sol[N] = ode_rk45(ode_sys, x0, t0, ts, ...);
```
where the arguments are given by
```cpp
vector ode_sys(t, x, ...) { // defined in functions block
    real dydt[num_elements(x)];
    /* implement F */
    return dydt;
}
vector[n] x0; // initial state
real t0; // initial time
real ts[N]; // save state at these times
```
* higher-order function (in the sense that one of the arguments is itself a function)
* variadic function (in the sense that we can pass an arbitrary number of parameters)

## ODEs in Stan

* Runge-Kutta 4-5 (`ode_rk45`): RK with adaptive step-size control, Fast for non-stiff ODEs
* Backwards Differentiation Formula (`ode_bdf`): Slower but works for stiff systems
* Cash-Karp (`ode_ckrk`) *should perform better for systems that exhibit rapidly varying solutions*

**How on earth can Stan do automatic differentiation with ODE models??**

Let $x$ denote the solution of our IVP. For each observation time $t$, we need the gradients 
\begin{equation}
g_{ij}(t) = \frac{\partial [x_i(t)]}{\partial \theta_j}\, \quad i = 1,\dots, n \quad j = 1,\dots,k
\end{equation}
These can be solved by augmenting the system of ODEs with $n\times k$ equations.
\begin{equation}
\frac{d g_{ij}}{dt} = \frac{\partial \dot{x}_i}{\partial \theta_j}\,, \quad
\frac{d g_{ij}}{dt}(t_0) = \frac{\partial x_{0,i}}{\partial \theta_j}
\quad i = 1,\dots, n \quad j = 1,\dots,k
\end{equation}

**So this means that for an ODE with $n$ equations and $k$ parameters, Stan integrates a system of $n + k\times n$ ODEs**

Recent addition to Stan: **adjoint solver** (`ode_adjoint_tol_ctl`), only requires integration of $n + k$ ODEs, but has needs more fine-tuning and has more overhead

## Simple ODE model in Stan

Consider the "standard" viral dynamic model

\begin{equation}
\begin{split}
 \frac{dT}{dt} &= \lambda - d_T T - \beta VT /T_0 \\
 \frac{dI}{dt} &= \beta VT /T_0 - \delta I \\
 \frac{dV}{dt} &= p I - c V
\end{split}
\end{equation}

Reduce to 2D system of equations using QSSA $V = p/c I$.

### Generate some data
\begin{equation}
 {\rm VL_n} \sim {\rm Lognormal}(\log(V(t_n)), \sigma)
\end{equation}

In [None]:
T0 = 1e6
d_T = 0.05
d_I = 0.5
beta = 1.7
lam = T0 * d_T
I0 = 1
sigma = 0.5

def vd_ode(t, y, T0, d_T, d_I, beta, lam):
    T, I = y
    return np.array([
        lam - d_T * T - beta * T * I / T0, 
        beta * T * I / T0 - d_I * I
    ])

def gen_vd_data(params, N, tmax, lod=None):
    T0, d_T, d_I, beta, lam, I0, sigma = params
    t_span = (0, tmax)
    ObsTime = np.linspace(tmax/N, tmax, N)
    y0 = [T0, I0]
    sol = solve_ivp(lambda t, y: vd_ode(t, y, T0, d_T, d_I, beta, lam),
                    t_span, y0, dense_output=True, t_eval=ObsTime)
    Ihat = sol.y[1]
    VL = sts.lognorm.rvs(scale=Ihat, s=sigma)
    if lod is None:
        CC = [0 for _ in VL]
    else:
        CC = [1 if x < lod else 0 for x in VL]
        VL = [lod if x < lod else x for x in VL]
    return ObsTime, Ihat, VL, CC, sol

N = 20
tmax = 50

params = (T0, d_T, d_I, beta, lam, I0, sigma)

ObsTime, Ihat, VL, CC, sol = gen_vd_data(params, N, tmax)

fig, axs = plt.subplots(2, 1, figsize=(7,5), sharex=True)
ts = np.linspace(0, tmax, 1000)
labs = ["$T$", "$I$"]
for i in range(2):
    axs[0].plot(ts, sol.sol(ts)[i], label=labs[i])
#axs[0].set_yscale('log')
axs[0].legend(loc=1)

axs[1].plot(ObsTime, VL, color='k')
axs[1].scatter(ObsTime, VL, color='k', label="VL data")
#axs[1].set_yscale('log')
axs[1].legend(loc=1)
axs[1].set_xlabel("time post infection")

In [None]:
fig ## show simulation and data

In [None]:
util.show_stan_model("../stan-models/vd_model.stan", lines=(1,15))

In [None]:
util.show_stan_model("../stan-models/vd_model.stan", lines=(15,30))

In [None]:
util.show_stan_model("../stan-models/vd_model.stan", lines=(31,None))

### Compile and fit model

In [None]:
%%time

sm = cmdstanpy.CmdStanModel(stan_file="../stan-models/vd_model.stan")

data = {
    "N" : N,
    "ObsTime" : ObsTime,
    "VL" : VL,
    "T0" : T0
}

sam = sm.sample(
    data=data, 
    chains=1
)

### Fitted model trajectories and data, parameter estimates

In [None]:
yhat = sam.stan_variable("yhat")

fig, ax = plt.subplots(1, 1, figsize=(5,3))

mI = np.median(yhat[:,:,1], axis=0)
lI, uI = np.percentile(yhat[:,:,1], axis=0, q=[2.5, 97.5])

ax.plot(ObsTime, mI)
ax.fill_between(ObsTime, lI, uI, alpha=0.4)
ax.scatter(ObsTime, VL, color='k')
ax.set_yscale('log')

In [None]:
fig ## show model fit

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5,3))

parnames = ["beta", "d_T", "d_I", "I0", "sigma"]
pretty_parnames = ["$\\beta$", "$d_T$", "$d_I$", "$I_0$", "$\\sigma$"]
parvals = [beta, d_T, d_I, I0, sigma]
parests = [sam.stan_variable(pn) for pn in parnames]
ax.violinplot(parests, showextrema=False)
pos = range(1,len(parnames)+1)
ax.scatter(pos, parvals, color='k', s=5)
ax.set_xticks(pos)
ax.set_xticklabels(pretty_parnames)
ax.set_ylim(0,4)

In [None]:
fig ## show parameter estimates

## Exercise

Viral load has a limit of detection (usually 200 or 50 copies per mL), therefore some viral load measurements are left-censored (i.e. the reported value is an upper bound of the actual value). 

Suppose that in addition to the viral load data `VL` we also have censoring information. The array `CC` contains censoring codes where a `0` means "uncensored" and a `1` means left-censored. Open the Stan model `vd-model.stan` in the `stan-models` directory and modify the code to account for censored data.

*Hint: make use of the `lognormal_lcdf` function.*

## Initial guesses: predator-prey model example

Lotka-Volterra predator-prey model

\begin{equation}
\frac{dx}{dt} = a x - b xy \,,\quad
\frac{dy}{dt} = cb xy - d y 
\end{equation}
With initial conditions $x(t_0) = x_0$ and $y(t_0) = y_0$.

Observations ($K$ is a constant determining the sampling volume)
\begin{equation}
 X_i \sim {\rm Poisson}(K x(t_i)) \,,\quad
 Y_i \sim {\rm Poisson}(K y(t_i))\,,\quad i = 1,\dots, N
\end{equation}

In [None]:
def lv_sys(t, u, a, b, c, d):
    x, y = u
    dx = a*x - b*x*y
    dy = c*b*x*y - d*y
    return np.array([dx, dy])

N = 30
t0 = 0
tmax = 30
t_span = (0, tmax)
ObsTime = np.linspace(tmax/N, tmax, N)
u0 = np.array([1, 1])
a, b, c, d = 1, 0.4, 0.4, 0.5
sol = solve_ivp(lambda t, u: lv_sys(t, u, a, b, c, d), t_span, u0, 
                t_eval=ObsTime, dense_output=True, rtol=1e-6, atol=1e-6)

K = 10
PreyObs = [sts.poisson.rvs(K*y) for y in sol.y[0]]
PredObs = [sts.poisson.rvs(K*y) for y in sol.y[1]]

Nsim = 100
SimTime = np.linspace(tmax/Nsim, tmax, Nsim)

fig, axs = plt.subplots(2, 1, figsize=(7,4), sharex=True)

ts = np.linspace(*t_span, 1000)

axs[0].plot(ts, sol.sol(ts)[0], label="prey")
axs[0].plot(ts, sol.sol(ts)[1], label="predator")
axs[0].set_ylabel("concentration")

axs[0].legend()

axs[1].scatter(ObsTime, PreyObs, label="prey")
axs[1].scatter(ObsTime, PredObs, label="predator")

axs[1].legend()
axs[1].set_ylabel("samples")

In [None]:
fig ## predator-prey trajectories and data

In [None]:
util.show_stan_model("../stan-models/lotka.stan", lines=(1,19))

In [None]:
util.show_stan_model("../stan-models/lotka.stan", lines=(20,None))

In [None]:
data_dict = {
    "N" : N, 
    "Prey" : PreyObs, "Predator" : PredObs,
    "t0" : t0, "ObsTime" : ObsTime,
    "K" : K,
    "Nsim" : Nsim, "SimTime" : SimTime
}
## try sampling a couple of times and check the fit
sm = cmdstanpy.CmdStanModel(stan_file="../stan-models/lotka.stan")
sam = sm.sample(chains=1, data=data_dict, show_progress=True, refresh=5)

In [None]:
def plot_lv_fit(sam):
    fig = plt.figure(figsize=(9,4))
    gs = GridSpec(2,5)
    ax1 = fig.add_subplot(gs[0, :4])
    ax2 = fig.add_subplot(gs[1, :4], sharex=ax1)
    
    chain = sam.stan_variables()
    uhat = chain["uhat"]
    xhat = uhat[:, :, 0]
    yhat = uhat[:, :, 1]
    
    q = [2.5, 97.5]

    ax1.plot(SimTime, np.mean(xhat, axis=0), label="prey")
    ax1.plot(SimTime, np.mean(yhat, axis=0), label="predator")
    
    ax1.fill_between(SimTime, *np.percentile(xhat, axis=0, q=q), 
                        color='tab:blue', alpha=0.3)
    ax1.fill_between(SimTime, *np.percentile(yhat, axis=0, q=q), 
                        color='tab:orange', alpha=0.3)

    
    ax1.set_ylabel("concentration")

    ax1.legend()

    ax2.scatter(ObsTime, PreyObs, label="prey")
    ax2.scatter(ObsTime, PredObs, label="predator")

    PreySim = chain["PreySim"]
    PredatorSim = chain["PredatorSim"]

    ax2.plot(SimTime, np.mean(PreySim, axis=0))
    ax2.plot(SimTime, np.mean(PredatorSim, axis=0))

    ax2.fill_between(SimTime, *np.percentile(PreySim, axis=0, q=q), 
                        color='tab:blue', alpha=0.3)
    ax2.fill_between(SimTime, *np.percentile(PredatorSim, axis=0, q=q), 
                        color='tab:orange', alpha=0.3)

    ax2.legend()
    ax2.set_ylabel("samples")
    
    bx1 = fig.add_subplot(gs[0,4])
    bx2 = fig.add_subplot(gs[1,4])
    
    bx1.violinplot(chain["u0"])
    bx1.set_xticks([1,2])
    bx1.set_xticklabels(["$x_0$", "$x_1$"])
    bx1.scatter([1,2], u0, color='k')
    
    parnames = ["a", "b", "c", "d"]
    parests = [chain[pn] for pn in parnames]
    bx2.violinplot(parests)
    bx2.scatter([1,2,3,4], [a,b,c,d], color='k')
    bx2.set_xticks([1,2,3,4])
    bx2.set_xticklabels(parnames)
    
    fig.tight_layout()
    return fig

In [None]:
fig = plot_lv_fit(sam)

**choose a good initial guess**

In [None]:
init_dict = {
    "u0" : u0,
    "a" : a,
    "b" : b,
    "c" : c,
    "d" : d
}

sam2 = sm.sample(chains=1, data=data_dict, inits=init_dict, show_progress=True, refresh=5)

In [None]:
fig = plot_lv_fit(sam2)

**test initial guess by running with with fixed parameters**

In [None]:
init_dict = {
    "u0" : [0.5,2],
    "a" : 2,
    "b" : 0.3,
    "c" : 0.9,
    "d" : 0.3
}
sam3 = sm.sample(chains=1, data=data_dict, inits=init_dict, fixed_param=True)
fig = plot_lv_fit(sam3)

## Solving problems with ODE models

* In the warmup phase, Stan can quickly walk away from you carefully chosen initial guess. Avoid this by reducing the initial step size ($\epsilon$) of the HMC leapfrog algorithm
`sam = sm.sampling(step_size=0.01, ...)`

* Some (unlikely) parts of the parameter space may lead to **stiff ODEs**. If you're using an explicit solver (e.g. `ode_rk45`), this will lead to an extremely small integration step. Either use `ode_bdf` (implicit solver) or try to restict the parameter space by choosing more informative priors.

* Try to avoid ODE models with time-discontinuities in the vector field $f$. Suppose we have a discontinuity at time $t_1$ Solution 1: call the integrator twice. Solution 2: replace step functions with smooth approximations $H_u(t) = (1 + e^{u(t-t_1)})^{-1}$ (converges to Heaviside function as $u\rightarrow \infty$)

Solution 1:
```cpp
vector[n] sol1[N1] = ode_rk45(ode_sys, u0, t0, ts1...);
vector[n] u1 = sol1[N];
vector[n] sol2[N1] = ode_rk45(ode_sys, u1, t1, ts2...);
```


## Multi-threading with `map_rect`

* Models with multiple independent computationally intensive computations
* Example: ODE models with $R$ repeated experiments

Higher-order function `map_rect` can be used to distribute computations over multiple CPU cores.
Syntax:
```cpp
vector[M] xs = map_rect(fun, pop_par_vec, unit_par_vecs, data_reals, data_ints);
```
Where the arguments have the following types
```cpp
vector fun(pop_par, unit_par, data_real, data_int) { /* function body */ }
vector[k] pop_par_vec; // shared parameters
vector[l] unit_par_vecs[R]; // unit-specific parameter vectors
real data_reals[R,m]; // real data
real data_ints[R,n]; // integer data
```
**`map_rect` returns a concatenated vector with all results**. Each individual vector therefore can have a different length.

## Panel VL data

Suppose we now have multiple ($R$) viral load timeseries, and we want to fit the viral dynamics model defined above.
There are slight differences between individual parameters, so we want to fit a hierarchical model (cf. NLME). For instance for the parameters $\beta_r$ we assume a shared prior
\begin{equation}
\beta_r \sim {\rm LogNormal}(m_{\beta}, s_{\beta})\,, \quad r = 1,\dots, R
\end{equation}
And we also estimate the population-level parameters $m_{\beta}, s_{\beta}$

In [None]:
def gen_vd_params(d_I, d_T, beta, I0, s):
    d_I_ran = d_I * sts.lognorm.rvs(s=s)
    d_T_ran = d_T * sts.lognorm.rvs(s=s)
    beta_ran = beta * sts.lognorm.rvs(s=s)
    I0_ran = I0 * sts.lognorm.rvs(s=s)
    return (d_I_ran, d_T_ran, beta_ran, I0_ran)

T0_pop = 1e6
d_T_pop = 0.05
d_I_pop = 0.5
beta_pop = 1.7
lam_pop = T0_pop * d_T_pop
I0_pop = 0.1
sigma = 0.5
    
R = 12
VLs = []
ObsTimes = []
for r in range(R):
    N = 20
    tmax = 50
    d_I, d_T, beta, I0 = gen_vd_params(d_I_pop, d_T_pop, beta_pop, I0_pop, 0.2)
    params = (T0_pop, d_T, d_I, beta, d_T*T0_pop, I0, sigma)
    ObsTime, Ihat, VL, CC, sol = gen_vd_data(params, N, tmax)
    ObsTimes.append(ObsTime)
    VLs.append(VL)

nrows = 3
fig, axs = plt.subplots(nrows, R//nrows, figsize=(10,4), sharex=True, sharey=True)

for i, ax in enumerate(axs.flatten()):
    t = ObsTimes[i]
    VL = VLs[i]
    ax.plot(t, VL, marker='o', markersize=3)
    ax.set_yscale('log')
    
fig.tight_layout()
fig.text(0, 0.5, "Viral Load", rotation=90, va='center')
fig.text(0.5, 0, "days post infection", ha='center')

In [None]:
fig ## VL panel data

In [None]:
util.show_stan_model("../stan-models/vd_model_panel.stan", lines=(1,14))

In [None]:
util.show_stan_model("../stan-models/vd_model_panel.stan", lines=(15,30))

## functions block (continued)

In [None]:
util.show_stan_model("../stan-models/vd_model_panel.stan", lines=(31,44))

In [None]:
util.show_stan_model("../stan-models/vd_model_panel.stan", lines=(45,70))

In [None]:
util.show_stan_model("../stan-models/vd_model_panel.stan", lines=(71,87))

In [None]:
util.show_stan_model("../stan-models/vd_model_panel.stan", lines=(88,None))

In [None]:
## compile model
sm = cmdstanpy.CmdStanModel(
    stan_file="../stan-models/vd_model_panel.stan",
    cpp_options={"STAN_THREADS": True}    
)

## prepare data
data_dict = {
    "R" : R,
    "N" : [len(VL) for VL in VLs],
    "VL" : np.array(VLs), ## no padding required
    "ObsTime" : np.array(ObsTimes), ## no padding required here
    "T0" : T0
}

## choose reasonable initial values
init_dict = {
    "beta" : beta_pop * np.ones(R),
    "d_I" : d_I_pop * np.ones(R),
    "d_T" : d_T_pop * np.ones(R),
    "I0" : I0_pop * np.ones(R),
}

## fit model
sam = sm.sample(
    data=data_dict, inits=init_dict, iter_warmup=1000, iter_sampling=1000,
    chains=1, output_dir="../stan-cache/",
    show_progress=True, refresh=1,
    threads_per_chain=4
)

In [None]:
nrows = 3
fig, axs = plt.subplots(nrows, R//nrows, figsize=(10,5), sharex=True, sharey=True)

VLhat = sam.stan_variable("VLhat")

for i, ax in enumerate(axs.flatten()):
    ## plot data
    t = ObsTimes[i]
    VL = VLs[i]
    ax.scatter(t, VL, color='k', s=5)
    ax.set_yscale('log')
    ## plot fit
    lV, mV, uV = np.percentile(VLhat[:,i,:], axis=0, q=[2.5, 50, 97.5])
    ax.plot(t, mV)
    ax.fill_between(t, lV, uV, alpha=0.4, color='tab:blue')
    
fig.tight_layout()
fig.text(0, 0.5, "Viral Load", rotation=90, va='center')
fig.text(0.5, 0, "days post infection", ha='center')

In [None]:
fig ## show individual fits to panel data

**parameter estimates** and posterior predictive distributions

In [None]:
parnames = ["beta", "d_T", "d_I", "I0"]
parvals = [beta_pop, d_T_pop, d_I_pop, I0_pop]

fig, axs = plt.subplots(4, 4, figsize=(10,5))

xlims = [[1,4], [0.01, 0.1], [0.3, 0.8], [0, 0.3]]

for i, pn in enumerate(parnames):
    m = sam.stan_variable("m_" + pn)
    s = sam.stan_variable("s_" + pn)
    ppd = sam.stan_variable("ppd_" + pn)
    plots.density(axs[0,i], np.exp(m))
    plots.density(axs[2,i], ppd, color='tab:red')
    plots.density(axs[3,i], s, color='tab:orange')
    for r in range(R):
        p = sam.stan_variable(pn)[:,r]
        plots.density(axs[1,i], p, alpha=0.4)
    axs[0,i].set_xlim(*xlims[i])
    axs[1,i].set_xlim(*xlims[i])
    axs[2,i].set_xlim(*xlims[i])
    axs[3,i].set_xlim(0, 1)
    axs[0,i].axvline(x=parvals[i], color='k')
    axs[3,i].axvline(x=0.2, color='k')
    axs[0,i].set_title(pn)
    
axs[0,0].set_ylabel("pop mean")
axs[1,0].set_ylabel("units")
axs[2,0].set_ylabel("ppd")
axs[3,0].set_ylabel("std")
fig.tight_layout()
fig.align_ylabels()

In [None]:
fig ## show parameter estimates

## Further reading

**Choice of Priors**
Prior recommendations by the Stan team can be found [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations). This is in general not an easy task.

**The Bayesian Workflow**
How do we build a model, iteratively improve it, validate it etc. Read [this paper](https://arxiv.org/pdf/2011.01808.pdf) for developing good strategies.

**Fitting ODE models with Stan**
A [tutorial](https://mc-stan.org/users/documentation/case-studies/planetary_motion/planetary_motion.html) explaing problems with ODEs in Stan and possible solutions. Bayesian workflow applied to [epidemic models](https://arxiv.org/pdf/2006.02985.pdf)

**[Thurston](https://metrumresearchgroup.github.io/Torsten/)** is an extension of Stan for PK/PD modeling with Stan.