In [None]:
from pathlib import Path

In [None]:
# Installation in case running over Colab
try:
    import google.colab
    on_colab = True
    %pip install estival
    %pip install pylatex==1.4.1
    %pip install kaleido
    ! git clone https://github.com/monash-emu/aust-covid
    %cd aust-covid
    %pip install -e ./
    PROJECT_PATH = Path().resolve()
    import multiprocessing as mp
    mp.set_start_method("forkserver")
    optimise_model = False
    new_calibration = True
except:
    PROJECT_PATH = Path().resolve().parent
    on_colab = False
    optimise_model = False
    new_calibration = False

DATA_PATH = PROJECT_PATH / "data"
OUTPUT_PATH = PROJECT_PATH / "outputs"
SUPPLEMENT_PATH = PROJECT_PATH / "supplement"
Path(OUTPUT_PATH).mkdir(parents=True, exist_ok=True)

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
from datetime import datetime
import pylatex as pl
from pylatex.utils import NoEscape
import pymc as pm
import arviz as az
import yaml
import nevergrad as ng
import plotly.express as px

from estival.model import BayesianCompartmentalModel
from estival.optimization.nevergrad import optimize_model
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget, CustomTarget
from estival.calibration import pymc as epm

from aust_covid import model
from doc_utils import calib_doc_utils
from aust_covid.inputs import load_param_info
from doc_utils.general_doc_utils import TextElement, TableElement, FigElement, add_element_to_document, \
    save_pyplot_add_to_doc, save_plotly_add_to_doc, compile_doc, generate_doc

In [None]:
# Data inputs
aust_data = pd.read_csv(DATA_PATH / "Aus_covid_data.csv", index_col="date")
aust_data.index = pd.to_datetime(aust_data.index)

# Extract national
national_data = aust_data[aust_data["region"] == "AUS"]
smoothed_national_cases = national_data["cases"].rolling(window=7).mean().dropna()

# Extract non-WA
non_wa_data = aust_data.loc[(aust_data["region"] != "AUS") & (aust_data["region"] != "WA")]
non_wa_data = non_wa_data.groupby(non_wa_data.index).sum()
smoothed_non_wa_cases = non_wa_data["cases"].rolling(window=7).mean().dropna()

In [None]:
# Times
start_date = datetime(2021, 9, 1)  # Analysis start time
end_date = datetime(2022, 10, 1)  # Analysis end time
plot_start_date = datetime(2021, 12, 1)  # Left end for plots
ref_date = datetime(2019, 12, 31)  # Arbitrary reference date

In [None]:
# Parameters
parameters = {
    "contact_rate": 0.048,
    "infectious_period": 5.0,
    "latent_period": 2.0,
    "start_cdr": 0.1,
    "seed_rate": 1.0,
    "seed_duration": 1.0,
    "ba1_seed_time": 660.0,
    "ba2_seed_time": 690.0,
    "ba5_seed_time": 780.0,
    "ba2_escape": 0.45,
    "ba5_escape": 0.6,
    "notifs_shape": 2.0,
    "notifs_mean": 4.0,
    "natural_immunity_period": 50.0,
}
param_info = load_param_info(PROJECT_PATH / "inputs/parameters.yml", parameters)
param_info

In [None]:
doc_sections = {}
compartments = [
    "susceptible",
    "latent",
    "infectious",
    "recovered",
    "waned",
]
aust_model, build_text = model.build_base_model(ref_date, compartments, start_date, end_date)
add_element_to_document("Model construction", TextElement(build_text), doc_sections)
build_text

In [None]:
pop_data, pop_text = model.get_pop_data()
add_element_to_document("Model construction", TextElement(pop_text), doc_sections, subsection_name="Population")
pop_text

In [None]:
start_text = model.set_starting_conditions(aust_model, pop_data)
add_element_to_document("Model construction", TextElement(start_text), doc_sections, subsection_name="Population")
start_text

In [None]:
infect_text = model.add_infection(aust_model)
add_element_to_document("Model construction", TextElement(infect_text), doc_sections)
infect_text

In [None]:
prog_text = model.add_progression(aust_model)
add_element_to_document("Model construction", TextElement(prog_text), doc_sections)
prog_text

In [None]:
rec_text = model.add_recovery(aust_model)
add_element_to_document("Model construction", TextElement(rec_text), doc_sections)
rec_text

