In [None]:
import numpy as np
import pymc as pm
from jax import numpy as jnp
import arviz as az
from plotly import express as px
import multiprocessing as mp

from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.utils.sample import SampleTypes
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel
import nevergrad as ng

from inputs.constants import INPUTS_PATH, SUPPLEMENT_PATH
from autumn.infrastructure.remote import springboard
from aust_covid.inputs import get_ifrs
from aust_covid.model import build_model
from aust_covid.calibration import get_priors, get_targets
from emutools.tex import StandardTexDoc, DummyTexDoc
from emutools.inputs import load_param_info
from emutools.calibration import param_table_to_tex, round_sigfig, tabulate_priors

In [None]:
app_doc = StandardTexDoc('./', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
param_info = load_param_info(INPUTS_PATH / 'parameters.yml')
ifrs = get_ifrs(app_doc)
param_info['value'].update(ifrs)
parameters = param_info['value'].to_dict()

In [None]:
aust_model = build_model(app_doc, mobility_sens=True)

In [None]:
priors = get_priors()
prior_names = [p.name for p in priors]
targets = get_targets(app_doc)

In [None]:
bcm = BayesianCompartmentalModel(aust_model, parameters, priors, targets)

In [None]:
# We'll use this later inside the remote calibration function to avoid having to pickle up the whole BCM

def get_bcm():
    aust_model = build_model(DummyTexDoc(), mobility_sens=True)
    bcm = BayesianCompartmentalModel(aust_model, parameters, priors, targets)
    return bcm

In [None]:
app_doc.include_table(param_table_to_tex(param_info, prior_names), section='Parameters', col_splits=[0.17, 0.15, 0.15, 0.53], longtable=True)
app_doc.include_table(tabulate_priors(priors, param_info), section='Calibration', col_splits=[0.25] * 4)

Following cell checks the calibration runs without error

In [None]:
def get_acceptable_start_params(n_params_target, ci=1.0):
    params = []
    n_cores = mp.cpu_count()
    while len(params) < n_params_target:
        new_samples = bcm.sample.lhs(n_cores, ci=ci)
        lle = esamp.likelihood_extras_for_samples(new_samples, bcm)
        for sidx, val in lle["ll_seropos_ceiling"].items():
            if (val == 0.0) and (len(params) < n_params_target):
                params.append(new_samples[sidx])
    return bcm.sample.convert(params)

In [None]:
CI = 0.67

start_lhs = get_acceptable_start_params(8, ci=CI)

In [None]:
px.imshow(bcm.sample.distance_matrix(start_lhs))

In [None]:
def optimize_ng(idx_sample):
    idx, sample = idx_sample
    opt = eng.optimize_model(bcm, budget=100, opt_class=ng.optimizers.TwoPointsDE, obj_function=bcm.logposterior, suggested=sample, num_workers=4, ci=CI)
    rec = opt.minimize(100)
    return idx, rec.value[1]

opt_samples = map_parallel(optimize_ng, start_lhs.iterrows(), n_workers=2, mode='process')

#samp.likelihood_extras_for_samples(opt_samples, bcm)

In [None]:
opt_samples = bcm.sample.convert(opt_samples)

In [None]:
px.imshow(bcm.sample.distance_matrix(opt_samples))

In [None]:
def run_calibration(bridge: springboard.task.TaskBridge, init_samples):
    import multiprocessing as mp
    mp.set_start_method('forkserver')
    
    # Put this method in here instead of as an argument
    bcm = get_bcm()
    
    init_samples = bcm.sample.convert(init_samples, 'list_of_dicts')
    
    n_chains = 8
    
#     def get_acceptable_start_params(n_params_target, ci=1.0):
#         params = []
#         n_cores = mp.cpu_count()
#         while len(params) < n_params_target:
#             new_samples = bcm.sample.lhs(n_cores, ci=ci)
#             lle = esamp.likelihood_extras_for_samples(new_samples, bcm)
#             for sidx, val in lle["ll_seropos_ceiling"].items():
#                 if (val == 0.0) and (len(params) < n_params_target):
#                     params.append(new_samples[sidx])
#         return bcm.sample.convert(params)
#     
#     CI = 0.67
#     starting_lhs = get_acceptable_start_params(n_chains, ci=CI)
# 
#     def optimize_ng(sample):
#         opt = eng.optimize_model(bcm, budget=100, opt_class=ng.optimizers.TwoPointsDE, obj_function=bcm.logposterior, suggested=sample, num_workers=4)
#         rec= opt.minimize(100)
#         return rec.value[1]
# 
#     opt_samples = map_parallel(optimize_ng, starting_lhs, n_workers = 8, mode="process")
#     opt_samples = bcm.sample.constrain(opt_samples)

    with pm.Model() as pm_model:
        variables = epm.use_model(bcm)
        idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=10000, tune=10000, cores=8, discard_tuned_samples=False, chains=n_chains, progressbar=False, initvals=init_samples)

    idata_raw.to_netcdf(str(bridge.out_path / 'calibration_out.nc'))
    burnt_idata = idata_raw.sel(draw=np.s_[5000:])
    sds = az.extract(burnt_idata, num_samples=100)
    spaghetti_res = esamp.model_results_for_samples(sds, bcm)
    spaghetti_res.results.to_hdf(str(bridge.out_path / 'results.hdf'), 'sampled_results')
    like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
    like_df.to_hdf(str(bridge.out_path / 'results.hdf'), 'likelihood_extras')
    bridge.logger.info('Calibration complete')

mspec = springboard.EC2MachineSpec(8, 2, 'compute')
tspec = springboard.TaskSpec(run_calibration,{"opt_samples": start_lhs.iloc[0:8].convert('list_of_dicts')})
run_path = springboard.launch.get_autumn_project_run_path('aust_covid', 'opt_experiments', 'lhs_constrained_mob_sens0.67_opt100_10k10k_demz')

In [None]:
aust_covid_commands = [
    'git clone --branch main https://github.com/monash-emu/aust-covid',
    'pip install -e ./aust-covid',
]
runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=aust_covid_commands)