Data from https://github.com/M3IT/COVID-19_Data/raw/master/Data/COVID_AU_state.csv
and saved in `aust_covid` folder.

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
import plotly.express as px
from datetime import datetime
import arviz as az

from summer2 import CompartmentalModel
from summer2.parameters import Parameter, DerivedOutput
from summer2.utils import ref_times_to_dti

from estival.calibration.mcmc.adaptive import AdaptiveChain
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget

In [None]:
# Process the observations
state_data = pd.read_csv(
    "./aust_covid/COVID_AU_state.csv", 
    index_col="date",
)
state_data.index = pd.to_datetime(state_data.index)
aust_cases = state_data.groupby(state_data.index)["confirmed"].sum()
smoothed_aust_cases = aust_cases.rolling(7).mean()

In [None]:
def build_aust_model() -> CompartmentalModel:
    model = CompartmentalModel(
        times=(600, 800),
        compartments=(
            "susceptible",
            "infectious",
            "recovered",
        ),
        infectious_compartments=("infectious",),
        ref_date=datetime(2019, 12, 31),
    )
    model.set_initial_population(
        {
            "susceptible": 2.6e7,
            "infectious": 1.0,
        }
    )
    model.add_infection_frequency_flow(
        "infection",
        Parameter("contact_rate"),
        "susceptible",
        "infectious",
    )
    model.add_transition_flow(
        "recovery",
        1.0 / Parameter("infectious_period"),
        "infectious",
        "recovered",
    )

    model.request_output_for_flow(
        "onset",
        "infection",
    )
    model.request_function_output(
        "notifications",
        func=DerivedOutput("onset") * Parameter("cdr"),
    )
    
    return model

In [None]:
parameters = {
    "contact_rate": 0.3,
    "infectious_period": 5.0,
    "cdr": 0.2,
}
aust_model = build_aust_model()
aust_model.run(parameters=parameters)

In [None]:
outputs = aust_model.get_derived_outputs_df()

In [None]:
comparison_df = pd.concat(
    (
        smoothed_aust_cases, 
        outputs["notifications"],
    ), 
    axis=1,
)

In [None]:
comparison_df.plot()

In [None]:
priors = [
    UniformPrior("contact_rate", (0.1, 0.5)),
    UniformPrior("infectious_period", (4.0, 8.0)),
]
targets = [
    NegativeBinomialTarget("notifications", smoothed_aust_cases.dropna(), 500.0),
]

In [None]:
uncertainty_analysis = AdaptiveChain(build_aust_model, parameters, priors, targets, parameters)

In [None]:
uncertainty_analysis.run(max_iter=10000)

In [None]:
uncertainty_outputs = uncertainty_analysis.to_arviz(1000)

In [None]:
az.plot_trace(
    uncertainty_outputs,
    figsize=(16, 12),
);