In [None]:
import pymc as pm
import numpy as np
import pandas as pd 
import multiprocessing as mp
import platform
if platform.system() != "Windows":
    mp.set_start_method('spawn')

from tb_incubator.constants import set_project_base_path
from tb_incubator.input import load_param_info
from tb_incubator.calibrate import get_bcm, save_priors
from tb_incubator.utils import get_next_run_number

from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
import arviz as az

from estival.utils.sample import SampleTypes
import nevergrad as ng
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel

pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")
calib_out = project_paths["OUT_PATH"]


In [None]:
model_configs = {
    'xpert_only': { # cdr is used outside detection function (only when calculating notification)
        'xpert_improvement': True,
        'apply_cdr_within_model': False,
    },
    'no_xpert_no_cdr': { # cdr inside detection function
        'xpert_improvement': False,
        'apply_cdr_within_model': False,
    },
    'xpert_cdr_inside': {
        'xpert_improvement': True,
        'apply_cdr_within_model': True,
    },
    'cdr_inside': {
        'xpert_improvement': False,
        'apply_cdr_within_model': True,
    },
}

params= load_param_info()["value"]
param_info = load_param_info()
covid_effects = {
    'detection_reduction':True
}

In [None]:
burn_in = 10000
tune = 10000
draws = 25000

In [None]:
def calibrate(out_path, draws, tune, xpert_improvement, covid_effects, apply_cdr_within_model):
    bcm = get_bcm(params, xpert_improvement=xpert_improvement, covid_effects=covid_effects, apply_cdr_within_model=apply_cdr_within_model)
    run_number = get_next_run_number(out_path, draws, tune, model_configs)
    file_suffix = f'{draws}d{tune}t_{run_number}'

    def optimize_ng_with_idx(item):
        idx, sample = item
        opt = eng.optimize_model(bcm, budget=500, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=8)
        rec= opt.minimize(500)
        return idx, rec.value[1]

    lhs_samples = bcm.sample.lhs(16, ci=0.67)
    lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
    lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
    opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
    best_opt_samps = bcm.sample.convert(opt_samples_idx)
    init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:8]

    with pm.Model() as model:
        variables = epm.use_model(bcm)
        idata = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=draws, tune=tune,cores=16,chains=8, initvals=init_samps, mp_ctx="spawn")
    idata.to_netcdf(str(out_path / f'calib_full_out_{file_suffix}.nc'))
    
    burnt_idata = idata.sel(draw=np.s_[burn_in:])
    idata_extract = az.extract(burnt_idata, num_samples=15)
    bcm.sample.convert(idata_extract).to_hdf5(out_path / f'calib_extract_out_{file_suffix}.h5')

    spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm) 
    spaghetti_res.results.to_hdf(path_or_buf = str(out_path / f"results_{file_suffix}.hdf"), key = "spaghetti")

    like_df = esamp.likelihood_extras_for_idata(idata, bcm)
    like_df.to_hdf(path_or_buf= str(out_path / f'results_{file_suffix}.hdf'), key = 'likelihood')

    return idata, file_suffix

In [None]:
def calibrate_with_configs(out_path, params, model_configs, covid_effects, draws, tune):
    for config_name, model_config in model_configs.items():
        bcm = get_bcm(params, **model_config, covid_effects=covid_effects)

        run_number = get_next_run_number(out_path, draws, tune, model_configs)
        file_suffix = f'{config_name}_{draws}d{tune}t_{run_number}'

        def optimize_ng_with_idx(item):
            idx, sample = item
            opt = eng.optimize_model(bcm, budget=500, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=8)
            rec= opt.minimize(500)
            return idx, rec.value[1]

        lhs_samples = bcm.sample.lhs(16, ci=0.67)
        lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
        lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
        opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
        best_opt_samps = bcm.sample.convert(opt_samples_idx)
        init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:8]

        with pm.Model() as model:
            variables = epm.use_model(bcm)
            idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=draws, tune=tune,cores=16,chains=8, initvals=init_samps, mp_ctx="spawn")
        idata_raw.to_netcdf(str(out_path / f'calib_full_out_{file_suffix}.nc'))
        
        burnt_idata = idata_raw.sel(draw=np.s_[burn_in:])
        idata_extract = az.extract(burnt_idata, num_samples=15)
        bcm.sample.convert(idata_extract).to_hdf5(out_path / f'calib_extract_out_{file_suffix}.h5')

        spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm) 
        spaghetti_res.results.to_hdf(path_or_buf = str(out_path / f"results_{file_suffix}.hdf"), key = "spaghetti")

        like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
        like_df.to_hdf(path_or_buf= str(out_path / f'results_{file_suffix}.hdf'), key = 'likelihood')

In [None]:
#idata, file_suffix = calibrate(calib_out, 10000, 5000, 27, xpert_improvement=True, covid_effects=covid_effects)
calibrate_with_configs(calib_out, params, model_configs, covid_effects, draws, tune)