In [None]:
from pathlib import Path

In [None]:
# Installation in case running over Colab
try:
    import google.colab
    %pip install estival
    %pip install pylatex==1.4.1
    %pip install kaleido
    ! git clone https://github.com/monash-emu/aust-covid
    %cd aust-covid
    %pip install -e ./
    PROJECT_PATH = Path().resolve()
    import multiprocessing as mp
    mp.set_start_method("forkserver")
except:
    PROJECT_PATH = Path().resolve().parent

DATA_PATH = PROJECT_PATH / "data"
OUTPUT_PATH = PROJECT_PATH / "outputs"
SUPPLEMENT_PATH = PROJECT_PATH / "supplement"

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
from datetime import datetime
import pylatex as pl
from aust_covid.doc_utils import TableElement, add_element_to_document
from pylatex.utils import NoEscape
import pymc as pm
import arviz as az
import yaml
import nevergrad as ng

from estival.model import BayesianCompartmentalModel
from estival.optimization.nevergrad import optimize_model
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget, CustomTarget
from estival.calibration import pymc as epm

from aust_covid.calibration_outputs import \
    plot_param_progression, plot_param_posterior, plot_sampled_outputs, \
    tabulate_param_results, tabulate_priors, tabulate_parameters
from aust_covid.doc_utils import save_pyplot_add_to_doc, \
    save_plotly_add_to_doc, compile_doc

In [None]:
# What do you want to do in this notebook?
optimise_model = False
new_calibration = False

In [None]:
# Data inputs
aust_data = pd.read_csv(DATA_PATH / "Aus_covid_data.csv", index_col="date")
aust_data.index = pd.to_datetime(aust_data.index)

# Extract national
national_data = aust_data[aust_data["region"] == "AUS"]
smoothed_national_cases = national_data["cases"].rolling(window=7).mean().dropna()

# Extract non-WA
non_wa_data = aust_data.loc[(aust_data["region"] != "AUS") & (aust_data["region"] != "WA")]
non_wa_data = non_wa_data.groupby(non_wa_data.index).sum()
smoothed_non_wa_cases = non_wa_data["cases"].rolling(window=7).mean().dropna()

In [None]:
# Set up for manual run with supplementary material document
Path(SUPPLEMENT_PATH).mkdir(parents=True, exist_ok=True)
supplement = pl.Document()
supplement.preamble.append(pl.Package("biblatex", options=["sorting=none"]))
supplement.preamble.append(pl.Command("addbibresource", arguments=["austcovid.bib"]))
supplement.preamble.append(pl.Command("title", "Supplemental Appendix"))
supplement.append(NoEscape(r"\maketitle"))

In [None]:
start_date = datetime(2021, 9, 1)
plot_start_date = datetime(2021, 12, 1)
end_date = datetime(2022, 10, 1)
parameters = {
    "contact_rate": 0.048,
    "infectious_period": 5.0,
    "latent_period": 2.0,
    "cdr": 0.1,
    "seed_rate": 1.0,
    "seed_duration": 1.0,
    "ba1_seed_time": 660.0,
    "ba2_seed_time": 688.0,
    "ba5_seed_time": 720.0,
    "ba2_escape": 0.45,
    "ba5_escape": 0.38,
    "notifs_shape": 2.0,
    "notifs_mean": 4.0,
    "deaths_shape": 2.0,
    "deaths_mean": 20.0,
    "natural_immunity_period": 50.0,
    "ifr_0": 0.0,
    "ifr_5": 0.0,
    "ifr_10": 0.0,
    "ifr_15": 0.0,
    "ifr_20": 0.0,
    "ifr_25": 0.0,
    "ifr_30": 0.0,
    "ifr_35": 0.0,
    "ifr_40": 0.0,
    "ifr_45": 0.0,
    "ifr_50": 0.0,
    "ifr_55": 0.0,
    "ifr_60": 0.0,
    "ifr_65": 0.0,
    "ifr_70": 0.01,
}

In [None]:
from aust_covid.model import build_base_model, get_pop_data, \
    set_model_starting_conditions, add_infection_to_model, \
    add_progression_to_model, add_recovery_to_model, \
    add_waning_to_model, build_polymod_britain_matrix, \
    adapt_gb_matrix_to_aust, add_incidence_output_to_model, \
    add_age_stratification_to_model, get_strain_stratification, \
    seed_vocs, add_reinfection_to_model, add_notifications_output_to_model, \
    track_age_specific_incidence, add_death_output_to_model

In [None]:
# Basic model construction
ref_date = datetime(2019, 12, 31)
compartments = [
    "susceptible",
    "latent",
    "infectious",
    "recovered",
    "waned",
]
aust_model = build_base_model(ref_date, compartments, start_date, end_date, False)
pop_data = get_pop_data()
set_model_starting_conditions(aust_model, pop_data)
add_infection_to_model(aust_model)
add_progression_to_model(aust_model)
add_recovery_to_model(aust_model)
add_waning_to_model(aust_model)
raw_matrix = build_polymod_britain_matrix()
age_strata = list(range(0, 75, 5))
adjusted_matrix, pop_splits = adapt_gb_matrix_to_aust(age_strata, raw_matrix, pop_data)
infection_processes = [
    "infection", 
    "early_reinfection",
    "late_reinfection",
]
add_infection_to_model(aust_model, infection_processes)

# Age stratification
add_age_stratification_to_model(aust_model, compartments, age_strata, pop_splits, adjusted_matrix)

