In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
from datetime import datetime
from random import sample
import pylatex as pl
from pylatex.utils import NoEscape, bold
from model_features import DocumentedAustModel
from tex_param_processing import DocumentedCalibration
import yaml
from jax import numpy as jnp

from summer2 import CompartmentalModel, StrainStratification
from summer2.parameters import Function, Time

from estival.calibration.mcmc.adaptive import AdaptiveChain
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget

In [None]:
# Data inputs
start_date = datetime(2021, 8, 22)
end_date = datetime(2022, 6, 10)
# To get latest data instead of our download, use: "https://raw.githubusercontent.com/M3IT/COVID-19_Data/master/Data/COVID_AU_state.csv"
state_data = pd.read_csv(
    "https://media.githubusercontent.com/media/monash-emu/AuTuMN/aust-simple-analysis/autumn/projects/austcovid/COVID_AU_state.csv", 
    index_col="date",
)
state_data.index = pd.to_datetime(state_data.index)
state_data = state_data.truncate(before=start_date, after=end_date)
aust_cases = state_data.groupby(state_data.index)["confirmed"].sum()
smoothed_aust_cases = aust_cases.rolling(7).mean()

In [None]:
def build_aust_model(
    start_date: datetime,
    end_date: datetime,
    doc: pl.document.Document,
    add_documentation: bool=False,
) -> CompartmentalModel:
    """
    Build a fairly basic model, as described in the component functions called.
    
    Returns:
        The model object
    """
    compartments = [
        "susceptible",
        "infectious",
        "recovered",
    ]
    aust_model = DocumentedAustModel(doc, add_documentation)
    model = aust_model.build_base_model(start_date, end_date, compartments)
    aust_model.set_model_starting_conditions()
    aust_model.add_infection_to_model()
    aust_model.add_recovery_to_model()
    aust_model.add_notifications_output_to_model()
    age_strata = list(range(0, 75, 5))
    matrix = aust_model.build_polymod_britain_matrix(age_strata)
    adjusted_matrix = aust_model.adapt_gb_matrix_to_aust(matrix, age_strata)
    aust_model.add_age_stratification_to_model(compartments, age_strata, adjusted_matrix)
    strain_strat, starting_strain, other_strains = aust_model.get_strain_stratification()
    aust_model.model.stratify_with(strain_strat)

    if add_documentation:
        aust_model.compile_doc()
    return aust_model.model

In [None]:
# Set up for manual run with supplementary material document
supplement = pl.Document()
supplement.preamble.append(pl.Package("biblatex", options=["sorting=none"]))
supplement.preamble.append(pl.Command("addbibresource", arguments=["../austcovid.bib"]))
supplement.preamble.append(pl.Command("title", "Supplemental Appendix"))
supplement.append(NoEscape(r"\maketitle"))

parameters = {
    "contact_rate": 0.0255,
    "infectious_period": 5.0,
    "cdr": 0.2,
    "ba2_rel_infness": 2.0,
}
aust_model = build_aust_model(start_date, end_date, supplement, add_documentation=True)
aust_model.run(parameters=parameters)

In [None]:
def voc_seed_func(time, entry_rate, start_time, seed_duration):
    offset = time - start_time
    return jnp.where(offset > 0, jnp.where(offset < seed_duration, entry_rate, 0.0), 0.0)

def make_voc_seed_func(entry_rate: float, start_time: float, seed_duration: float):
    return Function(voc_seed_func, [Time, entry_rate, start_time, seed_duration])

In [None]:
voc_seed_func = make_voc_seed_func(1.0, 600.0, 1.0)

In [None]:
aust_model = build_aust_model(start_date, end_date, supplement, add_documentation=True)
aust_model.add_importation_flow(
    "seed_ba1",
    voc_seed_func,
    "infectious",
    dest_strata={"strain": "ba1"},
    split_imports=True,
)    
aust_model.run(parameters=parameters)

In [None]:
# Look at results of manual run
axis_labels = {"index": "time", "value": "cases"}
comparison_df = pd.concat((smoothed_aust_cases, aust_model.get_derived_outputs_df()["notifications"]), axis=1)
comparison_df.plot(labels=axis_labels)

In [None]:
# Calibration settings
with open("parameters.yml", "r") as param_file:
    param_info = yaml.safe_load(param_file)
param_descriptions = param_info["descriptions"]
param_units = param_info["units"]
param_evidence = param_info["evidence"]

iterations = 500
burn_in = 100
priors = [
    UniformPrior("contact_rate", (0.02, 0.05)),
    UniformPrior("infectious_period", (4.0, 8.0)),
]
targets = [
    NegativeBinomialTarget("notifications", smoothed_aust_cases.dropna(), 500.0),
]
uncertainty_analysis = AdaptiveChain(
    build_aust_model, 
    parameters, 
    priors, 
    targets, 
    parameters,
    build_model_kwargs={"start_date": start_date, "end_date": end_date, "doc": None},
)

In [None]:
# Run and document the calibration
documented_calib = DocumentedCalibration(
    priors, 
    targets, 
    iterations, 
    burn_in, 
    parameters, 
    param_descriptions, 
    param_units, 
    param_evidence, 
    supplement,
)
documented_calib.get_analysis(build_aust_model, parameters, start_date, end_date)
documented_calib.add_calib_table_to_doc()
documented_calib.table_param_results()
documented_calib.graph_param_progression()
documented_calib.add_param_table_to_doc(aust_model, parameters)
documented_calib.compile_doc()

In [None]:
# Look at a subset of the results of calibration
sample_outputs = documented_calib.show_sample_outputs(50, aust_model, smoothed_aust_cases, parameters)
sample_outputs.plot(labels=axis_labels)

In [None]:
# Finish up the supplement document with bibliography
supplement.append(pl.NewPage())
supplement.append(pl.Command("printbibliography"))
supplement.generate_tex("supplement/aust_supp")