In [None]:
# Installation in case running over Colab
try:
    import google.colab
    %pip install summerepi2==1.0.4
    %pip install estival==0.1.7
except:
    pass

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
import plotly.express as px
from datetime import datetime
from random import sample
import arviz as az
from pathlib import Path
import os
import pylatex as pl
from pylatex.utils import NoEscape

from summer2 import CompartmentalModel
from summer2.parameters import Parameter, DerivedOutput
from summer2.utils import ref_times_to_dti

from estival.calibration.mcmc.adaptive import AdaptiveChain
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget

In [None]:
# Analysis period
start_date = datetime(2021, 8, 22)
end_date = datetime(2022, 3, 10)

# To get latest data instead of our download, use: "https://raw.githubusercontent.com/M3IT/COVID-19_Data/master/Data/COVID_AU_state.csv"
state_data = pd.read_csv(
    "https://media.githubusercontent.com/media/monash-emu/AuTuMN/aust-simple-analysis/notebooks/user/jtrauer/austcovid/COVID_AU_state.csv", 
    index_col="date",
)
state_data.index = pd.to_datetime(state_data.index)
state_data = state_data.truncate(before=start_date, after=end_date)
aust_cases = state_data.groupby(state_data.index)["confirmed"].sum()
smoothed_aust_cases = aust_cases.rolling(7).mean()

# Document to write outputs to
file_name = Path(os.path.abspath("austcovid")).joinpath("parameters.tex")

In [None]:
def build_aust_model(
    start_date: datetime,
    end_date: datetime,
) -> CompartmentalModel:
    """
    Build a (currently very basic) SIR model of COVID-19
    transmission with partial case ascertainment
    and virtually no other features.
    
    Returns:
        The model object
    """
    
    ref_date = datetime(2019, 12, 31)
    
    model = CompartmentalModel(
        times=(
            (start_date - ref_date).days, 
            (end_date - ref_date).days,
        ),
        compartments=(
            "susceptible",
            "infectious",
            "recovered",
        ),
        infectious_compartments=("infectious",),
        ref_date=ref_date,
    )
    model.set_initial_population(
        {
            "susceptible": 2.6e7,
            "infectious": 1.0,
        }
    )
    model.add_infection_frequency_flow(
        "infection",
        Parameter("contact_rate"),
        "susceptible",
        "infectious",
    )
    model.add_transition_flow(
        "recovery",
        1.0 / Parameter("infectious_period"),
        "infectious",
        "recovered",
    )

    model.request_output_for_flow(
        "onset",
        "infection",
        save_results=False,
    )
    model.request_function_output(
        "notifications",
        func=DerivedOutput("onset") * Parameter("cdr"),
    )
    
    return model

In [None]:
parameters = {
    "contact_rate": 0.3,
    "infectious_period": 5.0,
    "cdr": 0.2,
}
aust_model = build_aust_model(start_date, end_date)
aust_model.run(parameters=parameters)

In [None]:
# Quick look at the starting parameters
axis_labels = {"index": "time", "value": "cases"}
pd.concat(
    (
        smoothed_aust_cases, 
        aust_model.get_derived_outputs_df()["notifications"],
    ), 
    axis=1,
).plot(labels=axis_labels)

In [None]:
priors = [
    UniformPrior("contact_rate", (0.1, 0.5)),
    UniformPrior("infectious_period", (4.0, 8.0)),
]
targets = [
    NegativeBinomialTarget("notifications", smoothed_aust_cases.dropna(), 500.0),
]
uncertainty_analysis = AdaptiveChain(
    build_aust_model, 
    parameters, 
    priors, 
    targets, 
    parameters,
    build_model_kwargs={"start_date": start_date, "end_date": end_date},
)

In [None]:
iterations = 5000
burn_in = 1000
uncertainty_analysis.run(max_iter=iterations)

In [None]:
uncertainty_outputs = uncertainty_analysis.to_arviz(burn_in)
az.summary(uncertainty_outputs)

In [None]:
az.plot_trace(
    uncertainty_outputs,
    figsize=(16, 12),
);

In [None]:
# How many parameter samples to run through again (suppress warnings if 100+)
n_samples = 50
samples = sorted(sample(range(burn_in, iterations - 1000), n_samples))

# Parameter values from sampled runs
sample_params = pd.DataFrame(
    {p.name: uncertainty_outputs.posterior[p.name][0, samples].to_numpy() for p in priors},
    index=samples,
)

# Model outputs from sampled parameter sets
sample_outputs = pd.DataFrame(
    index=aust_model.get_derived_outputs_df().index, 
    columns=samples,
)
for i_param_set in samples:
    parameters.update(sample_params.loc[i_param_set, :].to_dict())
    aust_model.run(parameters=parameters)
    sample_outputs[i_param_set] = aust_model.get_derived_outputs_df()["notifications"]

