# Session 4: Bayesian Model Evaluation and Workflow

In this session, we will learn to evaluate the quality of our models using statistical and visual diagnostics. We'll also discuss a comprehensive Bayesian workflow that promotes model development through an iterative process.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import pymc as pm
import pytensor.tensor as pt
import arviz as az
import warnings
warnings.simplefilter("ignore")

RANDOM_SEED = 20090425

## MCMC Output Processing and Model Checking with ArviZ

[ArviZ](https://python.arviz.org/en/stable/) is a Python package for exploratory analysis of Bayesian models. It includes functions for posterior analysis, model checking, comparison and diagnostics. ArviZ is designed to work with output from a wide range of Bayesian inference libraries, including PyMC, emcee, Stan, Pyro, and TensorFlow Probability.

ArviZ is built on top of the popular libraries xarray, matplotlib, and bokeh. It is also built with the same design principles as PyMC, so if you are familiar with PyMC, you will find ArviZ easy to use.

### Example: Effect of coaching on SAT scores

This example was taken from Gelman *et al.* (2013):

> A study was performed for the Educational Testing Service to analyze the effects of special coaching programs on test scores. Separate randomized experiments were performed to estimate the effects of coaching programs for the SAT-V (Scholastic Aptitude Test- Verbal) in each of eight high schools. The outcome variable in each study was the score on a special administration of the SAT-V, a standardized multiple choice test administered by the Educational Testing Service and used to help colleges make admissions decisions; the scores can vary between 200 and 800, with mean about 500 and standard deviation about 100. The SAT examinations are designed to be resistant to short-term efforts directed specifically toward improving performance on the test; instead they are designed to reflect knowledge acquired and abilities developed over many years of education. Nevertheless, each of the eight schools in this study considered its short-term coaching program to be successful at increasing SAT scores. Also, there was no prior reason to believe that any of the eight programs was more effective than any other or that some were more similar in effect to each other than to any other.

We are given the estimated coaching effects (`y`) and their sampling variances (`s`). The estimates were obtained by independent experiments, with relatively large sample sizes (over thirty students in each school), so you it can be assumed that they have approximately normal sampling distributions with known variances.

In [None]:
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
s = np.array([15, 10, 16, 11, 9, 11, 10, 18])
schools = np.array(
    [
        "Choate",
        "Deerfield",
        "Phillips Andover",
        "Phillips Exeter",
        "Hotchkiss",
        "Lawrenceville",
        "St. Paul's",
        "Mt. Hermon",
    ]
)

with pm.Model(coords={'school': schools}) as schools_model:
    
    mu = pm.Normal("mu", 0, sigma=1e6)
    tau = pm.HalfCauchy("tau", 5)

    theta = pm.Normal("theta", mu, sigma=tau, dims='school')

    obs = pm.Normal("obs", theta, sigma=s, observed=y)

Let's run a short sample and look at the results.

In [None]:
with schools_model:
    # Model fitting
    schools_trace = pm.sample(1000, tune=1000, random_seed=RANDOM_SEED)

After running an MCMC simulation, `sample` returns an `arviz.InferenceData` object containing the samples for all the stochastic and named deterministic random variables. 

In [None]:
schools_trace

Data corresponding to each type of sampling is available as an `InferenceData` attribute.

In [None]:
post = schools_trace.posterior
post

## Model Checking

The final step in Bayesian computation is model checking, in order to ensure that inferences derived from your sample are valid

There are **two components** to model checking:

1. Convergence diagnostics
2. Goodness of fit

Convergence diagnostics are intended to detect **lack of convergence** in the Markov chain Monte Carlo sample; it is used to ensure that you have not halted your sampling too early. However, a converged model is not guaranteed to be a good model. 

The second component of model checking, goodness of fit, is used to check the **internal validity** of the model, by comparing predictions from the model to the data used to fit the model. 

### Convergence Diagnostics

Valid inferences from sequences of MCMC samples are based on the
assumption that the samples are derived from the true posterior
distribution of interest. Theory guarantees this condition as the number
of iterations approaches infinity. It is important, therefore, to
determine the **minimum number of samples** required to ensure a reasonable
approximation to the target posterior density. Unfortunately, no
universal threshold exists across all problems, so convergence must be
assessed independently each time MCMC estimation is performed. The
procedures for verifying convergence are collectively known as
*convergence diagnostics*.

There are a handful of easy-to-use methods for checking convergence. Since you cannot prove convergence, but only show lack of convergence, there is no single method that is foolproof. So, its best to look at a suite of diagnostics together. 

We will cover the canonical set of checks:

- Sampler statistics
- Variable plotting
- Divergences
- R-hat
- Effective Sample Size

### Sampler Statistics

When checking for convergence or when debugging a badly behaving sampler, it is often helpful to take a closer look at what the sampler is doing. For this purpose some samplers export statistics for each generated sample.

NUTS provides several metrics related to the performance of the sampler.

In [None]:
schools_trace.sample_stats

The sample statistics variables are defined as follows:

- `process_time_diff`: The time it took to draw the sample, as defined by the python standard library time.process_time. This counts all the CPU time, including worker processes in BLAS and OpenMP.

- `step_size`: The current integration step size.

- `diverging`: (boolean) Indicates the presence of leapfrog transitions with large energy deviation from starting and subsequent termination of the trajectory. "large" is defined as `max_energy_error` going over a threshold.

- `lp`: The joint log posterior density for the model (up to an additive constant).

- `energy`: The value of the Hamiltonian energy for the accepted proposal (up to an additive constant).

- `energy_error`: The difference in the Hamiltonian energy between the initial point and the accepted proposal.

- `perf_counter_diff`: The time it took to draw the sample, as defined by the python standard library time.perf_counter (wall time).

- `perf_counter_start`: The value of time.perf_counter at the beginning of the computation of the draw.

- `n_steps`: The number of leapfrog steps computed. It is related to `tree_depth` with `n_steps <= 2^tree_dept`.

- `max_energy_error`: The maximum absolute difference in Hamiltonian energy between the initial point and all possible samples in the proposed tree.

- `acceptance_rate`: The average acceptance probabilities of all possible samples in the proposed tree.

- `step_size_bar`: The current best known step-size. After the tuning samples, the step size is set to this value. This should converge during tuning.

- `tree_depth`: The number of tree doublings in the balanced binary tree.

It can be helpful to plot some of these variables, rather than staring at vectors of numbers!

#### Tree Depth

In the No-U-Turn Sampler (NUTS), each proposal constructs a balanced binary tree
of candidate states by recursively doubling the number of leapfrog steps.

- At tree depth *d*, the sampler can take up to `2**d` leapfrog steps.
- Doubling stops when:
  - A U-turn condition is detected (the trajectory reverses), or
  - A maximum tree depth limit (`max_tree_depth`) is reached.

Tree depth controls the maximum path length of each proposal:
- Higher depths allow the sampler to explore further along the posterior
  geometry.
- Consistently hitting the maximum depth often indicates tuning or geometry issues,
  such as a step size that is too small.



In [None]:
schools_trace.sample_stats["tree_depth"].plot(col="chain", ls="none", marker=".", alpha=0.3);

We can also see if the acceptance rate is close to the target.

In [None]:
schools_trace.sample_stats["acceptance_rate"].plot.hist(bins=20, density=True);

Recall that NUTS generates a binary tree of samples, so here the acceptance rate is the average of the acceptance probabilities of all the samples in the tree.

## Output Visualization with ArviZ

[ArviZ](https://arviz-devs.github.io/arviz/) is a Python package for exploratory analysis of Bayesian models. It includes functions for posterior analysis, model checking, comparison and diagnostics and is desingefd to work with a range of Bayesian inference libraries (not just PyMC).

ArviZ is built on top of the popular libraries xarray and matplotlib. It is also built with the same design principles as PyMC, so if you are familiar with PyMC, you will find ArviZ easy to use.

### Traceplot 

Perhaps the most-used ArviZ plot is the traceplot, obtained via the `plot_trace` function. This is a simple plot that is a good quick check to make sure nothing is obviously wrong, and is usually the first diagnostic step you will take. You've seen these already: just the time series of samples for an individual variable.

The `plot_trace` function from ArViZ by default generates a kernel density plot and a trace plot, with a different color for each chain of the simulation.

In [None]:
az.plot_trace(schools_trace, var_names=['mu', 'tau']);
plt.tight_layout();

This sample is deliberately inadequate. Looking at the trace plot, the problems should be apparent.

Can you identify the issues, based on what you learned in the previous section?

### Divergences

As we have seen, Hamiltonian Monte Carlo (and NUTS) performs numerical integration in order to explore the posterior distribution of a model. When the integration goes wrong, it can go dramatically wrong. 

For example, here are some Hamiltonian trajectories on the distribution of two correlated variables. Can you spot the divergent path?

![divering HMC](images/diverging_hmc.png)

The reason that this happens is that there may be parts of the posterior which are **hard to explore** for geometric reasons. Two ways of solving divergences are

1. **Set a higher "target accept" rate**: Similarly (but not the same) as for Metropolis-Hastings, larger integrator steps lead to lower acceptance rates. A higher `target_accept` will generally cause a smaller step size, and more accurate integration.
2. **Reparametrize**: If you can write your model in a different way that has the same joint probability density, you might do thpt. A lot of work is being done to automate this, since it requires careful work, and one goal of a probabilistic programming language is to iterate quickly. See [Hoffmann, Johnson, Tran (2018)](https://arxiv.org/abs/1811.11926), [Gorinova, Moore, Hoffmann (2019)](https://arxiv.org/abs/1906.03028).

You should be wary of a trace that contains many divergences (particularly those clustered in particular regions of the parameter space), and give thought to how to fix them.

### Divergence example

The trajectories above are from a famous example of a difficult geometry: Neal's funnel. It is problematic because the geometry is very different in some regions of the state space relative to others. Specifically, for hierarchical models, as the scale parameter changes in size so do the values of the parameters it is constraining. When the variance is close to zero, the parameter space is very constrained relative to the majority of the support.

In [None]:
def neals_funnel(dims=2):
    with pm.Model() as funnel:
        v = pm.Normal('v', 0, 3)
        x_vec = pm.MvNormal('x_vec', mu=pt.zeros(dims), cov=2 * pt.exp(v) * pt.eye(dims), shape=dims)
    return funnel

with neals_funnel():
    funnel_trace = pm.sample(random_seed=RANDOM_SEED)

PyMC provides us feedback on divergences, including a count and a recommendation on how to address them. 

In [None]:
diverging_ind = funnel_trace.sample_stats['diverging'].values[0].nonzero()
diverging_ind

In [None]:
ax = az.plot_pair(funnel_trace)
ax[0][0].plot(funnel_trace.posterior['v'].sel(chain=0).values[diverging_ind], 
              funnel_trace.posterior['x_vec'].sel(chain=0).values[diverging_ind].squeeze(), 'y.');

Notice that the divergent samples are clustered toward the narrow part of the funnel.

### Potential Scale Reduction: $\hat{R}$

Roughly, $\hat{R}$ (*R-Hat*, or the *Gelman-Rubin statistic*) is the ratio of between-chain variance to within-chain variance. This diagnostic uses multiple chains to
check for lack of convergence, and is based on the notion that if
multiple chains have converged, by definition they should appear very
similar to one another; if not, one or more of the chains has failed to
converge.

$\hat{R}$ uses an analysis of variance approach to
assessing convergence. That is, it calculates both the between-chain
variance (B) and within-chain variance (W), and assesses whether they
are different enough to worry about convergence. Assuming $m$ chains,
each of length $n$, quantities are calculated by:

$$\begin{align}B &= \frac{n}{m-1} \sum_{j=1}^m (\bar{\theta}_{.j} - \bar{\theta}_{..})^2 \\
W &= \frac{1}{m} \sum_{j=1}^m \left[ \frac{1}{n-1} \sum_{i=1}^n (\theta_{ij} - \bar{\theta}_{.j})^2 \right]
\end{align}$$

for each scalar estimand $\theta$. Using these values, an estimate of
the marginal posterior variance of $\theta$ can be calculated:

$$\hat{\text{Var}}(\theta | y) = \frac{n-1}{n} W + \frac{1}{n} B$$

Assuming $\theta$ was initialized to arbitrary starting points in each
chain, this quantity will overestimate the true marginal posterior
variance. At the same time, $W$ will tend to underestimate the
within-chain variance early in the sampling run. However, in the limit
as $n \rightarrow 
\infty$, both quantities will converge to the true variance of $\theta$.
In light of this, $\hat{R}$ monitors convergence using
the ratio:

$$\hat{R} = \sqrt{\frac{\hat{\text{Var}}(\theta | y)}{W}}$$

This is called the **potential scale reduction**, since it is an estimate of
the potential reduction in the scale of $\theta$ as the number of
simulations tends to infinity. In practice, we look for values of
$\hat{R}$ close to one (say, less than 1.1) to be confident that a
particular estimand has converged. 

In [None]:
az.summary(schools_trace)

### Effective Sample Size

In general, samples drawn from MCMC algorithms will be autocorrelated. Unless the autocorrelation is very severe, this is not a big deal, other than the fact that autocorrelated chains may require longer sampling in order to adequately characterize posterior quantities of interest. The calculation of autocorrelation is performed for each lag $i=1,2,\ldots,k$ (the correlation at lag 0 is, of course, 1) by: 

$$\hat{\rho}_i = 1 - \frac{V_i}{2\hat{\text{Var}}(\theta | y)}$$

where $\hat{\text{Var}}(\theta | y)$ is the same estimated variance as calculated for the Gelman-Rubin statistic, and $V_i$ is the variogram at lag $i$ for $\theta$:

$$\text{V}_i = \frac{1}{m(n-i)}\sum_{j=1}^m \sum_{k=i+1}^n (\theta_{jk} - \theta_{j(k-i)})^2$$

This autocorrelation can be visualized using the `plot_autocorr` function in ArviZ:

In [None]:
az.plot_autocorr(schools_trace, var_names=['mu', 'tau'], combined=True);

You can see very severe autocorrelation in both variables, which is not surprising given the trace that we observed earlier.

The amount of correlation in an MCMC sample influences the **effective sample size** (ESS) of the sample. The ESS estimates how many *independent* draws contain the same amount of information as the *dependent* sample obtained by MCMC sampling.

Given a series of samples $x_j$, the empirical mean is

$$
\hat{\mu} = \frac{1}{n}\sum_{j=1}^n x_j
$$

and the variance of the estimate of the empirical mean is 

$$
\operatorname{Var}(\hat{\mu}) = \frac{\sigma^2}{n},
$$
where $\sigma^2$ is the true variance of the underlying distribution.

Then the effective sample size is defined as the denominator that makes this relationship still be true:

$$
\operatorname{Var}(\hat{\mu}) = \frac{\sigma^2}{n_{\text{eff}}}.
$$

The effective sample size is estimated using the partial sum:

$$\hat{n}_{eff} = \frac{n}{1 + 2\sum_{i=1}^T \hat{\rho}_i}$$

where $T$ is the first odd integer such that $\hat{\rho}_{T+1} + \hat{\rho}_{T+2}$ is negative.

In [None]:
az.plot_ess(schools_trace, var_names=['tau'])
plt.tight_layout();

Using ArViZ, we can visualize the evolution of ESS as the MCMC sample accumulates. When the model is converging properly, both lines in this plot should be approximately linear.

The standard ESS estimate, which mainly assesses how well the centre of the distribution is resolved, is referred to as **bulk-ESS**. In order to estimate intervals reliably, it is also important to consider the **tail-ESS**.

In [None]:
az.plot_ess(schools_trace, var_names=['mu'], kind='evolution');

### Bayesian Fraction of Missing Information

The Bayesian fraction of missing information (BFMI) is a measure of how hard it is to
sample level sets of the posterior at each iteration. Specifically, it quantifies **how well momentum resampling matches the marginal energy distribution**. 

$$\text{BFMI} = \frac{\mathbb{E}_{\pi}[\text{Var}_{\pi_{E|q}}(E|q)]}{\text{Var}_{\pi_{E}}(E)}$$

$$\widehat{\text{BFMI}} = \frac{\sum_{i=1}^N (E_n - E_{n-1})^2}{\sum_{i=1}^N (E_n - \bar{E})^2}$$

BFMI is essentially a measure of the association between the energy of a state and the energy of the next state, or more precisely, it compares the average squared change in energy between successive samples to the overall variance of the energy across all samples. The "missing information" refers to the information that the sampler fails to gain about the posterior because it cannot efficiently traverse the energy landscape.

A small value indicates that the adaptation phase of the sampler was unsuccessful, and invoking the central limit theorem may not be valid. It indicates whether the sampler is able to *efficiently* explore the posterior distribution.

Though there is not an established rule of thumb for an adequate threshold, values close to one are optimal. Reparameterizing the model is sometimes helpful for improving this statistic.

BFMI calculation is only available in samples that were simulated using HMC or NUTS.

In [None]:
az.bfmi(schools_trace)

Another way of diagnosting this phenomenon is by comparing the overall distribution of 
energy levels with the *change* of energy between successive samples. Ideally, they should be very similar.

If the distribution of energy transitions is narrow relative to the marginal energy distribution, this is a sign of inefficient sampling, as many transitions are required to completely explore the posterior. On the other hand, if the energy transition distribution is similar to that of the marginal energy, this is evidence of efficient sampling, resulting in near-independent samples from the posterior.

As an example, if we look at the energy plot of our eight schools model, the low BFMI values (which result in poor overlap in the energy distributions) suggest taht the sampler is having trouble exploring different energy levels.

In [None]:
az.plot_energy(schools_trace);

We can attempt to improve the efficiency of the sampler by reparameterizing the random effect to be non-centered.

In [None]:
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
s = np.array([15, 10, 16, 11, 9, 11, 10, 18])
schools = np.array(
    [
        "Choate",
        "Deerfield",
        "Phillips Andover",
        "Phillips Exeter",
        "Hotchkiss",
        "Lawrenceville",
        "St. Paul's",
        "Mt. Hermon",
    ]
)

with pm.Model(coords={'school': schools}) as schools_uncentered:
    
    mu = pm.Normal("mu", 0, sigma=1e6)
    tau = pm.HalfCauchy("tau", 5)

    z = pm.Normal('z', dims='school')
    theta = pm.Deterministic("theta", mu + tau*z, dims='school')

    obs = pm.Normal("obs", theta, sigma=s, observed=y)

    trace_uncentered = pm.sample()

In [None]:
az.plot_energy(trace_uncentered)

In [None]:
az.plot_trace(trace_uncentered, var_names=['mu', 'tau']);
plt.tight_layout();

## Goodness of Fit

As noted at the beginning of this section, convergence diagnostics are only the first step in the evaluation
of MCMC model outputs. It is possible for an entirely unsuitable model to converge, so additional steps are needed to ensure that the estimated model adequately fits the data. 

One intuitive way of evaluating model fit is to compare model predictions with the observations used to fit
the model. In other words, the fitted model can be used to simulate data, and the distribution of the simulated data should resemble the distribution of the actual data.

Fortunately, simulating data from the model is a natural component of the Bayesian modelling framework. Recall, from the discussion on prediction, the posterior predictive distribution:

$$p(\tilde{y}|y) = \int p(\tilde{y}|\theta) f(\theta|y) d\theta$$

Here, $\tilde{y}$ represents some hypothetical new data that would be expected, taking into account the posterior uncertainty in the model parameters. 

Sampling from the posterior predictive distribution is easy in PyMC. The `sample_posterior_predictive` function draws posterior predictive samples from all of the observed variables in the model. 

In [None]:
with schools_uncentered:
    pm.sample_posterior_predictive(trace_uncentered, extend_inferencedata=True)

The degree to which simulated data correspond to observations can be evaluated visually. This allows for a qualitative comparison of model-based replicates and observations. If there is poor fit, the true value of the data may appear in the tails of the histogram of replicated data, while a good fit will tend to show the true data in high-probability regions of the posterior predictive distribution.

In [None]:
az.plot_ppc(trace_uncentered);

In [None]:
az.plot_ppc(trace_uncentered, kind='cumulative', mean=False);

We can also look at the predictive performance of our model by examining the residuals:

In [None]:
y_pred = trace_uncentered.posterior_predictive["obs"].mean(["chain", "draw"]).values
residuals = y - y_pred

go.Figure().add_trace(
    go.Scatter(
        x=y_pred,
        y=residuals,
        mode='markers',
        marker=dict(
            color='royalblue',
            size=8,
            opacity=0.6
        ),
        name='Residuals'
    )
).add_shape(
    type="line",
    x0=float(y_pred.min()),
    x1=float(y_pred.max()),
    y0=0,
    y1=0,
    line=dict(
        color="red",
        width=2
    )
).update_layout(
    title='Residual Plot',
    xaxis_title='Predicted values',
    yaxis_title='Residuals',
    width=600,
    plot_bgcolor='white',
    showlegend=False
)

## Bayesian Workflow

Strengths of Bayesian statistics that are critical:
* Great flexibility to quickly and iteratively build statistical models
* Offers principled way of dealing with uncertainty
* Don't just want most likely outcome but distribution of all possible outcomes
* Allows expert information to guide model by using informative priors

The Bayesian workflow consists of:
* How to go from data to a model idea
* How to find priors for your model
* How to evaluate a model
* How to iteratively improve a model
* How to forecast into the future
* How powerful generative modeling can be

### COVID-19 Case Study

Let's apply the Bayesian workflow to a real-world problem: modeling COVID-19 cases.

In [None]:
import load_covid_data

df = load_covid_data.load_data(drop_states=True, filter_n_days_100=2)
countries = df.country.unique()
n_countries = len(countries)
df = df.loc[lambda x: (x.days_since_100 >= 0)]
df.head()

### Bayesian Workflow

A good workflow to adopt when developing models is:

1. Plot the data
2. Build model
3. Run prior predictive check
4. Fit model
5. Assess convergence
6. Run posterior predictive check
7. Improve model

#### 1. Plot the data

We will look at German COVID-19 cases. At first, we will only look at the first 30 days after Germany crossed 100 cases.

In [None]:
# Extract data
country = 'Germany'
date = '2020-07-31'
df_country = df.query(f'country=="{country}"').loc[:date].iloc[:30]

# Create the figure
fig = go.Figure()

# Add confirmed cases line
fig.add_trace(
    go.Scatter(
        x=df_country.index,
        y=df_country.confirmed,
        mode='lines+markers',
        name='Confirmed cases',
        line=dict(color='royalblue', width=2),
        marker=dict(size=6)
    )
)

# Update layout
fig.update_layout(
    title=dict(
        text=f'COVID-19 Cases in {country}',
        x=0.5
    ),
    xaxis=dict(
        title='Date',
        tickformat='%b %d\n%Y'
    ),
    yaxis=dict(
        title='Confirmed cases',
        gridcolor='lightgray'
    ),
    width=1000,
    height=800,
    plot_bgcolor='white'
)

fig.show()

#### 2. Build an initial model

The above line looks exponential. This matches with knowledge from epidemiology, where early in an epidemic it grows exponentially.

In [None]:
# Get time-range of days since 100 cases were crossed
t = df_country.days_since_100.values
# Get number of confirmed cases for Germany
confirmed = df_country.confirmed.values

with pm.Model() as model_exp1:
    # Intercept
    a = pm.Normal('a', mu=0, sigma=100)

    # Slope
    b = pm.Normal('b', mu=0.3, sigma=0.3)

    # Exponential regression
    growth = a * (1 + b) ** t

    # Error term
    eps = pm.HalfNormal('eps', 100)

    # Likelihood
    pm.Normal('obs',
              mu=growth,
              sigma=eps,
              observed=confirmed)

#### 3. Run prior predictive check

Without even fitting the model to our data, we generate new potential data from our priors. Usually we have less intuition about the parameter space, where we define our priors, and more intution about what data we might expect to see. A prior predictive check thus allows us to make sure the model can generate the types of data we expect to see.

In [None]:
with model_exp1:
    prior_pred = pm.sample_prior_predictive()

In [None]:
obs_samples = prior_pred.prior_predictive["obs"].values.squeeze()

def plot_prior_predictive(obs_samples, y_range=None, x_range=None, 
                         title="Prior predictive", x_label="Days since 100 cases", 
                         y_label="Positive cases", log_y=False):
    num_samples, num_timesteps = obs_samples.shape
    
    fig = go.Figure()
    for i in range(num_samples):
        fig.add_trace(go.Scatter(
            x=np.arange(num_timesteps),
            y=obs_samples[i],
            mode='lines',
            line=dict(color='rgba(128,128,128,0.1)'),
            showlegend=False
        ))
    fig.update_layout(
        yaxis_range=y_range,
        xaxis_range=x_range,
        title=title,
        xaxis_title=x_label,
        yaxis_title=y_label,
        template="plotly_white",
        yaxis_type="log" if log_y else "linear"
    )
    return fig

In [None]:
plot_prior_predictive(obs_samples=obs_samples, y_range=(-1000, 1000))

#### What's wrong with this model?

There are several issues with this model:
1. Cases can't be negative
2. Cases can not start at 0, as we set it to start at above 100
3. Case counts can't go down

Let's improve our model. The presence of negative cases is due to us using a Normal likelihood. Instead, let's use a `NegativeBinomial`, which is similar to `Poisson` which is commonly used for count-data but has an extra dispersion parameter that allows more flexiblity in modeling the variance of the data.

In [None]:
t = df_country.days_since_100.values
confirmed = df_country.confirmed.values

with pm.Model() as model_exp2:
    # Intercept
    a = pm.Normal('a', mu=100, sigma=25)

    # Slope
    b = pm.Normal('b', mu=0.3, sigma=0.1)

    # Exponential regression
    growth = a * (1 + b) ** t

    # Likelihood
    pm.NegativeBinomial('obs',
                 growth,
                 alpha=pm.Gamma("alpha", mu=6, sigma=1),
                 observed=confirmed)

In [None]:
with model_exp2:
    prior_pred = pm.sample_prior_predictive()

plot_prior_predictive(obs_samples=prior_pred.prior_predictive['obs'].values.squeeze(), y_range=(-100, 1000))

#### 4. Fit model

In [None]:
with model_exp2:
    trace_exp2 = pm.sample(chains=4, cores=4, tune=2000)

#### 5. Assess convergence

In [None]:
az.plot_trace(trace_exp2, var_names=['a', 'b', 'alpha'])
plt.tight_layout();

In [None]:
az.summary(trace_exp2, var_names=['a', 'b', 'alpha'])

In [None]:
az.plot_energy(trace_exp2);

#### 6. Run posterior predictive check

In [None]:
with model_exp2:
    # Draw samples from posterior predictive
    post_pred = pm.sample_posterior_predictive(trace_exp2.posterior)

In [None]:

posterior_samples = post_pred.posterior_predictive['obs'].sel(chain=0).values.squeeze()

def plot_posterior_predictive(posterior_samples, confirmed):
    fig = go.Figure()
    for i in range(posterior_samples.shape[0]):
        fig.add_trace(go.Scatter(
            x=list(range(posterior_samples.shape[1])),
            y=posterior_samples[i],
            mode='lines',
            line=dict(color='rgba(128,128,128,0.05)', width=1),
            showlegend=False,
            hoverinfo='skip'
        ))
    fig.add_trace(go.Scatter(
        x=list(range(len(confirmed))),
        y=confirmed,
        mode='lines',
        line=dict(color='red', width=2),
        name='data'
    ))

    fig.update_layout(
        title=country,
        xaxis_title="Days since 100 cases",
        yaxis_title="Confirmed cases (log scale)",
        yaxis_type="log",
        height=800,
        width=1000,
        legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)
    )
    return fig

plot_posterior_predictive(posterior_samples, confirmed)

In [None]:
resid = post_pred.posterior_predictive["obs"].sel(chain=0).values - confirmed
fig = go.Figure()
for i in range(resid.shape[0]):
    fig.add_trace(go.Scatter(
        x=list(range(resid.shape[1])),
        y=resid[i],
        mode='lines',
        line=dict(color='rgba(128,128,128,0.01)', width=1),
        showlegend=False,
        hoverinfo='skip'
    ))
fig.update_layout(
    yaxis=dict(range=[-50_000, 200_000]),
    yaxis_title="Residual",
    xaxis_title="Days since 100 cases",
    height=800,
    width=1000
)
fig.show()

#### 7. Improve model - Logistic Growth Model

The exponential model doesn't capture the plateau in cases that we expect to see over time. Let's implement a logistic growth model which has an S-shaped curve that better represents epidemic dynamics with a carrying capacity.

![Logistic Growth](images/logistic_growth.jpg)

In [None]:
df_country = df.query(f'country=="{country}"').loc[:date]

with pm.Model() as logistic_model:
    t_data = pm.Data('t', df_country.days_since_100.values)
    confirmed_data = pm.Data('confirmed', df_country.confirmed.values)

    # Intercept
    a0 = pm.HalfNormal('a0', sigma=25)
    intercept = pm.Deterministic('intercept', a0 + 100)

    # Slope
    b = pm.HalfNormal('b', sigma=0.2)
    
    carrying_capacity = pm.Uniform('carrying_capacity',
                                   lower=1_000,
                                   upper=80_000_000)
    # Transform carrying_capacity to a
    a = carrying_capacity / intercept - 1

    # Logistic
    growth = carrying_capacity / (1 + a * pm.math.exp(-b * t_data))

    # Likelihood
    pm.NegativeBinomial('obs',
                 growth,
                 alpha=pm.Gamma("alpha", mu=6, sigma=1),
                 observed=confirmed_data)

In [None]:
with logistic_model:
    prior_pred = pm.sample_prior_predictive()

In [None]:
plot_prior_predictive(prior_pred.prior_predictive['obs'].squeeze().values, log_y=True)

In [None]:
with logistic_model:
    # Inference
    trace_logistic = pm.sample(chains=4, cores=4, tune=2000, target_accept=0.9)
    
    # Sample posterior predcitive
    pm.sample_posterior_predictive(trace_logistic, extend_inferencedata=True)

In [None]:
az.plot_trace(trace_logistic)
plt.tight_layout();

In [None]:
plot_posterior_predictive(trace_logistic.posterior_predictive['obs'].sel(chain=0).squeeze().values, confirmed=df_country.confirmed.values)

### Forecasting

One of the key strengths of Bayesian modeling is the ability to make predictions with uncertainty. Let's extend our prediction window to forecast future cases.

In [None]:
# Create a forecast window
forecast_days = 60  # 2 months forecast
future_days = np.arange(len(df_country.days_since_100.values), 
                        len(df_country.days_since_100.values) + forecast_days)

with logistic_model:
    # Update our data containers for forecasting
    pm.set_data({'t': np.concatenate([df_country.days_since_100.values, future_days]),
                 'confirmed': np.concatenate([df_country.confirmed.values, 
                                            np.zeros(forecast_days, dtype='int')])})

    forecast = pm.sample_posterior_predictive(trace_logistic)

In [None]:
import plotly.graph_objects as go
import numpy as np

historical_days = len(df_country.days_since_100.values)
all_days = np.arange(historical_days + forecast_days)

forecast_samples = forecast.posterior_predictive['obs'].values
forecast_mean = forecast_samples.mean(axis=(0, 1))
forecast_lower = np.percentile(forecast_samples, 2.5, axis=(0, 1))
forecast_upper = np.percentile(forecast_samples, 97.5, axis=(0, 1))

go.Figure().add_trace(
    go.Scatter(
        x=np.arange(historical_days),
        y=df_country.confirmed.values,
        mode='lines+markers',
        name='Observed cases',
        line=dict(color='black', width=2),
        marker=dict(size=6)
    )
).add_trace(
    go.Scatter(
        x=all_days,
        y=forecast_mean,
        mode='lines',
        name='Mean forecast',
        line=dict(color='blue', width=2)
    )
).add_trace(
    go.Scatter(
        x=np.concatenate([all_days, all_days[::-1]]),
        y=np.concatenate([forecast_upper, forecast_lower[::-1]]),
        fill='toself',
        fillcolor='rgba(0, 0, 255, 0.2)',
        line=dict(color='rgba(255, 255, 255, 0)'),
        name='95% credible interval'
    )
).add_shape(
    type="line",
    x0=historical_days-1, x1=historical_days-1,
    y0=0, y1=forecast_mean[historical_days-1],
    line=dict(color='gray', width=2, dash='dash')
).update_layout(
    title=f'COVID-19 Cases Forecast for {country}',
    xaxis=dict(
        title='Days since 100 cases'
    ),
    yaxis=dict(
        title='Confirmed cases'
    ),
    width=1200,
    height=800,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    )
)

---

## References

Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B. (2003). Bayesian Data Analysis, Second Edition (Chapman & Hall/CRC Texts in Statistical Science) (2nd ed.). Chapman and Hall/CRC.

Gelman, A., & Rubin, D. B. (1992). Inference from iterative simulation using multiple sequences. Statistical Science. A Review Journal of the Institute of Mathematical Statistics, 457–472.

Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P.-C. (2019). Rank-normalization, folding, and localization: An improved $\hat{R}$ for assessing convergence of MCMC. arXiv preprint arXiv:1903.08008.

Gelman, A., Hwang, J., & Vehtari, A. (2014). Understanding predictive information criteria for Bayesian models. Statistics and Computing, 24(6), 997–1016.

Betancourt, M. (2016). Diagnosing Suboptimal Cotangent Disintegrations in Hamiltonian Monte Carlo. arXiv preprint arXiv:1604.00695.

Gabry, J., Simpson, D., Vehtari, A., Betancourt, M., & Gelman, A. (2019). Visualization in Bayesian workflow. Journal of the Royal Statistical Society: Series A (Statistics in Society), 182(2), 389-402.

In [None]:
%load_ext watermark
%watermark -n -u -v -iv -w