In [None]:
import numpy as np
import pymc as pm
from jax import numpy as jnp
import arviz as az

from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.utils.sample import SampleTypes

from inputs.constants import INPUTS_PATH, SUPPLEMENT_PATH
from autumn.infrastructure.remote import springboard
from aust_covid.inputs import get_ifrs
from aust_covid.model import build_model
from aust_covid.calibration import get_priors, get_targets
from emutools.tex import StandardTexDoc
from emutools.inputs import load_param_info
from emutools.calibration import param_table_to_tex, round_sigfig, tabulate_priors

In [None]:
app_doc = StandardTexDoc(SUPPLEMENT_PATH, 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
param_info = load_param_info(INPUTS_PATH / 'parameters.yml')
ifrs = get_ifrs(app_doc)
param_info['value'].update(ifrs)
parameters = param_info['value'].to_dict()

In [None]:
aust_model = build_model(app_doc)

In [None]:
priors = get_priors()
prior_names = [p.name for p in priors]
targets = get_targets(app_doc)

In [None]:
bcm = BayesianCompartmentalModel(aust_model, parameters, priors, targets)

In [None]:
app_doc.include_table(param_table_to_tex(param_info, prior_names), section='Parameters', col_splits=[0.17, 0.15, 0.15, 0.53], longtable=True)
app_doc.include_table(tabulate_priors(priors, param_info), section='Calibration', col_splits=[0.25] * 4)

Following cell checks the calibration runs without error

In [None]:
# with pm.Model() as pm_model:
#     variables = epm.use_model(bcm)
#     idata_local = pm.sample(step=[pm.DEMetropolis(variables)], draws=100, tune=0, cores=3, chains=18, progressbar=True)
# idata_local.to_netcdf('calibration_out.nc')

In [None]:
def get_acceptable_start_params(n_params_target):
    params = []
    while len(params) < n_params_target:
        for new_params in bcm.sample.lhs(n_params_target - len(params), SampleTypes.LIST_OF_DICTS):
            if bcm.run(parameters | new_params, include_extras=True).extras['ll_components']['seropos_ceiling'][0] == 0.0:
                params.append(new_params)
    return params

In [None]:
def run_calibration(bridge: springboard.task.TaskBridge, bcm: BayesianCompartmentalModel):
    import multiprocessing as mp
    mp.set_start_method('forkserver')
    
    n_chains = 8
    starting_params = get_acceptable_start_params(n_chains)

    with pm.Model() as pm_model:
        variables = epm.use_model(bcm)
        idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=5000, tune=2000, cores=8, chains=n_chains, progressbar=False, initvals=starting_params)

    idata_raw.to_netcdf(str(bridge.out_path / 'calibration_out.nc'))
    burnt_idata = idata_raw.sel(draw=np.s_[1000:])
    sds = az.extract(burnt_idata, num_samples=100)
    spaghetti_res = esamp.model_results_for_samples(sds, bcm)
    spaghetti_res.results.to_hdf(str(bridge.out_path / 'results.hdf'), 's')
    like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
    like_df.to_hdf(str(bridge.out_path / 'results.hdf'), 'l')
    bridge.logger.info('Calibration complete')

mspec = springboard.EC2MachineSpec(8, 2, 'compute')
tspec = springboard.TaskSpec(run_calibration, {'bcm': bcm})
run_path = springboard.launch.get_autumn_project_run_path('aust_covid', 'base_case_analysis', 'lhs_in_range_DEMZ_try_again')

In [None]:
aust_covid_commands = [
    'git clone --branch preliminary-optimisation https://github.com/monash-emu/aust-covid',
    'pip install -e ./aust-covid',
]
runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=aust_covid_commands)