In [None]:
import matplotlib.pyplot as plt
import scipy.stats as sts
import numpy as np
import cmdstanpy ## import stan interface for Python
import os
import datetime
from matplotlib.gridspec import GridSpec

import sys
sys.path.append("..")

import stancourse.utilities as util
from stancourse import plots

if os.name == "nt": ## adds compiler to path in Windows
    cmdstanpy.utils.cxx_toolchain_path() 

# Some "Simple" Stan Models

* Linear regression
* Logistic regression
* Mixture model
* Censored data

## Linear regression
\begin{equation}
Y \sim a X + b + \mathcal{N}(0, \sigma)
\end{equation}

In [None]:
## generate some random synthetic data
N = 100
X = sts.norm.rvs(loc=0, scale=1, size=N)
a_gt, b_gt = 0.1, 0.2
sigma_gt = 0.35
Y = a_gt * X + sts.norm.rvs(loc=b_gt, scale=sigma_gt, size=N)

## make figure of the data (X,Y)
fig, ax = plt.subplots(1, 1, figsize=(5,3))
ax.scatter(X, Y, s=5, color='k')
ax.set_xlabel("X"); ax.set_ylabel("Y");

**Stan model**

In [None]:
util.show_stan_model("../stan-models/linreg_minimal.stan")

**Stan model with `generated quantities` block**

In [None]:
util.show_stan_model("../stan-models/linreg.stan")

In [None]:
## compile stan model
sm = cmdstanpy.CmdStanModel(stan_file="../stan-models/linreg.stan")

## prepare data for stan
Nsim = 250; Xsim = np.linspace(np.min(X), np.max(X), Nsim)
data_dict = {
    "N" : N,
    "X" : X,
    "Y" : Y,
    "Nsim" : Nsim,
    "Xsim" : Xsim
}

## sample from posterior
sam = sm.sample(
    chains=4, ## number of independent parallel chains
    iter_warmup=1000, ## warmup iterations (adaptation of algo-parameters)
    iter_sampling= 2000, ## number of samples
    data=data_dict ## and of course the data
)

**Use the `summary()` method to get some summary statistics of the samples**

In [None]:
df = sam.summary()
df.loc[["a", "b", "sigma", "lp__"]] ## show only subset of output...

**use the `diagnose()` method to diagnose potential problems**

We will discuss what this means in the "Debugging" session

In [None]:
res = sam.diagnose()

In [None]:
## extract parameter "traces"

Yhat = sam.stan_variable("Yhat")
Ysim = sam.stan_variable("Ysim")
a_est = sam.stan_variable("a")

print("shape of array Yhat:", Yhat.shape)

## compute some statistics
mYhat = np.mean(Yhat, axis=0)
lYhat, uYhat = np.percentile(Yhat, axis=0, q=[2.5, 97.5])
lYsim, uYsim = np.percentile(Ysim, axis=0, q=[2.5, 97.5])

Pr = len([a for a in a_est if a < 0]) / len(a_est)
print("fraction of samples a < 0:", Pr)
## make figure...

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14,3))
ax1.scatter(X, Y, s=5, label="data", color='k', zorder=4)

ax1.plot(Xsim, mYhat, color=plots.bl1, zorder=3, label='reg. line')
ax1.fill_between(Xsim, lYhat, uYhat, color=plots.bl2, zorder=2, label="95%CrI")
ax1.fill_between(Xsim, lYsim, uYsim, color=plots.bl3, label="post. pred.")

ax1.set_xlabel("X")
ax1.set_ylabel("Y")
ax1.legend()

plots.density(ax2, a_est, color=plots.bl2, label="posterior density $a$")
ax2.axvline(x=a_gt, color='k', label='ground truth $a$')
ax2.axvline(x=0, color='r', label="a = 0")
ax2.set_xlabel("$a$")
ax2.set_ylabel("$density$")
ax2.legend()

In [None]:
fig ## show regression and posterior density plot

## Logistic regression
**with SARS-CoV-2 variant data**
* some sequence data from India (Feb - April 2021)
* counts of alpha and delta variant (and other)

In [None]:
with open("../data/india-alpha-delta.tsv") as f:
    for line in f.read().split('\n')[:4]:
        print(line)

In [None]:
with open("../data/india-alpha-delta.tsv") as f:
    table = [[int(x) for x in row.split('\t')] 
             for row in f.read().split('\n')[1:] if row !='']
    
Time = [row[0] for row in table]
Counts = [row[1:] for row in table]

freqs = np.array([[x / np.sum(row) for x in row] for row in Counts])
variants = ["alpha", "delta", "other"]

