In [None]:
from pathlib import Path

In [None]:
# Installation in case running over Colab
try:
    import google.colab
    on_colab = True
    %pip install estival
    %pip install pylatex==1.4.1
    %pip install kaleido
    ! git clone https://github.com/monash-emu/aust-covid
    %cd aust-covid
    %pip install -e ./
    PROJECT_PATH = Path().resolve()
    import multiprocessing as mp
    mp.set_start_method("forkserver")
except:
    PROJECT_PATH = Path().resolve().parent
    on_colab = False
optimise_model = True
new_calibration = False

DATA_PATH = PROJECT_PATH / "data"
OUTPUT_PATH = PROJECT_PATH / "outputs"
SUPPLEMENT_PATH = PROJECT_PATH / "supplement"
Path(OUTPUT_PATH).mkdir(parents=True, exist_ok=True)

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
from datetime import datetime
import pymc as pm
import arviz as az

from summer2.functions.time import get_linear_interpolation_function, Function
from summer2.parameters import Parameter
from estival.model import BayesianCompartmentalModel
from estival.optimization.nevergrad import optimize_model
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget, CustomTarget
from estival.calibration import pymc as epm

from aust_covid import model
from general_utils import calibration_utils
from aust_covid.inputs import load_calibration_targets
from general_utils.parameter_utils import load_param_info
from general_utils.doc_utils import TextElement, TableElement, FigElement, add_element_to_document, \
    save_pyplot_add_to_doc, save_plotly_add_to_doc, compile_doc, generate_doc

In [None]:
smoothed_composite_cases, targets_text = load_calibration_targets(datetime(2021, 12, 15))
targets_text

In [None]:
# Times
start_date = datetime(2021, 9, 1)  # Analysis start time
end_date = datetime(2022, 10, 1)  # Analysis end time
plot_start_date = datetime(2021, 12, 1)  # Left end for plots
ref_date = datetime(2019, 12, 31)  # Arbitrary reference date

In [None]:
# Parameters
parameters = {
    "contact_rate": 0.0458, 
    "infectious_period": 6.898, 
    "start_cdr": 0.0997, 
    "ba1_seed_time": 657.4, 
    "natural_immunity_period": 23.5,
    "ba2_seed_time": 700.0,
    "ba2_escape": 0.8,
    "ba5_seed_time": 765., 
    "ba5_escape": 1.0,
    "latent_period": 2.0,
    "seed_rate": 1.0,
    "seed_duration": 1.0,
    "notifs_shape": 2.0,
    "notifs_mean": 4.0,
    "vacc_infect_protect": 0.2,
    "vacc_prop": 0.5,
}
param_info = load_param_info(PROJECT_PATH / "inputs/parameters.yml", parameters)
# param_info

In [None]:
doc_sections = {}
compartments = [
    "susceptible",
    "latent",
    "infectious",
    "recovered",
    "waned",
]
aust_model, build_text = model.build_base_model(ref_date, compartments, start_date, end_date)
# add_element_to_document("Model construction", TextElement(build_text), doc_sections)
# build_text

In [None]:
pop_data, pop_text = model.get_pop_data()
# add_element_to_document("Model construction", TextElement(pop_text), doc_sections, subsection_name="Population")
# pop_text

In [None]:
start_text = model.set_starting_conditions(aust_model, pop_data, adjuster=1.0)
# add_element_to_document("Model construction", TextElement(start_text), doc_sections, subsection_name="Population")
# start_text

In [None]:
infect_text = model.add_infection(aust_model)
# add_element_to_document("Model construction", TextElement(infect_text), doc_sections)
# infect_text

In [None]:
prog_text = model.add_progression(aust_model)
# add_element_to_document("Model construction", TextElement(prog_text), doc_sections)
# prog_text

In [None]:
rec_text = model.add_recovery(aust_model)
# add_element_to_document("Model construction", TextElement(rec_text), doc_sections)
# rec_text

In [None]:
wane_text = model.add_waning(aust_model)
# add_element_to_document("Model construction", TextElement(wane_text), doc_sections)
# wane_text

