In [None]:
from pathlib import Path

In [None]:
PROJECT_PATH = Path().resolve().parent
DATA_PATH = PROJECT_PATH / "data"
OUTPUT_PATH = PROJECT_PATH / "outputs"
SUPPLEMENT_PATH = PROJECT_PATH / "supplement"
Path(OUTPUT_PATH).mkdir(parents=True, exist_ok=True)

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
from datetime import datetime
import pylatex as pl
from pylatex.utils import NoEscape
import pymc as pm
import arviz as az
import yaml
import nevergrad as ng

from estival.model import BayesianCompartmentalModel
from estival.optimization.nevergrad import optimize_model
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget, CustomTarget
from estival.calibration import pymc as epm

from aust_covid.inputs import load_household_impacts_data
from aust_covid import model
from documentation import calibration_outputs
from aust_covid.inputs import load_param_info
from documentation.doc_utils import TextElement, TableElement, add_element_to_document, \
    save_pyplot_add_to_doc, save_plotly_add_to_doc, compile_doc, generate_doc

In [None]:
# Data inputs
aust_data = pd.read_csv(DATA_PATH / "Aus_covid_data.csv", index_col="date")
aust_data.index = pd.to_datetime(aust_data.index)

# Extract national
national_data = aust_data[aust_data["region"] == "AUS"]
smoothed_national_cases = national_data["cases"].rolling(window=7).mean().dropna()

# Extract non-WA
non_wa_data = aust_data.loc[(aust_data["region"] != "AUS") & (aust_data["region"] != "WA")]
non_wa_data = non_wa_data.groupby(non_wa_data.index).sum()
smoothed_non_wa_cases = non_wa_data["cases"].rolling(window=7).mean().dropna()

In [None]:
# Times
start_date = datetime(2021, 9, 1)  # Analysis start time
end_date = datetime(2022, 10, 1)  # Analysis end time
plot_start_date = datetime(2021, 12, 1)  # Left end for plots
ref_date = datetime(2019, 12, 31)  # Arbitrary reference date

In [None]:
# Parameters
parameters = {
    "contact_rate": 0.048,
    "infectious_period": 5.0,
    "latent_period": 2.0,
    "cdr": 0.1,
    "seed_rate": 1.0,
    "seed_duration": 1.0,
    "ba1_seed_time": 660.0,
    "ba2_seed_time": 688.0,
    "ba5_seed_time": 720.0,
    "ba2_escape": 0.45,
    "ba5_escape": 0.38,
    "notifs_shape": 2.0,
    "notifs_mean": 4.0,
    "deaths_shape": 2.0,
    "deaths_mean": 20.0,
    "natural_immunity_period": 50.0,
    "ifr_0": 0.0,
    "ifr_5": 0.0,
    "ifr_10": 0.0,
    "ifr_15": 0.0,
    "ifr_20": 0.0,
    "ifr_25": 0.0,
    "ifr_30": 0.0,
    "ifr_35": 0.0,
    "ifr_40": 0.0,
    "ifr_45": 0.0,
    "ifr_50": 0.0,
    "ifr_55": 0.0,
    "ifr_60": 0.0,
    "ifr_65": 0.0,
    "ifr_70": 0.01,
}
param_info = load_param_info(PROJECT_PATH / "inputs/parameters.yml", parameters)
param_info

In [None]:
doc_sections = {}
compartments = [
    "susceptible",
    "latent",
    "infectious",
    "recovered",
    "waned",
]
aust_model, build_text = model.build_base_model(ref_date, compartments, start_date, end_date)
add_element_to_document("Model construction", TextElement(build_text), doc_sections)
build_text

In [None]:
pop_data, pop_text = model.get_pop_data()
pop_text  # Can't add to supplement yet

In [None]:
start_text = model.set_starting_conditions(aust_model, pop_data)
add_element_to_document("Model construction", TextElement(start_text), doc_sections)
start_text

In [None]:
infect_text = model.add_infection(aust_model)
add_element_to_document("Model construction", TextElement(infect_text), doc_sections)
infect_text