In [None]:
wane_text = model.add_waning(aust_model)
add_element_to_document("Model construction", TextElement(wane_text), doc_sections)
wane_text

In [None]:
age_strata = list(range(0, 75, 5))
raw_matrix, age_text, raw_matrix_fig,  raw_matrix_name, raw_matrix_caption = model.build_polymod_britain_matrix(age_strata)
add_element_to_document("Age stratification", TextElement(age_text), doc_sections, subsection_name="Mixing")
add_element_to_document("Age stratification", FigElement(raw_matrix_name, caption=raw_matrix_caption), doc_sections, subsection_name="Mixing")
age_text

In [None]:
raw_matrix_fig.layout.title.update({"text": raw_matrix_caption})
raw_matrix_fig

In [None]:
adjusted_matrix, pop_splits, mat_adj_text, \
    input_pop_filename, input_pop_caption, input_pop_fig, modelled_pop_filename, modelled_pop_caption, modelled_pop_fig, \
    matrix_ref_pop_filename, matrix_ref_pop_caption, matrix_ref_pop_fig, \
    adjusted_matrix_filename, adjusted_matrix_caption, adjusted_matrix_fig, modelled_pop = model.adapt_gb_matrix_to_aust(age_strata, raw_matrix, pop_data)
add_element_to_document("Age stratification", TextElement(mat_adj_text), doc_sections, subsection_name="Mixing")
add_element_to_document("Age stratification", FigElement(input_pop_filename, caption=input_pop_caption), doc_sections, subsection_name="Mixing")
add_element_to_document("Age stratification", FigElement(modelled_pop_filename, caption=modelled_pop_caption), doc_sections, subsection_name="Mixing")
add_element_to_document("Age stratification", FigElement(matrix_ref_pop_filename, caption=matrix_ref_pop_caption), doc_sections, subsection_name="Mixing")
add_element_to_document("Age stratification", FigElement(adjusted_matrix_filename, caption=adjusted_matrix_caption), doc_sections, subsection_name="Mixing")
mat_adj_text

In [None]:
input_pop_fig.layout.title.update({"text": input_pop_caption})
input_pop_fig

In [None]:
modelled_pop_fig.layout.title.update({"text": modelled_pop_caption})
modelled_pop_fig

In [None]:
matrix_ref_pop_fig.layout.title.update({"text": matrix_ref_pop_caption})
matrix_ref_pop_fig

In [None]:
age_strat, agestrat_text = model.add_age_stratification(compartments, age_strata, pop_splits, adjusted_matrix)
aust_model.stratify_with(age_strat)
add_element_to_document("Age stratification", TextElement(agestrat_text), doc_sections)
agestrat_text

In [None]:
strain_strata = {
    "ba1": "BA.1",
    "ba2": "BA.2",
    "ba5": "BA.5",
}
strain_strat, strainstrat_text = model.get_strain_stratification(compartments, strain_strata)
aust_model.stratify_with(strain_strat)
add_element_to_document("Strain stratification", TextElement(strainstrat_text), doc_sections, subsection_name="Strain")
strainstrat_text

In [None]:
seed_text = model.seed_vocs(aust_model)
add_element_to_document("Strain stratification", TextElement(seed_text), doc_sections, subsection_name="Strain")
seed_text

In [None]:
reinfect_text = model.add_reinfection(aust_model, strain_strata)
add_element_to_document("Strain stratification", TextElement(reinfect_text), doc_sections, subsection_name="Strain")
reinfect_text

In [None]:
infection_processes = [
    "infection", 
    "early_reinfection",
    "late_reinfection",
]
inc_text = model.add_incidence_output(aust_model, infection_processes)
add_element_to_document("Outputs", TextElement(inc_text), doc_sections)
inc_text

In [None]:
ratio_df = model.add_notifications_output(aust_model)

In [None]:
# Calibration/optimisation settings
priors = [
    UniformPrior("contact_rate", (0.03, 0.06)),
    UniformPrior("infectious_period", (3.0, 7.0)),
    UniformPrior("ba2_escape", (0.3, 0.7)),
    UniformPrior("ba5_escape", (0.3, 0.7)),
    UniformPrior("ba1_seed_time", (645.0, 675.0)),
    UniformPrior("ba2_seed_time", (675.0, 735.0)),
    UniformPrior("ba5_seed_time", (735.0, 825.0)),
    UniformPrior("start_cdr", (0.05, 0.5)),
    UniformPrior("natural_immunity_period", (30.0, 120.0)),
]
smoothed_non_wa_cases_intindex = smoothed_non_wa_cases.copy()
smoothed_non_wa_cases_intindex.index=(smoothed_non_wa_cases.index - aust_model.ref_date).days
def least_squares(modelled, obs, parameters, time_weights):
    return 0.0 - (((modelled - obs) ** 2.0)).sum()
