In [None]:
import numpy as np
import pymc as pm
import arviz as az
from plotly import express as px
import multiprocessing as mp

from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel
import nevergrad as ng

from inputs.constants import DATA_PATH
from autumn.infrastructure.remote import springboard
from aust_covid.inputs import get_ifrs
from aust_covid.model import build_model
from aust_covid.calibration import get_priors, get_targets
from emutools.tex import DummyTexDoc
from emutools.parameters import load_param_info

In [None]:
def get_bcm(sens):
    app_doc = DummyTexDoc()
    
    aust_model = build_model(app_doc, mobility_sens=sens['mob'], vacc_sens=sens['vacc'])
    
    param_info = load_param_info()
    ifrs = get_ifrs(app_doc)
    param_info['value'].update(ifrs)
    parameters = param_info['value'].to_dict()
    
    priors = get_priors()
    prior_names = [p.name for p in priors]
    
    targets = get_targets(app_doc)

    return BayesianCompartmentalModel(aust_model, parameters, priors, targets)

local_bcm = get_bcm(sens={'mob': False, 'vacc': False})

In [None]:
def get_acceptable_start_params(bcm, n_params_target, ci=1.0):
    params = []
    n_cores = mp.cpu_count()
    while len(params) < n_params_target:
        new_samples = bcm.sample.lhs(n_cores, ci=ci)
        lle = esamp.likelihood_extras_for_samples(new_samples, bcm)
        for sidx, val in lle["ll_seropos_ceiling"].items():
            if (val == 0.0) and (len(params) < n_params_target):
                params.append(new_samples[sidx])
    return bcm.sample.convert(params)

CI = 0.67
start_lhs = get_acceptable_start_params(local_bcm, 8, ci=CI)
px.imshow(local_bcm.sample.distance_matrix(start_lhs))

In [None]:
def calibrate(out_path, sens, draws, tune, init_samples):
    remote_bcm = get_bcm(sens)
    
    n_opti_draws = 100
    
    def optimize_ng(idx_sample):
        idx, sample = idx_sample
        opt = eng.optimize_model(remote_bcm, budget=n_opti_draws, opt_class=ng.optimizers.TwoPointsDE, obj_function=remote_bcm.logposterior, suggested=sample, num_workers=4, ci=CI)
        rec = opt.minimize(n_opti_draws)
        return idx, rec.value[1]

    opt_samples = map_parallel(optimize_ng, start_lhs.iterrows(), n_workers=2, mode='process')
    opt_samples = remote_bcm.sample.convert(opt_samples)
    
    n_chains = 8
    n_samples = 100
    with pm.Model() as pm_model:
        variables = epm.use_model(remote_bcm)
        idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=draws, tune=tune, cores=8, discard_tuned_samples=False, chains=n_chains, progressbar=False, initvals=init_samples)
    idata_raw.to_netcdf(str(out_path / 'calib_full_out.nc'))
    
    burnt_idata = idata_raw.sel(draw=np.s_[burn:])
    idata_extract = az.extract(burnt_idata, num_samples=n_samples)
    
    remote_bcm.sample.convert(idata_extract).to_hdf5(out_path/ "calib_extract_out.h5")
    
    spaghetti_res = esamp.model_results_for_samples(idata_extract, remote_bcm)
    spaghetti_res.results.to_hdf(str(out_path / 'results.hdf'), 'spaghetti')

    like_df = esamp.likelihood_extras_for_idata(idata_raw, remote_bcm)
    like_df.to_hdf(str(out_path / 'results.hdf'), 'likelihood')

In [None]:
def run_calibration(bridge: springboard.task.TaskBridge, sens, draws, tune, init_samples):
    import multiprocessing as mp
    mp.set_start_method('forkserver')
    idata_raw = calibrate(bridge.out_path, sens, draws, tune, init_samples)
    bridge.logger.info('Calibration complete')

In [None]:
# Run the following to check basic calibration algorithm runs locally
# calibrate(DATA_PATH, {'mob': False, 'vacc': False}, draws, tune, init_samples=opt_samples.iloc[0: 8].convert('list_of_dicts'))

In [None]:
sens_analyses = {
    'none': {
        'mob': False,
        'vacc': False,
    },
    'mob': {
        'mob': True,
        'vacc': False,
    },
    'vacc': {
        'mob': False,
        'vacc': True,
    },
    'both': {
        'mob': True,
        'vacc': True,
    },
}
runners = {}
draws = 20000
tune = 10000
burn = 5000

aust_covid_commands = [
    'git clone --branch revise-calibration-code https://github.com/monash-emu/aust-covid',
    'pip install -e ./aust-covid',
]

mspec = springboard.EC2MachineSpec(8, 2, 'compute')
for sens_name, analysis in sens_analyses.items():
    run_str = f'{sens_name}-d{int(draws / 1000)}k-t{int(tune / 1000)}k-b{int(burn / 1000)}k'
    tspec = springboard.TaskSpec(run_calibration, 
                                 {'draws': draws, 'tune': tune, 'init_samples': opt_samples.iloc[0: 8].convert('list_of_dicts'), 'sens': analysis})
    run_path = springboard.launch.get_autumn_project_run_path('aust_covid', 'alternate_analyses', run_str)
    runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=aust_covid_commands)
    runners[sens_name] = runner