# Introduction to calibration and uncertainty propagation
Up until now we have been creating models that may accurately represent the local epidemic but (at best) only provide one possible epidemic profile that would be consistent with the observations. In this notebook, we extend this to obtain a range of parameter values and epidemic trajectories that would be consistent with the local observations, and thereby quantify the uncertainty in our simulations.

In this notebook, we will learn how to use a Markov Chain Monte Carlo (MCMC) algorithm to calibrate an SIR model to epidemic data.
That is, we will use a Bayesian sampling approach to estimate model parameters and to project the epidemic with uncertainty.

We will implement the Metropolis algorithm which is one type of MCMC.

Recommended pre-reading:
- Wikipedia page on Metropolis–Hastings algorithm [here](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm),
- Some example implementations with discussion of common tuning issues [here](https://jellis18.github.io/post/2018-01-02-mcmc-part1/).


And also, a great interactive demo of multiple Bayesian sampling algorithms [here](https://chi-feng.github.io/mcmc-demo/).



In [None]:
# pip install the required packages if running in Colab
try:
  import google.colab
  IN_COLAB = True
  %pip install summerepi2
  %pip install numpyro
except:
  IN_COLAB = False


In [None]:
# Standard imports, plotting option and constant definition
import pandas as pd
from scipy import stats
import numpy as np
from jax import numpy as jnp

from summer2 import CompartmentalModel
from summer2.parameters import Parameter

pd.options.plotting.backend = "plotly"

import numpyro
from numpyro import distributions as dist

from numpyro.infer.initialization import init_to_median, init_to_value, init_to_sample
from numpyro.infer import MCMC
from jax import random

## Create some dummy data we want our model to fit to

In [None]:
data = pd.DataFrame({"active_cases":
{
    60.: 3000.,
    80.: 8500.,
    100.: 21000.,
    120.: 40000.,
    140.: 44000.,
    160.: 30000.,
    180.: 16000.,
    200.: 7000.,
}}
)
data['active_cases'].plot(kind="scatter")

# Model

## Define a simple SIR model

In [None]:
def build_sir_model(model_config: dict) -> CompartmentalModel:
    """
    Create a compartmental model, with the minimal compartmental structure needed to run and produce some sort of 
    meaningful outputs.
    
    Args:
        model_config: Fixed values that determine structural and numerical properties of the model
    Returns:
        A compartmental model currently without stratification applied
    """

    model = CompartmentalModel(
        times=(model_config["start_time"], model_config["end_time"]),
        compartments=["S", "I", "R"],
        infectious_compartments=["I"],
    )

    infectious_seed = model_config["infectious_seed"]
    initial_population = model_config["initial_population"]
    assert initial_population >= infectious_seed, "Initial population size must be greater than infectious seed"

    model.set_initial_population(
        distribution=
        {
            "S": initial_population - infectious_seed, 
            "I": infectious_seed
        }
    )
    
    # Set up flows with summer2 Parameter objects - these are placeholders
    # whose actual values will be looked up in a dictionary when we run the model
    
    # Susceptible people can get infected
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"), 
        source="S", 
        dest="I",
    )
    
    # Note that you can perform arithmetic and other transforms on Parameter objects - 
    # their final values will be computed later
    
    # Infectious people recover
    model.add_transition_flow(
        name="recovery",
        fractional_rate= 1. / Parameter("infection_duration"),
        source="I",
        dest="R",
    )
    
    model.request_output_for_compartments("active_cases", "I")

    return model

## Run the model with some example parameters

In [None]:
model_config = {
    # Fixed configuration options that define the structure and behaviour of the model
    "initial_population": 1.e6,
    "infectious_seed": 100.,
    "start_time": 0,
    "end_time": 365,
}
# Get an SIR model object
sir_model = build_sir_model(model_config)


In [None]:
# Define a dictionary of free parameters that we will use in calibration - ie those that we
# declared as Parameter objects when building the model
parameters = {
    "contact_rate": 0.3,
    "infection_duration": 9.0
}

# Run the model with the dummy parameter values
sir_model.run(parameters)

# Plot the model outputs against the data
output_df = pd.DataFrame({
    "modelled": sir_model.get_outputs_df()["I"],
    "observed": data["active_cases"]
})
output_df.plot(kind='scatter')


In [None]:
sir_model.get_input_parameters()

In [None]:
def evaluate_log_priors(proposed_parameters: dict) -> float:
    # Initialise the prior likelihood to 1
    prior_log_proba = 0.

    # Use a uniform prior on [0., 0.5] for the contact_rate 
    prior_log_proba += stats.uniform.logpdf(x=proposed_parameters['contact_rate'], loc=0, scale=0.5)

    # Use a normal prior for the infection duration, with mean=7 days and sd=.5
    prior_log_proba += stats.norm.logpdf(x=proposed_parameters['infection_duration'], loc=7, scale=.5)

    return prior_log_proba

In [None]:
sir_model = build_sir_model(model_config)
runner = sir_model.get_runner(parameters)

In [None]:
priors = {
    "contact_rate": dist.Uniform(0.01, 0.5),
    "infection_duration": dist.TruncatedNormal(7.0, 0.5, low=2.0, high=50.0),
    #"infection_duration": dist.Normal(7.0, 0.5)
    #"infection_duration": dist.Uniform(3.0, 15.0),
}

In [None]:
# There are a few caveats here, mostly for good reason but with sometimes
# confusing results
# Normal will produce infinite support and thus generate some NaNs in outputs
# Truncated normal makes a bit more sense in this context

In [None]:
#d0 = dist.TruncatedNormal(7.0, 0.5, low=0.0, high=np.inf)
#d1 = dist.Normal(7.0, 0.5)

In [None]:
#d0.support

In [None]:
#d1.support

In [None]:
#pd.Series(d0.sample(rng_key, (10000,))).hist()

In [None]:
#pd.Series(d1.sample(rng_key, (10000,))).hist()

In [None]:
def get_target_indices(target_series, model_times):
    return jnp.array([np.where(model_times == idx)[0][0] for idx in target_series.index])

In [None]:
get_target_indices(data["active_cases"], sir_model.times)

In [None]:
s = data["active_cases"]

In [None]:
targets = {}
for k, v in data.items():
    targets[k] = jnp.array(v.to_numpy())
targets

In [None]:
target_indices = {}
for k,v in data.items():
    target_indices[k] = get_target_indices(v, sir_model.times)
    
target_indices

In [None]:
targets

In [None]:
sir_model = build_sir_model(model_config)
runner = sir_model.get_runner(parameters)

In [None]:
priors = {
    "contact_rate": dist.Uniform(0.01, 0.5),
    "infection_duration": dist.TruncatedNormal(7.0, 0.5, low=0.0, high=50.0),
    #"infection_duration": dist.Normal(7.0, 0.5)
    #"infection_duration": dist.Uniform(3.0, 15.0),
}

In [None]:
def model():
    # numpyro.sample is one of the most common numpyro primitives
    # it generates a sample based on the supplied distribution,
    # and the PRNGKey supplied by the optimizer (more on this later)
    # In most example models the samples would be explicit,
    # but here we iterate through our priors dictionary

    param_updates = {k:numpyro.sample(k, priors[k]) for k in priors}

    
    # runer._run_func is a pure compiled jax function - it bypasses
    # any additional CPython code, therefore can be used directly in
    # numpyro models
    results = runner._run_func(param_updates)

    # Now we assemble our loglikelihood, and collect additional information
    for k,obs in targets.items():
        # Note we can use normal looking python code here on our model outputs;
        # internally these are all jax types
        modelled = results["derived_outputs"][k][target_indices[k]]
        
        # For some models, we would use the sample primitive here as well,
        # and supply an 'obs' argument
        # But, by using the numpyro factor primitive we can assemble a more
        # customized ll function, as well as save specific components
        # via the deterministic primitive
        #k_fac = dist.TruncatedNormal(modelled, jnp.std(obs), low=0.0).log_prob(obs).sum()
        k_fac = dist.Normal(modelled, jnp.std(obs)).log_prob(obs).sum()
        
        numpyro.factor(k, k_fac)
        numpyro.deterministic(f"{k}_ll", k_fac)
        #numpyro.deterministic(f"{k}_sum", jnp.sum(modelled))

In [None]:
initial_parameters = {
    "contact_rate": jnp.array((0.2,)),
    "infection_duration": jnp.array((7.,))
}

#sa_kernel = numpyro.infer.SA(model, init_strategy=init_to_value(values=initial_parameters))
sa_kernel = numpyro.infer.SA(model, dense_mass=True, adapt_state_size=8, init_strategy=init_to_median)

In [None]:
mcmc = MCMC(sa_kernel, num_chains=2, num_samples=10000,num_warmup=10000)#,thinning=1)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))
#mcmc.run(rng_key)#, init_params=initial_parameters)

In [None]:
pd.Series(mcmc.get_extra_fields()['accept_prob']).plot()

In [None]:
mcmc.print_summary(exclude_deterministic=False)

In [None]:
import arviz as az

In [None]:
ad = az.from_numpyro(mcmc)

In [None]:
az.plot_posterior(ad)

In [None]:
_ = az.plot_trace(ad, compact=False, figsize=(15,20))

In [None]:
class SampleWrapper:
    def __init__(self, samples, params):
        self.samples = samples
        self.params = params
        
    def __getitem__(self, idx):
        out = {}
        for k,v in samples.items():
            if k in self.params:
                out[k] = v[idx]
        return out

In [None]:
samples = mcmc.get_samples(False)

In [None]:
sw = SampleWrapper(samples, sir_model.get_input_parameters())

In [None]:
best_idx = pd.Series(sw.samples['active_cases_ll']).idxmax()

In [None]:
runner.run(sw[best_idx])
# Plot the model outputs against the data
output_df = pd.DataFrame({
    "modelled": runner.get_derived_outputs_df()["active_cases"],
    "observed": data.active_cases
})
output_df.plot(kind='scatter')

In [None]:
sampled_df = pd.DataFrame(index=runner.get_derived_outputs_df().index)

for i in range(100):
    runner.run(sw[i*100])
    sampled_df[i] = runner.get_derived_outputs_df()['active_cases']

In [None]:
sampled_df.plot()