targets = [
    CustomTarget("notifications", smoothed_non_wa_cases_intindex, least_squares),
]
binom_targets = [
    NegativeBinomialTarget("notifications", smoothed_non_wa_cases_intindex, 500.0),
]
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, binom_targets)

In [None]:
if optimise_model:
    print("Optimising with nevergrad \n Progression of loss function values:")
    optim_runner = optimize_model(calibration_model)
    for i in range(10):
        rec = optim_runner.minimize(100)
        print(rec.loss)
    optim_params = rec.value[1]
    parameters.update(optim_params)
    aust_model.run(parameters=parameters)
    print("Best calibration parameters found:")
    optim_params

# Single run point

In [None]:
aust_model.run(parameters=parameters)
axis_labels = {"index": "time", "value": "cases"}
comparison_df = pd.concat((smoothed_non_wa_cases, aust_model.get_derived_outputs_df()["notifications"]), axis=1)
comparison_df.plot(labels=axis_labels, title="Optimised parameter outputs")

In [None]:
parameters

In [None]:
# assert False

In [None]:
# Main calibration loop
iterations = 500
burn_in = 100
n_chains = 10
if new_calibration:
    with pm.Model() as pm_model:
        variables = epm.use_model(calibration_model)
        idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=iterations, tune=0, cores=8, chains=n_chains)
    burnt_idata = idata.sel(draw=range(burn_in, iterations))  # Discard burn-in
    idata.to_netcdf(OUTPUT_PATH / "calibration_out.nc")
else:
    idata = az.from_netcdf(OUTPUT_PATH / "calibration_out.nc")  

In [None]:
# Check parameter starting points by chain
idata.posterior.isel(draw=0).to_dataframe()

In [None]:
# Report acceptance ratios by chain
(idata.sample_stats.accepted.sum(axis=1) / idata.sample_stats.coords["draw"].size).to_dataframe()

In [None]:
priors_table = calib_doc_utils.tabulate_priors(priors, param_info)
add_element_to_document("Calibration", TableElement(priors_table), doc_sections)
priors_table

In [None]:
calib_table = calib_doc_utils.tabulate_param_results(idata, priors, param_info)
add_element_to_document("Calibration", TableElement(calib_table), doc_sections)
calib_table

In [None]:
param_table = calib_doc_utils.tabulate_parameters(parameters, priors, param_info)
add_element_to_document("Parameters", TableElement(param_table, col_requests=[0.25, 0.25, 0.5]), doc_sections)

In [None]:
chains_plot = calib_doc_utils.plot_param_progression(idata, param_info)
save_pyplot_add_to_doc(chains_plot, "chains", "Calibration", doc_sections, caption="Parameter progression and posterior by chain.")

In [None]:
posterior_plot = calib_doc_utils.plot_param_posterior(idata, param_info)
save_pyplot_add_to_doc(posterior_plot, "posterior", "Calibration", doc_sections, caption="Final estimated parameter posteriors.")

In [None]:
sample_plot, sampled_df = calib_doc_utils.plot_sampled_outputs(idata, 5, "notifications", calibration_model, smoothed_non_wa_cases, plot_start_date, end_date)
save_plotly_add_to_doc(sample_plot, "calibration_fit", "Calibration", doc_sections, caption="Sampled model run fits to calibration targets")
sample_plot

In [None]:
# Look at CDR profiles for sampled runs - has to come later down here
model.show_cdr_profiles(sampled_df["start_cdr"], ratio_df)

In [None]:
# Complete the documentation process
supplement = generate_doc("Supplemental Appendix", "austcovid")
compile_doc(doc_sections, supplement)

In [None]:
if on_colab:
    # To build a PDF, we need the appropriate tex packages installed
    ! apt-get -y install texlive-latex-base texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra texlive-bibtex-extra biber
    # To avoid clutter and navigate to the right directory
    import os
    os.chdir("supplement")
    # And finally build the formatted PDF, repeated commands are necessary
    ! pdflatex supplement
    ! biber supplement
    ! pdflatex supplement
    ! pdflatex supplement