In [None]:
age_strata = list(range(0, 80, 5))

In [None]:
# raw_mob_df, raw_mob_fig, raw_mob_filename, raw_mob_text = model.get_raw_mobility(plot_start_date, aust_model)
# add_element_to_document("Mobility", TextElement(raw_mob_text), doc_sections)
# raw_mob_text

In [None]:
# add_element_to_document("Mobility", FigElement(raw_mob_filename), doc_sections)
# raw_mob_fig

In [None]:
# mean_smoothed_non_resi_mob, modelled_mob_fig, modelled_mob_filename = model.process_mobility(raw_mob_df, plot_start_date, aust_model)
# add_element_to_document("Mobility", FigElement(modelled_mob_filename), doc_sections)
# modelled_mob_fig

In [None]:
# mobility_adjuster, infection_func, mob_adj_text, mob_effect_fig, mob_effect_filename = model.calculate_mobility_effect(mean_smoothed_non_resi_mob, plot_start_date, aust_model)
# add_element_to_document("Mobility", TextElement(mob_adj_text), doc_sections)
# mob_adj_text

In [None]:
infection_func = Function(lambda: 1.0)

In [None]:
# add_element_to_document("Mobility", FigElement(mob_effect_filename), doc_sections)
# mob_effect_fig

In [None]:
# # This is currently effectively turned off
# mobility_scaling, mob_map_text = model.get_mobility_mapper()
# add_element_to_document("Mobility", TextElement(mob_map_text), doc_sections)
# # mob_map_text

In [None]:
locations = ["school", "home", "work", "other_locations"]
raw_location_matrices = {i: pd.read_csv(DATA_PATH / f"{i}.csv", index_col=0).to_numpy() for i in locations}
input_pop_fig, input_pop_caption, input_pop_filename, modelled_pop_fig, modelled_pop_caption, modelled_pop_filename, \
    matrix_ref_pop_fig, matrix_ref_pop_caption, matrix_ref_pop_filename, \
    adjusted_matrices, pop_splits, mat_adj_text = model.adapt_gb_matrices_to_aust(age_strata, raw_location_matrices, pop_data)
# mixing_matrix = Function(
#     mobility_scaling, 
#     [
#         adjusted_matrices,
#         infection_func,
#     ]
# )

In [None]:
# add_element_to_document("Mixing", TextElement(mat_adj_text), doc_sections)
# mat_adj_text

In [None]:
# input_pop_fig.layout.title.update({"text": input_pop_caption})
# input_pop_fig

In [None]:
# modelled_pop_fig.layout.title.update({"text": modelled_pop_caption})
# add_element_to_document("Mixing", FigElement(modelled_pop_filename, caption=modelled_pop_caption), doc_sections)
# modelled_pop_fig

In [None]:
# matrix_ref_pop_fig.layout.title.update({"text": matrix_ref_pop_caption})
# add_element_to_document("Mixing", FigElement(matrix_ref_pop_filename, caption=matrix_ref_pop_caption), doc_sections)
# matrix_ref_pop_fig

In [None]:
# raw_matrix_filename = "raw_matrices"
# raw_matrix_fig, raw_matrix_fig_text = model.plot_mixing_matrices(raw_location_matrices, locations, age_strata, raw_matrix_filename + ".jpg")
# add_element_to_document("Mixing", FigElement(raw_matrix_filename, caption=raw_matrix_fig_text), doc_sections)
# raw_matrix_fig

In [None]:
# adjusted_matrix_filename = "adjusted_matrices"
# adjusted_matrix_fig, adjusted_matrix_fig_text = model.plot_mixing_matrices(raw_location_matrices, locations, age_strata, adjusted_matrix_filename + ".jpg")
# add_element_to_document("Mixing", FigElement(adjusted_matrix_filename, caption=adjusted_matrix_fig_text), doc_sections)
# adjusted_matrix_fig

In [None]:
mixing_matrix = sum(adjusted_matrices.values())
from plotly import express as px
px.imshow(mixing_matrix)

