In [None]:
# Installation in case running over Colab
try:
    import google.colab
    %pip install summerepi2==1.0.4
    %pip install estival==0.1.7
    %pip install pylatex==1.4.1
except:
    pass

In [None]:
import pandas as pd
import numpy as np
pd.options.plotting.backend = "plotly"
import plotly.express as px
from datetime import datetime
from random import sample
import arviz as az
from pathlib import Path
import os
import pylatex as pl
from pylatex.utils import NoEscape, bold
from pylatex.section import Section
from tex_param_processing import (
    get_fixed_param_value_text, get_prior_dist_type, get_prior_dist_param_str, 
    get_prior_dist_support, add_param_table_to_doc, add_calib_table_to_doc,
    add_calib_metric_table_to_doc, add_parameter_progression_fig_to_doc,
)
from model_features import (
    add_age_stratification_to_model,
    adapt_gb_matrix_to_aust, add_strain_stratification_to_model,
    DocumentedModel,
)
import yaml
import matplotlib.pyplot as plt

from summer2 import CompartmentalModel, StrainStratification
from summer2.parameters import Parameter, DerivedOutput
from summer2.utils import ref_times_to_dti

from estival.calibration.mcmc.adaptive import AdaptiveChain
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget

In [None]:
# Analysis period
start_date = datetime(2021, 8, 22)
end_date = datetime(2022, 6, 10)

# To get latest data instead of our download, use: "https://raw.githubusercontent.com/M3IT/COVID-19_Data/master/Data/COVID_AU_state.csv"
state_data = pd.read_csv(
    "https://media.githubusercontent.com/media/monash-emu/AuTuMN/aust-simple-analysis/autumn/projects/austcovid/COVID_AU_state.csv", 
    index_col="date",
)
state_data.index = pd.to_datetime(state_data.index)
state_data = state_data.truncate(before=start_date, after=end_date)
aust_cases = state_data.groupby(state_data.index)["confirmed"].sum()
smoothed_aust_cases = aust_cases.rolling(7).mean()

# Set up supplementary material document
supplement = pl.Document()
supplement.preamble.append(pl.Package("biblatex", options=["sorting=none"]))
supplement.preamble.append(pl.Command("addbibresource", arguments=["../austcovid.bib"]))
supplement.preamble.append(pl.Command("title", "Supplemental Appendix"))
supplement.append(NoEscape(r"\maketitle"))

In [None]:
def build_aust_model(
    start_date: datetime,
    end_date: datetime,
    doc: pl.document.Document,
    add_documentation: bool=False,
) -> CompartmentalModel:
    """
    Build a fairly basic model, as described in the component functions called.
    
    Returns:
        The model object
    """
    compartments = [
        "susceptible",
        "infectious",
        "recovered",
    ]
    aust_model = DocumentedModel(doc, add_documentation)
    model = aust_model.build_base_model(start_date, end_date, compartments)
    aust_model.set_model_starting_conditions()
    aust_model.add_infection_to_model()
    aust_model.add_recovery_to_model()
    aust_model.add_notifications_output_to_model()
    age_strata = list(range(0, 75, 5))
    matrix = aust_model.build_polymod_britain_matrix(age_strata)
    
    if add_documentation:
        aust_model.compile_doc()
    # adjusted_matrix = adapt_gb_matrix_to_aust(matrix, age_strata, doc)
    # add_age_stratification_to_model(model, compartments, age_strata, adjusted_matrix, doc)
    # add_strain_stratification_to_model(model)
    return aust_model.model

In [None]:
build_aust_model(start_date, end_date, supplement, True)

In [None]:
parameters = {
    "contact_rate": 0.3,
    "infectious_period": 5.0,
    "cdr": 0.2,
}
aust_model = build_aust_model(start_date, end_date, supplement, True)
aust_model.run(parameters=parameters)

In [None]:
# Quick look at the starting parameters
axis_labels = {"index": "time", "value": "cases"}
pd.concat(
    (
        smoothed_aust_cases, 
        aust_model.get_derived_outputs_df()["notifications"],
    ), 
    axis=1,
).plot(labels=axis_labels)

In [None]:
priors = [
    UniformPrior("contact_rate", (0.2, 0.5)),
    UniformPrior("infectious_period", (4.0, 8.0)),
]
targets = [
    NegativeBinomialTarget("notifications", smoothed_aust_cases.dropna(), 500.0),
]
uncertainty_analysis = AdaptiveChain(
    build_aust_model, 
    parameters, 
    priors, 
    targets, 
    parameters,
    build_model_kwargs={"start_date": start_date, "end_date": end_date, "doc": None},
)

In [None]:
iterations = 500
burn_in = 100
uncertainty_analysis.run(max_iter=iterations)

In [None]:
uncertainty_outputs = uncertainty_analysis.to_arviz(burn_in)

In [None]:
az.plot_trace(
    uncertainty_outputs,
    figsize=(16, 12),
);

In [None]:
# How many parameter samples to run through again (suppress warnings if 100+)
n_samples = 50
samples = sorted(sample(range(burn_in, iterations - 200), n_samples))

# Parameter values from sampled runs
sample_params = pd.DataFrame(
    {p.name: uncertainty_outputs.posterior[p.name][0, samples].to_numpy() for p in priors},
    index=samples,
)

# Model outputs from sampled parameter sets
sample_outputs = pd.DataFrame(
    index=aust_model.get_derived_outputs_df().index, 
    columns=samples,
)
for i_param_set in samples:
    parameters.update(sample_params.loc[i_param_set, :].to_dict())
    aust_model.run(parameters=parameters)
    sample_outputs[i_param_set] = aust_model.get_derived_outputs_df()["notifications"]

In [None]:
pd.concat(
    (
        smoothed_aust_cases, 
        sample_outputs,
    ), 
    axis=1,
).plot(labels=axis_labels)

In [None]:
with open("parameters.yml", "r") as param_file:
    param_info = yaml.safe_load(param_file)
param_descriptions = param_info["descriptions"]
param_units = param_info["units"]
param_evidence = param_info["evidence"]

In [None]:
prior_names = [uncertainty_analysis.priors[i_prior].name for i_prior in range(len(priors))]
calib_summary = az.summary(uncertainty_outputs)
with supplement.create(Section("Parameter values")):
    add_param_table_to_doc(aust_model, supplement, parameters, param_descriptions, param_units, param_evidence, prior_names)
with supplement.create(Section("Calibration algorithm")):
    add_calib_table_to_doc(supplement, uncertainty_analysis.priors, param_descriptions)
with supplement.create(Section("Calibration metrics")):
    add_calib_metric_table_to_doc(supplement, calib_summary, param_descriptions)
add_parameter_progression_fig_to_doc(uncertainty_outputs, supplement, prior_names, param_descriptions)
supplement.append(pl.NewPage())
supplement.append(pl.Command("printbibliography"))

In [None]:
supplement.generate_tex("supplement/aust_supp")