fig, ax = plt.subplots(1, 1, figsize=(7,2))
for i in range(3):
    ax.plot(Time, freqs[:,i], marker='o', label=variants[i])
    
date0 = datetime.datetime.strptime("01-01-2020", "%m-%d-%Y")

xticks = Time[::2]
dates = [date0 + datetime.timedelta(days=t) for t in xticks]
datestrs = [date.strftime("%b %d") for date in dates]
ax.set_xticks(xticks)
ax.set_xticklabels(datestrs)
ax.set_ylabel("frequency")
    
ax.legend()

In [None]:
fig ## show data

In [None]:
util.show_stan_model("../stan-models/sars2-variants.stan")

In [None]:
sm = cmdstanpy.CmdStanModel(stan_file="../stan-models/sars2-variants.stan")
data_dict = {
    "N" : len(Time),
    "K" : 3,
    "Time" : Time,
    "Counts" : Counts
}
sam = sm.sample(chains=1, data=data_dict)

In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:green']

fig = plt.figure(figsize=(7,3))

gs = GridSpec(1,4)
ax = fig.add_subplot(gs[:3])

for i in range(3):
    ax.scatter(Time, freqs[:,i], marker='o', label=variants[i], color=colors[i])

p_hats = sam.stan_variable("p_hat")
    
p_hat_mean = np.mean(p_hats, axis=0)
p_hat_l, p_hat_u = np.percentile(p_hats, axis=0, q=[2.5, 97.5])

ax.plot(Time, p_hat_mean)
for i in range(3):
    ax.fill_between(Time, p_hat_l[:,i], p_hat_u[:,i], color=colors[i], alpha=0.3)
    
    
date0 = datetime.datetime.strptime("01-01-2020", "%m-%d-%Y")
xticks = Time[::2]
dates = [date0 + datetime.timedelta(days=t) for t in xticks]
datestrs = [date.strftime("%b %d") for date in dates]
ax.set_xticks(xticks)
ax.set_xticklabels(datestrs)
ax.set_ylabel("frequency")
    
ax.legend()

bx = fig.add_subplot(gs[3])

alpha = sam.stan_variable("alpha")
bx.violinplot(alpha)

bx.set_ylabel("selective advantage (per day)")

bx.set_xticks(range(1,4))
bx.set_xticklabels(variants)

fig.tight_layout()


In [None]:
fig # show fit

## Mixture model

Example application: seroprevalence data

* $X_1, X_2, \dots, X_N$ (properly transformed) antibody titers
* With probability $p$, subject $i$ is "positive", and "negative" otherwise. 
* postive and negative titers have normal distribution with means $\mu_1 < \mu_2$ and standard deviations $\sigma_1$ and $\sigma_2$. 

\begin{equation}
    X_i \sim \left\{\begin{array}{ll}
        \mathcal{N}(\mu_1, \sigma_1) & \mbox{if $i$ negative} \\
        \mathcal{N}(\mu_2, \sigma_2) & \mbox{if $i$ positive}
    \end{array}\right.
\end{equation}

* We don't know the status of each individual, but only the titer $X_i$

In [None]:
p = 0.3
mu1, mu2 = -1, 2
sigma1, sigma2 = 0.5, 1.0

N = 1000
I = sts.bernoulli.rvs(p, size=N)
X = [sts.norm.rvs(loc=mu1, scale=sigma1) if i == 0 else sts.norm.rvs(loc=mu2, scale=sigma2)
     for i in I]

fig, ax = plt.subplots(1, 1, figsize=(7,3))

ax.hist(X, 50, density=True)
ax.set_xlabel("titer $X$")
ax.set_ylabel("density")

In [None]:
fig ## show histogram of the data

In [None]:
util.show_stan_model("../stan-models/mixture_model.stan")

In [None]:
sm = cmdstanpy.CmdStanModel(stan_file="../stan-models/mixture_model.stan")

data_dict = {
    "N" : N,
    "X" : X
}

sam = sm.sample(data=data_dict, chains=1)

sam.summary()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7,2))

ax.hist(X, 50, density=True, label="data")
ax.set_xlabel("titer $X$")
ax.set_ylabel("density")

mu_est = sam.stan_variable("mu")
sigma_est = sam.stan_variable("sigma")
p_est = sam.stan_variable("p")
xs = np.linspace(np.min(X), np.max(X), 1000)
y1s = sts.norm.pdf(xs, loc=np.mean(mu_est[:,0]), scale=np.mean(sigma_est[:,0]))
y2s = sts.norm.pdf(xs, loc=np.mean(mu_est[:,1]), scale=np.mean(sigma_est[:,1]))

