In [None]:
from pathlib import Path
from autumn.core.runs import ManagedRun
import arviz as az
import numpy as np
from estival import priors as esp
from general_utils.tex import StandardTexDoc
from aust_covid.inputs import get_ifrs
from general_utils.inputs import load_param_info
from arviz.labels import MapLabeller
PROJECT_PATH = Path().resolve().parent

In [None]:
run_path = 'projects/aust_covid/base_case_analysis/2023-08-29T1613-base_case_try_5000'

In [None]:
mr = ManagedRun(run_path)
mr.remote.download(mr.remote.list_contents()[-1])
idata = az.from_netcdf(mr.list_local()[-1])

In [None]:
idata = idata.sel(draw=np.s_[100:])

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')
parameters = {
    'contact_rate': 0.065,
    'latent_period': 1.8,
    'infectious_period': 2.5,
    'natural_immunity_period': 60.0,
    'start_cdr': 0.3,
    'imm_prop': 0.4,
    'imm_infect_protect': 0.4,
    'ifr_adjuster': 3.0,
    'ba1_seed_time': 619.0,
    'ba2_seed_time': 659.0,
    'ba5_seed_time': 715.0,
    'ba2_escape': 0.4,
    'ba5_escape': 0.54,
    'ba2_rel_ifr': 0.5,
    'wa_reopen_period': 50.0,
    'seed_duration': 10.0,
    'seed_rate': 1.0,
    'notifs_mean': 4.0,
    'notifs_shape': 2.0,
    'deaths_mean': 15.93,
    'deaths_shape': 5.0,
}
ifrs = get_ifrs(app_doc)
parameters.update(ifrs)

In [None]:
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml', parameters | ifrs)

In [None]:
priors = [
    esp.UniformPrior('contact_rate', (0.02, 0.15)),
    esp.GammaPrior.from_mode('latent_period', 2.5, 5.0),
    esp.GammaPrior.from_mode('infectious_period', 3.5, 6.0),
    esp.GammaPrior.from_mode('natural_immunity_period', 180.0, 1000.0),
    esp.UniformPrior('start_cdr', (0.1, 0.6)),
    esp.UniformPrior('imm_prop', (0.0, 1.0)),
    esp.UniformPrior('imm_infect_protect', (0.0, 1.0)),
    esp.TruncNormalPrior('ifr_adjuster', 1.0, 2.0, (0.2, np.inf)),
    esp.UniformPrior('ba1_seed_time', (580.0, 620.0)), 
    esp.UniformPrior('ba2_seed_time', (620.0, 660.0)),
    esp.UniformPrior('ba5_seed_time', (660.0, 740.0)),
    esp.BetaPrior.from_mean_and_ci('ba2_escape', 0.4, (0.2, 0.6)),
    esp.BetaPrior.from_mean_and_ci('ba5_escape', 0.4, (0.2, 0.6)),
    esp.TruncNormalPrior('ba2_rel_ifr', 0.7, 0.15, (0.2, np.inf)),
    esp.UniformPrior('wa_reopen_period', (30.0, 90.0)),
    esp.GammaPrior.from_mean('deaths_mean', 15.93, 18.79),
]

In [None]:
figure = az.plot_density(
    idata, 
    var_names=[p.name for p in priors], 
    shade=0.5, 
    labeller=MapLabeller(var_name_map=param_info['descriptions'].to_dict()), 
    point_estimate=None,
    hdi_prob=0.99,
)
for i_ax, ax in enumerate(figure.ravel()[:len(priors)]):
    ax_limits = ax.get_xlim()
    x_vals = np.linspace(ax_limits[0], ax_limits[1], 100)
    y_vals = priors[i_ax].pdf(x_vals)
    ax.fill_between(x_vals, y_vals, color='k', alpha=0.2, linewidth=2)