Data from https://github.com/M3IT/COVID-19_Data/raw/master/Data/COVID_AU_state.csv
and saved in `aust_covid` folder.

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
import plotly.express as px
from datetime import datetime
import arviz as az

from summer2 import CompartmentalModel
from summer2.parameters import Parameter, DerivedOutput
from summer2.utils import ref_times_to_dti

from estival.calibration.mcmc.adaptive import AdaptiveChain
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget

In [None]:
# Process the observations
state_data = pd.read_csv(
    "./aust_covid/COVID_AU_state.csv", 
    index_col="date",
)
state_data.index = pd.to_datetime(state_data.index)
aust_cases = state_data.groupby(state_data.index)["confirmed"].sum()
smoothed_aust_cases = aust_cases.rolling(7).mean()

In [None]:
def build_aust_model() -> CompartmentalModel:
    """
    Build a (currently very basic) model of COVID-19
    transmission with partial case ascertainment
    and virtually no other features.
    
    Returns:
        The model object
    """
    
    model = CompartmentalModel(
        times=(600, 800),
        compartments=(
            "susceptible",
            "infectious",
            "recovered",
        ),
        infectious_compartments=("infectious",),
        ref_date=datetime(2019, 12, 31),
    )
    model.set_initial_population(
        {
            "susceptible": 2.6e7,
            "infectious": 1.0,
        }
    )
    model.add_infection_frequency_flow(
        "infection",
        Parameter("contact_rate"),
        "susceptible",
        "infectious",
    )
    model.add_transition_flow(
        "recovery",
        1.0 / Parameter("infectious_period"),
        "infectious",
        "recovered",
    )

    model.request_output_for_flow(
        "onset",
        "infection",
        save_results=False,
    )
    model.request_function_output(
        "notifications",
        func=DerivedOutput("onset") * Parameter("cdr"),
    )
    
    return model

In [None]:
parameters = {
    "contact_rate": 0.3,
    "infectious_period": 5.0,
    "cdr": 0.2,
}
aust_model = build_aust_model()
aust_model.run(parameters=parameters)

In [None]:
comparison_df = pd.concat(
    (
        smoothed_aust_cases, 
        aust_model.get_derived_outputs_df()["notifications"],
    ), 
    axis=1,
).plot()

In [None]:
priors = [
    UniformPrior("contact_rate", (0.1, 0.5)),
    UniformPrior("infectious_period", (4.0, 8.0)),
]
targets = [
    NegativeBinomialTarget("notifications", smoothed_aust_cases.dropna(), 500.0),
]
uncertainty_analysis = AdaptiveChain(
    build_aust_model, 
    parameters, 
    priors, 
    targets, 
    parameters,
)

In [None]:
uncertainty_analysis.run(max_iter=10000)

In [None]:
uncertainty_outputs = uncertainty_analysis.to_arviz(1000)
az.plot_trace(
    uncertainty_outputs,
    figsize=(16, 12),
);

In [None]:
sample_params = pd.DataFrame(
    {p.name: uncertainty_outputs.posterior[p.name][0, ::50].to_numpy() for p in priors}
)

In [None]:
sample_outputs = pd.DataFrame(
    index=aust_model.get_derived_outputs_df().index, 
    columns=range(len(sample_params)),
)
for i_param_set in range(len(sample_params)):
    parameters.update(sample_params.iloc[i_param_set, :].to_dict())
    aust_model.run(parameters=parameters)
    sample_outputs[i_param_set] = aust_model.get_derived_outputs_df()["notifications"]

In [None]:
sample_comparison_df = pd.concat(
    (
        smoothed_aust_cases, 
        sample_outputs,
    ), 
    axis=1,
)

In [None]:
import warnings
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
sample_comparison_df.plot()