In [None]:
import numpy as np
import pymc as pm
from jax import numpy as jnp
import arviz as az

from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm

from inputs.constants import INPUTS_PATH, SUPPLEMENT_PATH
from autumn.infrastructure.remote import springboard
from aust_covid.inputs import get_ifrs
from aust_covid.model import build_model
from aust_covid.calibration import get_priors, get_targets
from aust_covid.plotting import plot_single_run_outputs
from emutools.tex import DummyTexDoc, StandardTexDoc
from emutools.parameters import load_param_info
from emutools.calibration import param_table_to_tex, round_sigfig, tabulate_priors

In [None]:
param_info = load_param_info()
ifrs = get_ifrs(DummyTexDoc())
param_info['value'].update(ifrs)
parameters = param_info['value'].to_dict()
epi_model = build_model(DummyTexDoc())
priors = get_priors(False, param_info['abbreviations'], DummyTexDoc())
prior_names = [p.name for p in priors]
targets = get_targets(DummyTexDoc())
bcm = BayesianCompartmentalModel(epi_model, parameters, priors, targets)

In [None]:
with pm.Model() as pmc_model:
    start_params = {k: np.clip(v, *bcm.priors[k].bounds(0.99)) for k, v in parameters.items() if k in bcm.priors}
    variables = epm.use_model(bcm)
    map_params = pm.find_MAP(start=start_params, vars=variables, include_transformed=False)
    map_params = {k: float(v) for k, v in map_params.items()}
print('Best calibration parameters found:')
for i_param, param in enumerate([p for p in map_params if '_dispersion' not in p]):
    print(f'   {param}: {round_sigfig(map_params[param], 4)} (within bound {priors[i_param].bounds()}')
parameters.update(map_params)
epi_model.run(parameters=parameters)

In [None]:
plot_single_run_outputs(epi_model, targets)