In [None]:
# Installation in case running over Colab
try:
    import google.colab
    %pip install summerepi2==1.0.4
    %pip install estival==0.1.7
    %pip install pylatex==1.4.1
except:
    pass

In [None]:
import pandas as pd
import numpy as np
pd.options.plotting.backend = "plotly"
import plotly.express as px
from datetime import datetime
from random import sample
import arviz as az
from pathlib import Path
import os
import pylatex as pl
from pylatex.utils import NoEscape, bold
from pylatex.section import Section
from tex_param_processing import (
    get_fixed_param_value_text, get_prior_dist_type, get_prior_dist_param_str, 
    get_prior_dist_support, add_param_table_to_doc,
)
from model_features import DocumentedAustModel
import yaml
import matplotlib.pyplot as plt

from summer2 import CompartmentalModel, StrainStratification
from summer2.parameters import Parameter, DerivedOutput
from summer2.utils import ref_times_to_dti

from estival.calibration.mcmc.adaptive import AdaptiveChain
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget

In [None]:
# Analysis period
start_date = datetime(2021, 8, 22)
end_date = datetime(2022, 6, 10)

# To get latest data instead of our download, use: "https://raw.githubusercontent.com/M3IT/COVID-19_Data/master/Data/COVID_AU_state.csv"
state_data = pd.read_csv(
    "https://media.githubusercontent.com/media/monash-emu/AuTuMN/aust-simple-analysis/autumn/projects/austcovid/COVID_AU_state.csv", 
    index_col="date",
)
state_data.index = pd.to_datetime(state_data.index)
state_data = state_data.truncate(before=start_date, after=end_date)
aust_cases = state_data.groupby(state_data.index)["confirmed"].sum()
smoothed_aust_cases = aust_cases.rolling(7).mean()

# Set up supplementary material document
supplement = pl.Document()
supplement.preamble.append(pl.Package("biblatex", options=["sorting=none"]))
supplement.preamble.append(pl.Command("addbibresource", arguments=["../austcovid.bib"]))
supplement.preamble.append(pl.Command("title", "Supplemental Appendix"))
supplement.append(NoEscape(r"\maketitle"))

In [None]:
def build_aust_model(
    start_date: datetime,
    end_date: datetime,
    doc: pl.document.Document,
    add_documentation: bool=False,
) -> CompartmentalModel:
    """
    Build a fairly basic model, as described in the component functions called.
    
    Returns:
        The model object
    """
    compartments = [
        "susceptible",
        "infectious",
        "recovered",
    ]
    aust_model = DocumentedAustModel(doc, add_documentation)
    model = aust_model.build_base_model(start_date, end_date, compartments)
    aust_model.set_model_starting_conditions()
    aust_model.add_infection_to_model()
    aust_model.add_recovery_to_model()
    aust_model.add_notifications_output_to_model()
    age_strata = list(range(0, 75, 5))
    matrix = aust_model.build_polymod_britain_matrix(age_strata)
    adjusted_matrix = aust_model.adapt_gb_matrix_to_aust(matrix, age_strata)
    aust_model.add_age_stratification_to_model(compartments, age_strata, adjusted_matrix)
    strain_strat, starting_strain, other_strains = aust_model.get_strain_stratification()
    aust_model.model.stratify_with(strain_strat)

    if add_documentation:
        aust_model.compile_doc()
    return aust_model.model

In [None]:
parameters = {
    "contact_rate": 0.0255,
    "infectious_period": 5.0,
    "cdr": 0.2,
    "ba2_rel_infness": 2.0,
}
aust_model = build_aust_model(start_date, end_date, supplement, add_documentation=True)
aust_model.run(parameters=parameters)

In [None]:
# Quick look at the starting parameters
axis_labels = {"index": "time", "value": "cases"}
pd.concat(
    (
        smoothed_aust_cases, 
        aust_model.get_derived_outputs_df()["notifications"],
    ), 
    axis=1,
).plot(labels=axis_labels)

In [None]:
priors = [
    UniformPrior("contact_rate", (0.02, 0.05)),
    UniformPrior("infectious_period", (4.0, 8.0)),
]
targets = [
    NegativeBinomialTarget("notifications", smoothed_aust_cases.dropna(), 500.0),
]
uncertainty_analysis = AdaptiveChain(
    build_aust_model, 
    parameters, 
    priors, 
    targets, 
    parameters,
    build_model_kwargs={"start_date": start_date, "end_date": end_date, "doc": None},
)

In [None]:
iterations = 500
burn_in = 100
uncertainty_analysis.run(max_iter=iterations)

In [None]:
with open("parameters.yml", "r") as param_file:
    param_info = yaml.safe_load(param_file)
param_descriptions = param_info["descriptions"]
param_units = param_info["units"]
param_evidence = param_info["evidence"]

prior_names = [priors[i_prior].name for i_prior in range(len(priors))]


In [None]:
from model_features import DocumentedModel, FigElement