In [None]:
age_strat, agestrat_text = model.add_age_stratification(compartments, age_strata, pop_splits, mixing_matrix)
aust_model.stratify_with(age_strat)
add_element_to_document("Age stratification", TextElement(agestrat_text), doc_sections)
# agestrat_text

In [None]:
# This is effectively turned off by seeding BA.2 and BA.5 late
strain_strata = ["ba1", "ba2", "ba5"]
strain_strat, strainstrat_text = model.get_strain_stratification(compartments, strain_strata)
aust_model.stratify_with(strain_strat)
# add_element_to_document("Strain stratification", TextElement(strainstrat_text), doc_sections)
# strainstrat_text

In [None]:
seed_text = model.seed_vocs(aust_model)
# add_element_to_document("Strain stratification", TextElement(seed_text), doc_sections)
# seed_text

In [None]:
reinfect_text = model.add_reinfection(aust_model, strain_strata, infection_func)
# add_element_to_document("Strain stratification", TextElement(reinfect_text), doc_sections)
# reinfect_text

In [None]:
infection_processes = [
    "infection", 
    "early_reinfection",
    "late_reinfection",
]
inc_text = model.add_incidence_output(aust_model, infection_processes)
add_element_to_document("Outputs", TextElement(inc_text), doc_sections)
# inc_text

In [None]:
ratio_df, survey_fig, survey_fig_name, survey_fig_caption, ratio_fig, ratio_fig_name, ratio_fig_caption, cdr_description = model.add_notifications_output(aust_model)
# add_element_to_document("Outputs", TextElement(cdr_description), doc_sections, subsection_name="Notifications")
# add_element_to_document("Outputs", FigElement(survey_fig_name, caption=survey_fig_caption), doc_sections, subsection_name="Notifications")
# add_element_to_document("Outputs", FigElement(ratio_fig_name, caption=ratio_fig_caption), doc_sections, subsection_name="Notifications")
# cdr_description

In [None]:
ratio_df.iloc[:] = ratio_df.iloc[0]

In [None]:
# survey_fig.layout.title.update({"text": survey_fig_caption})
# survey_fig

In [None]:
# ratio_fig.layout.title.update({"text": ratio_fig_caption})
# ratio_fig

In [None]:
model.track_sero_prevalence(compartments, aust_model)

In [None]:
strain_prop_text = model.track_strain_prop(strain_strata, aust_model)
# add_element_to_document("Outputs", TextElement(strain_prop_text), doc_sections, subsection_name="Other outputs")
# strain_prop_text

### Vaccination stratification

In [None]:
# from summer2 import Stratification, Multiply
# vacc_strat = Stratification("vaccination", ["vacc", "unvacc"], compartments)
# for infection_process in infection_processes:
#     vacc_strat.set_flow_adjustments(
#         infection_process,
#         {
#             "vacc": Multiply(1.0 - Parameter("vacc_infect_protect")),
#             "unvacc": None,
#         },
#     )
# vacc_strat.set_population_split(
#     {
#         "vacc": Parameter("vacc_prop"),
#         "unvacc": 1.0 - Parameter("vacc_prop"),
#     }
# )
# aust_model.stratify_with(vacc_strat)

In [None]:
from summer2.parameters import DerivedOutput

In [None]:
for comp in compartments:
    aust_model.request_output_for_compartments(comp, [comp])
for process in infection_processes:
    aust_model.request_output_for_flow(process, process, save_results=False)
aust_model.request_function_output("all_infection", DerivedOutput("infection") + DerivedOutput("early_reinfection") + DerivedOutput("late_reinfection"))
aust_model.request_function_output("reproduction_number", DerivedOutput("all_infection") / DerivedOutput("infectious") * Parameter("infectious_period"))

In [None]:
# Truncate targets
smoothed_composite_cases = smoothed_composite_cases.loc[:datetime(2022, 2, 1)]

