In [1]:
import pymc as pm
import numpy as np
import pandas as pd 
import multiprocessing as mp
import platform
if platform.system() != "Windows":
    mp.set_start_method('forkserver')

from tb_incubator.constants import set_project_base_path, BURN_IN
from tb_incubator.input import load_param_info
from tb_incubator.calibrate import get_bcm, save_priors
from tb_incubator.utils import get_next_run_number

from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
import arviz as az

from estival.utils.sample import SampleTypes
import nevergrad as ng
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel

pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")
calib_out = project_paths["OUT_PATH"]

In [2]:
xpert_configs = {
    'no_xpert_utilisation_improvement': {
        'xpert_improvement': False
    },
    'xpert_utilisation_improvement': {
        'xpert_imprvement': True
    }
}

params= load_param_info()["value"]
param_info = load_param_info()
covid_effects = {
    'detection_reduction':True
}

In [3]:
def calibrate_with_configs(out_path, params, xpert_configs, covid_effects, draws, tune, num_params):
    for config_name, xpert_config in xpert_configs.items():
        bcm = get_bcm(params, xpert_config, covid_effects=covid_effects)

        run_number = get_next_run_number(out_path, num_params)
        file_suffix = f'{config_name}_p{num_params:02d}_{run_number}'

        def optimize_ng_with_idx(item):
            idx, sample = item
            opt = eng.optimize_model(bcm, budget=500, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=8)
            rec= opt.minimize(500)
            return idx, rec.value[1]

        lhs_samples = bcm.sample.lhs(16, ci=0.67)
        lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
        lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
        opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
        best_opt_samps = bcm.sample.convert(opt_samples_idx)
        init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:8]

        with pm.Model() as model:
            variables = epm.use_model(bcm)
            idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=draws, tune=tune,cores=4,chains=8, initvals=init_samps)
        idata_raw.to_netcdf(str(out_path / f'calib_full_out_{file_suffix}.nc'))
        
        burnt_idata = idata_raw.sel(draw=np.s_[12500:])
        idata_extract = az.extract(burnt_idata, num_samples=15)
        bcm.sample.convert(idata_extract).to_hdf5(out_path / f'calib_extract_out_{file_suffix}.h5')

        spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm) 
        spaghetti_res.results.to_hdf(str(out_path / f"results_{file_suffix}.hdf"), "spaghetti")

        like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
        like_df.to_hdf(str(out_path / f'results_{file_suffix}.hdf'), 'likelihood')

    return idata_raw, file_suffix

In [4]:
idata, file_suffix = calibrate_with_configs(calib_out, params, xpert_configs, covid_effects, 25000, 10000, 17)

Multiprocess sampling (8 chains in 4 jobs)
DEMetropolisZ: [contact_rate, rr_infection_latent, rr_infection_recovered, smear_positive_death_rate, smear_negative_death_rate, smear_positive_self_recovery, smear_negative_self_recovery, screening_scaleup_shape, screening_inflection_time, time_to_screening_end_asymp, notif_start_time, base_diagnostic_capacity, initial_notif_rate, latest_notif_rate, mid_notif_rate, incidence_props_smear_positive_among_pulmonary, genexpert_sensitivity, detection_reduction, notification_dispersion, prevalence_dispersion]


Sampling 8 chains for 10_000 tune and 25_000 draw iterations (80_000 + 200_000 draws total) took 1815 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
Multiprocess sampling (8 chains in 4 jobs)
DEMetropolisZ: [contact_rate, rr_infection_latent, rr_infection_recovered, smear_positive_death_rate, smear_negative_death_rate, smear_positive_self_recovery, smear_negative_self_recovery, screening_scaleup_shape, screening_inflection_time, time_to_screening_end_asymp, notif_start_time, base_diagnostic_capacity, initial_notif_rate, latest_notif_rate, mid_notif_rate, incidence_props_smear_positive_among_pulmonary, genexpert_sensitivity, detection_reduction, notification_dispersion, prevalence_dispersion]


Sampling 8 chains for 10_000 tune and 25_000 draw iterations (80_000 + 200_000 draws total) took 1476 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