In [None]:
class DocumentedCalibration(DocumentedModel):
    def __init__(self, doc=None, add_documentation=False):
        self.doc = doc
        self.add_documentation = add_documentation
        self.doc_sections = {}
    
    def get_inputs(
        self, 
        priors, 
        targets, 
        iterations, 
        burn_in,
        descriptions,
        units,
        evidence,
    ):
        self.iterations = iterations
        self.burn_in = burn_in
        self.priors = priors
        self.targets = targets
        self.descriptions = descriptions
        self.units = units
        self.evidence = evidence
        self.prior_names = [priors[i_prior].name for i_prior in range(len(priors))]
        
    def get_analysis(self, model, params, start, end):
        uncertainty_analysis = AdaptiveChain(
            model, params, self.priors, self.targets, params,
            build_model_kwargs={"start_date": start, "end_date": end, "doc": None},
        )
        uncertainty_analysis.run(max_iter=self.iterations)
        self.uncertainty_outputs = uncertainty_analysis.to_arviz(self.burn_in)
    
    def graph_param_progression(self):
        axes = az.plot_trace(self.uncertainty_outputs, figsize=(16, 12))
        for i_prior, prior_name in enumerate(self.prior_names):
            column_names = ["posterior", "trace"]
            for col in range(2):
                ax = axes[i_prior][col]
                ax.set_title(f"{self.descriptions[prior_name]}, {column_names[col]}", fontsize=20)
                ax.xaxis.set_tick_params(labelsize=15)
                ax.yaxis.set_tick_params(labelsize=15)
        with self.doc.create(pl.Figure()) as plot:
            plot.add_plot(width=NoEscape(r"1\textwidth"))
            plot.add_caption("Parameter posteriors and progression.")

    def add_calib_table_to_doc(self):
        self.doc.append("Input parameters varied through calibration with uncertainty distribution parameters and support.\n")
        calib_headers = ["Name", "Distribution", "Distribution parameters", "Support"]
        with self.doc.create(pl.Tabular("p{2.7cm} " * 4)) as calibration_table:
            calibration_table.add_hline()
            calibration_table.add_row([bold(i) for i in calib_headers])
            for prior in self.priors:
                prior_desc = self.descriptions[prior.name]
                dist_type = get_prior_dist_type(prior)
                dist_params = get_prior_dist_param_str(prior)
                dist_range = get_prior_dist_support(prior)
                calibration_table.add_hline()
                calib_table_row = (prior_desc, dist_type, dist_params, dist_range)
                calibration_table.add_row(calib_table_row)
            calibration_table.add_hline()
            
    def table_param_results(self):
        calib_summary = az.summary(self.uncertainty_outputs)
        headers = ["Para-meter", "Mean (SD)", "3-97% high-density interval", "MCSE mean (SD)", "ESS bulk", "ESS tail", "R_hat"]
        with self.doc.create(pl.Tabular("p{1.3cm} " * 7)) as calib_metrics_table:
            calib_metrics_table.add_hline()
            calib_metrics_table.add_row([bold(i) for i in headers])
            for param in calib_summary.index:
                calib_metrics_table.add_hline()
                summary_row = calib_summary.loc[param]
                name = self.descriptions[param]
                mean_sd = f"{summary_row['mean']} ({summary_row['sd']})"
                hdi = f"{summary_row['hdi_3%']} to {summary_row['hdi_97%']}"
                mcse = f"{summary_row['mcse_mean']} ({summary_row['mcse_sd']})"
                calib_metrics_table.add_row([name, mean_sd, hdi, mcse] + [str(metric) for metric in summary_row[6:]])
            calib_metrics_table.add_hline()
            

In [None]:
documented_calib = DocumentedCalibration(supplement, True)
documented_calib.get_inputs(priors, targets, iterations, burn_in, param_descriptions, param_units, param_evidence)
documented_calib.get_analysis(build_aust_model, parameters, start_date, end_date)
documented_calib.add_calib_table_to_doc()
documented_calib.table_param_results()
documented_calib.graph_param_progression()

In [None]:
uncertainty_outputs = uncertainty_analysis.to_arviz(burn_in)

In [None]:
# How many parameter samples to run through again (suppress warnings if 100+)
n_samples = 50
samples = sorted(sample(range(burn_in, iterations - 200), n_samples))

# Parameter values from sampled runs
sample_params = pd.DataFrame(
    {p.name: uncertainty_outputs.posterior[p.name][0, samples].to_numpy() for p in priors},
    index=samples,
)

# Model outputs from sampled parameter sets
sample_outputs = pd.DataFrame(
    index=aust_model.get_derived_outputs_df().index, 
    columns=samples,
)
for i_param_set in samples:
    parameters.update(sample_params.loc[i_param_set, :].to_dict())
    aust_model.run(parameters=parameters)
    sample_outputs[i_param_set] = aust_model.get_derived_outputs_df()["notifications"]

In [None]:
pd.concat(
    (
        smoothed_aust_cases, 
        sample_outputs,
    ), 
    axis=1,
).plot(labels=axis_labels)

In [None]:
with supplement.create(Section("Parameter values")):
    add_param_table_to_doc(aust_model, supplement, parameters, param_descriptions, param_units, param_evidence, prior_names)
supplement.append(pl.NewPage())
supplement.append(pl.Command("printbibliography"))

In [None]:
supplement.generate_tex("supplement/aust_supp")