In [None]:
parameters.update(
    {
        "contact_rate": 0.05,
        "ba1_seed_time": 650.0,
        "vacc_infect_protect": 0.0,
        "vacc_prop": 0.0,
        "start_cdr": 0.5,
        "ba2_seed_time": 1000.0,  # Switch off BA.2
        "ba5_seed_time": 1000.0,  # Switch off BA.5
        "natural_immunity_period": 1e10,  # Switch of waning
    }
)

In [None]:
        "contact_rate": 0.07,
        "infectious_period": 2.0,
        "start_cdr": 0.25,
        "ba1_seed_time": 640.0,
        "latent_period": 1.5,

In [None]:
# Calibration/optimisation settings
priors = [
    UniformPrior("ba1_seed_time", (600.0, 700.0)), 
    UniformPrior("contact_rate", (0.03, 0.1)),
    UniformPrior("infectious_period", (2.0, 3.0)),
    UniformPrior("start_cdr", (0.2, 0.8)),
    UniformPrior("latent_period", (1.0, 2.0)),
    # UniformPrior("ba2_escape", (0.3, 1.0)),
    # UniformPrior("ba5_escape", (0.3, 1.0)),
    # UniformPrior("ba1_seed_time", (640.0, 670.0)),
    # UniformPrior("ba2_seed_time", (680.0, 720.0)),
    # UniformPrior("ba5_seed_time", (750.0, 800.0)),
    # UniformPrior("natural_immunity_period", (20.0, 100.0)),
]
smoothed_composite_cases_intindex = smoothed_composite_cases.copy()
smoothed_composite_cases_intindex.index=(smoothed_composite_cases.index - aust_model.ref_date).days
def least_squares(modelled, obs, parameters, time_weights):
    return 0.0 - (((modelled - obs) ** 2.0)).sum()
targets = [
    CustomTarget("notifications", smoothed_composite_cases_intindex, least_squares),
]
binom_targets = [
    NegativeBinomialTarget("notifications", smoothed_composite_cases_intindex, 500.0),
]
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, binom_targets)

In [None]:
if optimise_model:
    print("Optimising with nevergrad \n Progression of loss function values:")
    optim_runner = optimize_model(calibration_model)
    for i in range(10):
        rec = optim_runner.minimize(100)
        print(rec.loss)
    optim_params = rec.value[1]
    parameters.update(optim_params)
    aust_model.run(parameters=parameters)
    print("Best calibration parameters found:")
    print(optim_params)

In [None]:
# For a video of the scaling of the mixing matrix over time
# model.video_mixing_matrices(aust_model)

### Single run point

In [None]:
parameters

In [None]:
epoch = aust_model.get_epoch()
epoch.datetime_to_number(end_date)

In [None]:
# parameters.update(
#     {
#         "contact_rate": 0.07,
#         "infectious_period": 2.0,
#         "start_cdr": 0.25,
#         "ba1_seed_time": 640.0,
#         "latent_period": 1.5,
#     }
# )

In [None]:
aust_model.run(parameters=parameters)
axis_labels = {"index": "time", "value": "cases"}
comparison_df = pd.concat((smoothed_composite_cases, aust_model.get_derived_outputs_df()["notifications"]), axis=1)
fig = comparison_df.plot(labels=axis_labels, title="Optimised parameter outputs")
fig.update_xaxes(range=(start_date, end_date))
fig

In [None]:
aust_model.get_derived_outputs_df()["reproduction_number"].plot()

In [None]:
cumulative_cases = smoothed_composite_cases.cumsum() / pop_data["Australia"].sum()
fig = cumulative_cases.plot(title="Cumulative proportion ever notified")
fig.update_xaxes(range=(start_date, end_date))

In [None]:
aust_model.get_derived_outputs_df()[compartments].plot.area(title="Distribution of population by compartment")

In [None]:
seropos_prop = aust_model.get_derived_outputs_df()["seropos_prop"]
ever_infected_stats = pd.DataFrame(
    {
        "seropos": seropos_prop,
        "seroneg": 1.0 - seropos_prop,
    },
)
fig = ever_infected_stats.plot.area(title="Ever infected")
fig.update_xaxes(range=(start_date, end_date))
fig.update_yaxes(range=(0.0, 1.0))