In [None]:
prog_text = model.add_progression(aust_model)
add_element_to_document("Model construction", TextElement(prog_text), doc_sections)
prog_text

In [None]:
rec_text = model.add_recovery(aust_model)
add_element_to_document("Model construction", TextElement(rec_text), doc_sections)
rec_text

In [None]:
wane_text = model.add_waning(aust_model)
add_element_to_document("Model construction", TextElement(wane_text), doc_sections)
wane_text

In [None]:
age_strata = list(range(0, 75, 5))
raw_matrix, age_text = model.build_polymod_britain_matrix(age_strata)
add_element_to_document("Model construction", TextElement(age_text), doc_sections)
age_text

In [None]:
adjusted_matrix, pop_splits, mat_adj_text = model.adapt_gb_matrix_to_aust(age_strata, raw_matrix, pop_data)
add_element_to_document("Model construction", TextElement(mat_adj_text), doc_sections)
mat_adj_text

In [None]:
age_strat, agestrat_text = model.add_age_stratification(compartments, age_strata, pop_splits, adjusted_matrix)
aust_model.stratify_with(age_strat)
add_element_to_document("Model stratification", TextElement(agestrat_text), doc_sections)
agestrat_text

In [None]:
strain_strata = {
    "ba1": "BA.1",
    "ba2": "BA.2",
    "ba5": "BA.5",
}
strain_strat, strainstrat_text = model.get_strain_stratification(compartments, strain_strata)
aust_model.stratify_with(strain_strat)
add_element_to_document("Model stratification", TextElement(strainstrat_text), doc_sections)
strainstrat_text

In [None]:
seed_text = model.seed_vocs(aust_model)
add_element_to_document("Model stratification", TextElement(seed_text), doc_sections)
seed_text

In [None]:
reinfect_text = model.add_reinfection(aust_model, strain_strata)
add_element_to_document("Model stratification", TextElement(reinfect_text), doc_sections)
reinfect_text

In [None]:
infection_processes = [
    "infection", 
    "early_reinfection",
    "late_reinfection",
]
inc_text = model.add_incidence_output(aust_model, infection_processes)
add_element_to_document("Outputs", TextElement(inc_text), doc_sections)
inc_text

In [None]:
import numpy as np
import jax.numpy as jnp

from summer2.parameters import Parameter, Function, DerivedOutput, Time, Data
from summer2.functions.time import get_linear_interpolation_function

from aust_covid.model_utils import build_gamma_dens_interval_func, convolve_probability
from aust_covid.model import get_param_to_exp_plateau

In [None]:
aust_epoch = aust_model.get_epoch()

In [None]:
# Get data, using test to symptomatic ratio
hh_impact = load_household_impacts_data()
hh_test_ratio = hh_impact["test_prop"] / hh_impact["sympt_prop"]

In [None]:
start_cdr = Parameter("cdr")

In [None]:
exp_param = get_param_to_exp_plateau(hh_test_ratio[0], start_cdr)

In [None]:
# exp_param = 0.0 - np.log(1.0 - start_cdr) / hh_test_ratio[0]

In [None]:
cdr_values = 1.0 - np.exp(0.0 - exp_param * hh_test_ratio.to_numpy())

In [None]:
time_points = jnp.array(aust_epoch.datetime_to_number(hh_test_ratio.index))

In [None]:
ratio_interp = get_linear_interpolation_function(time_points, cdr_values)

In [None]:
tracked_ratio_interp = aust_model.request_track_modelled_value("ratio_interp", ratio_interp)

In [None]:
delay = build_gamma_dens_interval_func(Parameter("notifs_shape"), Parameter("notifs_mean"), aust_model.times)
notif_dist_rel_inc = Function(convolve_probability, [DerivedOutput("incidence"), delay]) * tracked_ratio_interp
aust_model.request_function_output(name="notifications", func=notif_dist_rel_inc)

In [None]:
aust_model.run(parameters=parameters)
axis_labels = {"index": "time", "value": "cases"}
comparison_df = pd.concat((smoothed_non_wa_cases, aust_model.get_derived_outputs_df()["notifications"]), axis=1)
comparison_df.plot(labels=axis_labels, title="Parameter outputs")