In [None]:
pd.concat(
    (
        smoothed_aust_cases, 
        sample_outputs,
    ), 
    axis=1,
).plot(labels=axis_labels)

In [None]:
# Parameter descriptions and units
param_descriptions = {
    "contact_rate": "Rate of effective contacts",
    "infectious_period": "Infectious period",
    "cdr": "Case detection rate (proportion of infections captured through surveillance)",
}
param_units = {
    "contact_rate": "contacts per person per day",
    "infectious_period": "days",
    "cdr": "",
}
param_evidence = {
    "contact_rate": "Calibrated within plausible range",
    "infectious_period": 
        "This quantity is difficult to estimate, ",
        # "given that identified cases are typically quarantined. " \
        # "Studies in settings of high case ascertainment and an "\
        # "effective public health response have suggested a duration of greater than 5.5 days \cite{bi2020}. " \
        # "PCR positivity, which may continue for up to two to three weeks from the point of symptom onset "
        # "\cite{he2020} \cite{byrne2020} does not necessarily indicate infectiousness. " \
        # "The duration infectious for asymptomatic persons has also been estimated " \
        # "at 6.5 to 9.5 days \cite{byrne2020}.",
    "cdr": "Assumed",
}

prior_names = [uncertainty_analysis.priors[i_prior].name for i_prior in range(len(priors))]

In [None]:
# Some functions to get string description of parameters
def get_fixed_param_value_text(
    param,
    parameters,
    param_units,
    prior_names,
    decimal_places=2,
    calibrated_string="Calibrated, see priors table",
) -> str:
    param_number = round(parameters[param], decimal_places)
    param_unit = param_units[param]
    return calibrated_string if param in prior_names else f"{param_number} {param_unit}"


def get_prior_dist_type(prior):
    # I'm sure there's a better way!
    dist_type = str(prior.__class__).replace(">", "").replace("'", "").split(".")[-1].replace("Prior", "")
    return f"{dist_type} distribution"


def get_prior_dist_param_str(prior):
    return " ".join([f"{param}: {prior.distri_params[param]}" for param in prior.distri_params])


def get_prior_dist_support(prior):
    return " to ".join([str(i) for i in prior.bounds()])

In [None]:
# Write the parameters and the priors to the outputs files for Overleaf
with open(file_name, "w") as tex_file:
    tex_file.write(
        "\\section{Fixed parameters}" \
        "\n\\begin{tabularx}{\\textwidth}{X X X}" \
        "\n\\hline" \
        "Parameter & Value & Evidence \\\\" \
        "\n\\hline"
    )
    for param in aust_model.get_input_parameters():
        param_value_text = get_fixed_param_value_text(
            param,
            parameters,
            param_units,
            prior_names,
        )
        param_row_in_tex = f"\n{param_descriptions[param]} & {param_value_text} & {param_evidence[param]} \\\\"
        tex_file.write(param_row_in_tex)
    tex_file.write(
        "\n\\end{tabularx}" \
        "\\section{Prior distributions}" \
        "\n\\begin{tabularx}{\\textwidth}{X X X X}" \
        "\n\\hline" \
        "\nPrior & Distribution & Parameters & Support \\\\" \
        "\n\\hline"
    )
    for prior in uncertainty_analysis.priors:
        prior_desc = param_descriptions[prior.name]
        dist_type = get_prior_dist_type(prior)
        dist_params = get_prior_dist_param_str(prior)
        dist_range = get_prior_dist_support(prior)
        prior_row_in_tex = f"\n{prior_desc} & {dist_type} & {dist_params} & {dist_range} \\\\"
        tex_file.write(prior_row_in_tex)
    tex_file.write(
        "\n\\end{tabularx}"
    )

In [None]:
doc = pl.Document()
doc.preamble.append(
    pl.Package(
        "biblatex", 
        options=["sorting=none"],
    ),
)
doc.preamble.append(
    pl.Command(
        "addbibresource", 
        arguments=["austcovid/austcovid.bib"],
    ),
)
doc.append("Text to reference")
doc.append(
    pl.Command(
        "cite", 
        arguments=["bi2020"],
    ),
)
doc.append("\n")
with doc.create(pl.Tabular("p{2.5cm} p{2.5cm} p{4cm} p{3cm}")) as table:
    for param in aust_model.get_input_parameters():
        param_value_text = get_fixed_param_value_text(
            param,
            parameters,
            param_units,
            prior_names,
        )
        table.add_hline()
        table.add_row(
            (
                param_descriptions[param], 
                param_value_text, 
                param_evidence[param],
                NoEscape(" \cite{bi2020,byrne2020} "),
            ),
        )
doc.append(
    pl.Command("printbibliography"),
)

In [None]:
doc.generate_pdf("full", clean_tex=False)