In [None]:
strain_prop_fig, strain_prop_fig_name, strain_prop_fig_caption = model.show_strain_props(strain_strata, datetime(2021, 11, 1), aust_model)
add_element_to_document("Outputs", FigElement(strain_prop_fig_name, caption=strain_prop_fig_caption), doc_sections, subsection_name="Sub-variants")
strain_prop_fig.layout.title.update({"text": strain_prop_fig_caption})
strain_prop_fig

In [None]:
exp_param = model.get_param_to_exp_plateau(ratio_df[0], parameters["start_cdr"])
model.get_cdr_values(exp_param, ratio_df).plot()

In [None]:
assert False

In [None]:
# Main calibration loop
iterations = 10000
burn_in = 2000
n_chains = 10
if new_calibration:
    with pm.Model() as pm_model:
        variables = epm.use_model(calibration_model)
        idata_raw = pm.sample(step=[pm.DEMetropolis(variables)], draws=iterations, tune=0, cores=8, chains=n_chains)
    idata_raw.to_netcdf(OUTPUT_PATH / "calibration_out.nc")
else:
    idata_raw = az.from_netcdf(OUTPUT_PATH / "calibration_out.nc")

idata = idata_raw.sel(draw=range(burn_in, iterations))  # Discard burn-in

In [None]:
# Check parameter starting points by chain
idata_raw.posterior.isel(draw=0).to_dataframe()

In [None]:
# Report acceptance ratios by chain
(idata.sample_stats.accepted.sum(axis=1) / idata.sample_stats.coords["draw"].size).to_dataframe()

In [None]:
priors_table = calibration_utils.tabulate_priors(priors, param_info)
add_element_to_document("Calibration", TableElement(priors_table), doc_sections)
priors_table

In [None]:
calib_table = calibration_utils.tabulate_param_results(idata, priors, param_info)
add_element_to_document("Calibration", TableElement(calib_table), doc_sections)
calib_table

In [None]:
param_table = calibration_utils.tabulate_parameters(parameters, priors, param_info)
add_element_to_document("Parameters", TableElement(param_table, col_requests=[0.25, 0.25, 0.5]), doc_sections)

In [None]:
chains_plot = calibration_utils.plot_param_progression(idata, param_info)
save_pyplot_add_to_doc(chains_plot, "chains", "Calibration", doc_sections, caption="Parameter progression and posterior by chain.")

In [None]:
posterior_plot = calibration_utils.plot_param_posterior(idata, param_info)
save_pyplot_add_to_doc(posterior_plot, "posterior", "Calibration", doc_sections, caption="Final estimated parameter posteriors.")

In [None]:
sample_plot, sampled_df = calibration_utils.plot_sampled_outputs(idata, 30, "notifications", calibration_model, smoothed_composite_cases, plot_start_date, end_date)
save_plotly_add_to_doc(sample_plot, "calibration_fit", "Calibration", doc_sections, caption="Sampled model run fits to calibration targets")
sample_plot

In [None]:
modelled_cdr_fig, modelled_cdr_fig_name, modelled_cdr_fig_caption = model.show_cdr_profiles(sampled_df["start_cdr"], ratio_df)
save_plotly_add_to_doc(modelled_cdr_fig, modelled_cdr_fig_name, "Calibration", doc_sections, caption=modelled_cdr_fig_caption)
modelled_cdr_fig

In [None]:
# Complete the documentation process
supplement = generate_doc("Supplemental Appendix", "austcovid")
compile_doc(doc_sections, supplement)

In [None]:
if on_colab:
    # To build a PDF, we need the appropriate tex packages installed
    ! apt-get -y install texlive-latex-base texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra texlive-bibtex-extra biber
    # To avoid clutter and navigate to the right directory
    import os
    os.chdir("supplement")
    # And finally build the formatted PDF, repeated commands are necessary
    ! pdflatex supplement
    ! biber supplement
    ! pdflatex supplement
    ! pdflatex supplement