# Strain stratification
strain_strata = {
    "ba1": "BA.1",
    "ba2": "BA.2",
    "ba5": "BA.5",
}
aust_model.stratify_with(get_strain_stratification(compartments, strain_strata))
seed_vocs(aust_model)

# Reinfection (must come after strain stratification)
add_reinfection_to_model(aust_model, strain_strata)

# Outputs (must come after infection and reinfection)
add_incidence_output_to_model(aust_model, infection_processes)
add_notifications_output_to_model(aust_model)
track_age_specific_incidence(aust_model, infection_processes)
add_death_output_to_model(aust_model)

In [None]:
aust_model.run(parameters=parameters)

In [None]:
# Calibration/optimisation settings
with open(PROJECT_PATH / "inputs/parameters.yml", "r") as param_file:
    param_info = yaml.safe_load(param_file)
param_descriptions = param_info["descriptions"]
param_units = param_info["units"]
param_evidence = param_info["evidence"]

priors = [
    UniformPrior("contact_rate", (0.03, 0.06)),
    UniformPrior("infectious_period", (3.0, 7.0)),
    UniformPrior("ba2_escape", (0.3, 0.7)),
    UniformPrior("ba5_escape", (0.3, 0.7)),
    UniformPrior("ba1_seed_time", (645.0, 665.0)),
    UniformPrior("ba2_seed_time", (675.0, 700.0)),
    UniformPrior("ba5_seed_time", (705.0, 730.0)),
    UniformPrior("cdr", (0.05, 0.5)),
]
smoothed_non_wa_cases_intindex = smoothed_non_wa_cases.copy()
smoothed_non_wa_cases_intindex.index=(smoothed_non_wa_cases.index - aust_model.ref_date).days
def least_squares(modelled, obs, parameters, time_weights):
    return 0.0 - (((modelled - obs) ** 2.0)).sum()
targets = [
    CustomTarget("notifications", smoothed_non_wa_cases_intindex, least_squares),
]
binom_targets = [
    NegativeBinomialTarget("notifications", smoothed_non_wa_cases_intindex, 500.0),
]
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, binom_targets)

In [None]:
if optimise_model:
    print("Optimising with nevergrad \n Progression of loss function values:")
    optim_runner = optimize_model(calibration_model)
    for i in range(10):
        rec = optim_runner.minimize(100)
        print(rec.loss)
    optim_params = rec.value[1]
    parameters.update(optim_params)
    aust_model.run(parameters=parameters)
    print("Best calibration parameters found:")
    optim_params

In [None]:
aust_model.run(parameters=parameters)
axis_labels = {"index": "time", "value": "cases"}
comparison_df = pd.concat((smoothed_non_wa_cases, aust_model.get_derived_outputs_df()["notifications"]), axis=1)
comparison_df.plot(labels=axis_labels, title="Optimised parameter outputs")

In [None]:
# Main calibration loop
iterations = 500
burn_in = 100
n_chains = 10
if new_calibration:
    with pm.Model() as model:
        variables = epm.use_model(calibration_model)
        idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=iterations, tune=0, cores=8, chains=n_chains)
    burnt_idata = idata.sel(draw=range(burn_in, iterations))  # Discard burn-in
    idata.to_netcdf(OUTPUT_PATH / "calibration_out.nc")
else:
    idata = az.from_netcdf(OUTPUT_PATH / "calibration_out.nc")  

In [None]:
doc_sections = {}

In [None]:
# Check parameter starting points by chain
idata.posterior.isel(draw=0).to_dataframe()

In [None]:
# Report acceptance ratios by chain
(idata.sample_stats.accepted.sum(axis=1) / idata.sample_stats.coords["draw"].size).to_dataframe()

In [None]:
priors_table = tabulate_priors(priors, param_descriptions)
add_element_to_document("Calibration", TableElement("p{2cm} " * 4, priors_table), doc_sections)
priors_table

In [None]:
calib_table = tabulate_param_results(idata, priors, param_descriptions)
add_element_to_document("Calibration", TableElement("p{1.3cm} " * 7, calib_table), doc_sections)
calib_table

In [None]:
param_table = tabulate_parameters(parameters, param_units, priors, param_descriptions, param_evidence)
add_element_to_document("Parameters", TableElement("p{2.5cm} p{2.5cm} p{5cm} ", param_table), doc_sections)

In [None]:
type(supplement)

In [None]:
chains_plot = plot_param_progression(idata, param_descriptions)
save_pyplot_add_to_doc(chains_plot, "chains", "Calibration", doc_sections, caption="Parameter progression and posterior by chain.")

In [None]:
posterior_plot = plot_param_posterior(idata, param_descriptions, grid_request=(4, 2))
save_pyplot_add_to_doc(posterior_plot, "posterior", "Calibration", doc_sections, caption="Final estimated parameter posteriors.")

In [None]:
sample_plot = plot_sampled_outputs(idata, 5, "notifications", calibration_model, smoothed_non_wa_cases, plot_start_date, end_date)
save_plotly_add_to_doc(sample_plot, "calibration_fit", "Calibration", doc_sections, caption="Sampled model run fits to calibration targets")
sample_plot

In [None]:
# Finish up the supplement document with bibliography
compile_doc(doc_sections, supplement)
supplement.append(pl.NewPage())
supplement.append(pl.Command("printbibliography"))
supplement.generate_tex(str(PROJECT_PATH / "supplement/supplement"))