# The estival library for calibration of summer2 models

In [None]:
# If we are running in google colab, pip install the required packages, 
# but do not modify local environments
try:
    import google.colab
    IN_COLAB = True
    %pip install summerepi2
    %pip install estival
except:
    IN_COLAB = False

In [None]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

In [None]:
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget

In [None]:
from summer2 import CompartmentalModel
from summer2.parameters import Parameter

## Build a simple model

In [None]:
def build_model():
    m = CompartmentalModel([0,100], ["S","I","R"],"I",ref_date=datetime(2000,1,1))
    m.set_initial_population({"S": 1000, "I": 10.0})
    m.add_infection_frequency_flow("infection", Parameter("contact_rate"),"S","I")
    m.add_transition_flow("recovery", Parameter("recovery_rate"),"I","R")
    m.request_output_for_flow("infection", "infection")
    m.request_output_for_flow("recovery", "recovery")
    return m

In [None]:
m = build_model()

In [None]:
parameters = {"contact_rate": 0.2, "recovery_rate": 0.02}

In [None]:
m.run(parameters)

In [None]:
m.get_outputs_df().plot()

In [None]:
m.get_derived_outputs_df().plot()

In [None]:
idata = m.get_derived_outputs_df()["infection"]
rdata = m.get_derived_outputs_df()["recovery"]

## Defining our calibration

In [None]:
from estival.calibration.mcmc.adaptive import AdaptiveChain

In [None]:
mcmc_priors = [
    UniformPrior("contact_rate", (0.01,0.5)),
    UniformPrior("recovery_rate", (0.001,0.1))
]

In [None]:
mcmc_targets = [
    NegativeBinomialTarget("infection", idata, 5.0),
    NegativeBinomialTarget("recovery", rdata, 1.0)
]

In [None]:
mcmc = AdaptiveChain(build_model, parameters, mcmc_priors, mcmc_targets, parameters)

In [None]:
mcmc.run(max_iter=20000)

## Examing results with ArViz

In [None]:
import arviz as az

In [None]:
infdata = mcmc.to_arviz(5000)

In [None]:
infdata

In [None]:
az.plot_trace(infdata,figsize=(16,12));

In [None]:
az.plot_posterior(infdata);

In [None]:
az.plot_ess(infdata, kind="evolution");

In [None]:
mcmc.results

## Basic uncertainty sampling

In [None]:
sample_data = {}
for p in parameters:
    sample_data[p] = infdata.posterior[p][0,::50].to_numpy()
    
samples = pd.DataFrame(columns=parameters, data = sample_data)
    

In [None]:
outputs = []
for k,s in samples.iterrows():
    m.run(s.to_dict())
    outputs.append(m.get_derived_outputs_df())

In [None]:
q = (0.1,0.25,0.5,0.75,0.9)
ax = pd.DataFrame(index=idata.index,columns=q, data=np.quantile([o["infection"] for o in outputs], q, axis=0).T).plot()
idata.plot(style='.')

In [None]:
q = (0.1,0.25,0.5,0.75,0.9)
ax = pd.DataFrame(index=idata.index,columns=q, data=np.quantile([o["recovery"] for o in outputs], q, axis=0).T).plot()
rdata.plot(style='.')