p_mean = np.mean(p_est)
ax.plot(xs, (1-p_mean) * y1s, linewidth=3, label="negative")
ax.plot(xs, p_mean * y2s, linewidth=3, label="positive")

ax.legend()

In [None]:
fig ## data and fit mixture model

## Exercise

**classification of subjects in the mixture model**

Open the Stan file `mixture_model.stan` in the `stan-models` directory. 
Add a `generated quantities` block to calculate for each subject the probability `ppos[i]` that $i$ is positive.

```cpp
// other model blocks...

generated quantities {
    vector[N] ppos;
    
    // put your code here
}
```

## Censored data
**Interval censoring**
* Example: HIV-1 cure research: analytic treatment interuption experiments
* Measure viral load at discrete time points after antiretroviral treatment interruption (say every week)
* *Viral rebound* is defined as the time $T$ that the VL becomes detectable
* This time $T$ is *interval censored* as the VL is not observed continuously

Simple model for the rebound time: $T \sim {\rm Gamma}(\alpha, \beta)$

In [None]:
def gen_interval_censored_data(alpha, beta, dtmax):
    dtobs = np.random.randint(3, high=dtmax, size=200)
    tobs = np.concatenate([[0],np.cumsum(dtobs)])
    T = sts.gamma.rvs(alpha, scale=1/beta)
    for t1, t2 in zip(tobs[:-1], tobs[1:]):
        if t1 < T and t2 >= T:
            return [t1, t2]
                    

alpha_gt = 2
beta_gt = 1/7
dtmax = 14
N = 100

Ts = [gen_interval_censored_data(alpha_gt, beta_gt, dtmax) for _ in range(N)]

Ts.sort()
                            
fig, ax = plt.subplots(1, 1, figsize=(7,7))

for i, T in enumerate(Ts):
    label = 'rebound interval' if i == 0 else None
    ax.plot(T, [i,i], color='k', label=label)
    
ax.set_xlabel("days post ATI")
ax.set_ylabel("participant")

In [None]:
fig ## show rebound intervals

**What is the likelihood of interval censored data?**

The probability that rebound occured in the interval $[L,U]$ is equal to 

\begin{equation}
P(T \in [L, U]) = \int_{L}^U f_T(t) dt
\end{equation}
where $f_T$ is the PDF of the Gamma distribution. In Stan, we only have access to
```
gamma_lpdf
gamma_lcdf
gamma_lccdf
```
However, we can write
\begin{equation}
P(T \in [L, U]) = P(T \in [0, U]) - P(T \in [0, L])
\end{equation}
Hence, the desired probability is the difference between two CDFs. Working on the log-scale in Stan, the log of this difference is
```cpp
log_diff_exp(
    gamma_lcdf(U | alpha, beta), 
    gamma_lcdf(L | alpha, beta)
);
```


In [None]:
util.show_stan_model("../stan-models/interval_censored.stan")

In [None]:
sm = cmdstanpy.CmdStanModel(stan_file="../stan-models/interval_censored.stan")
data_dict = {
    "N" : N,
    "TimesL" : [int(t[0]) for t in Ts],
    "TimesU" : [int(t[1]) for t in Ts],
}
sam = sm.sample(data=data_dict, chains=1)
Tsim = sam.stan_variable("Tsim")

In [None]:
fig, (ax, bx) = plt.subplots(2, 1, figsize=(7,7), sharex=True)

for i, T in enumerate(Ts):
    label = 'rebound interval' if i == 0 else None
    ax.plot(T, [i,i], color='k', label=label, linewidth=0.5)
    
bx.set_xlabel("days post ATI")
ax.set_ylabel("participant")

plots.density(bx, Tsim, label="Tsim")
bx.legend()
bx.set_ylabel("density")

CrI_alpha = np.percentile(sam.stan_variable("alpha"), q=[2.5, 97.5])
CrI_beta = np.percentile(sam.stan_variable("beta"), q=[2.5, 97.5])

est_alpha =f"true alpha: {alpha_gt}, 95% CrI: [{CrI_alpha[0]:0.2f}, {CrI_alpha[1]:0.2f}]"
est_beta = f"true beta: {beta_gt:0.2f}, 95% CrI: [{CrI_beta[0]:0.2f}, {CrI_beta[1]:0.2f}]"


In [None]:
print(est_alpha + '\n' + est_beta)

fig ## show data and posterior predictive distribution