# Calibration

In [1]:
import multiprocessing as mp
mp.set_start_method('forkserver')

In [2]:
import pymc as pm
import numpy as np

from tb_incubator.constants import set_project_base_path
from tb_incubator.input import load_param_info
from tb_incubator.calibrate import get_bcm, tabulate_calib_results, plot_posterior_comparison, plot_spaghetti_calib_comparison
from tb_incubator.plotting import plot_model_vs_actual, display_plot

from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.sampling.tools import likelihood_extras_for_idata
import arviz as az


#pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")
out_path = project_paths["OUT_PATH"]


In [3]:
params= load_param_info()["value"]
param_info = load_param_info()

In [4]:
bcm = get_bcm(params)

In [5]:
from estival.utils.sample import SampleTypes
import nevergrad as ng
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel

def calibrate(draws, tune):
    def optimize_ng_with_idx(item):
        idx, sample = item
        opt = eng.optimize_model(bcm, budget=500, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=8)
        rec= opt.minimize(500)
        return idx, rec.value[1]

    lhs_samples = bcm.sample.lhs(16, ci=0.67)
    lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
    lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
    opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
    best_opt_samps = bcm.sample.convert(opt_samples_idx)
    init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:16]

    with pm.Model() as model:
        variables = epm.use_model(bcm)
        idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=draws, tune=tune,cores=8,chains=16, initvals=init_samps)
    idata.to_netcdf(str(out_path / 'calib_full_out.nc'))
    
    burnt_idata = idata.sel(draw=np.s_[25000:])
    idata_extract = az.extract(burnt_idata, num_samples=15)
    bcm.sample.convert(idata_extract).to_hdf5(out_path / 'calib_extract_out.h5')

    spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm)
    spaghetti_res.results.to_hdf(str(out_path / "results.hdf"), "spaghetti")

    like_df = esamp.likelihood_extras_for_idata(idata, bcm)
    like_df.to_hdf(str(out_path / 'results.hdf'), 'likelihood')

    return idata

In [None]:
idata = calibrate(50000, 10000)

In [7]:
import pandas as pd 
likelihood_df = pd.read_hdf(out_path / 'results.hdf', 'likelihood')
ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
map_res = bcm.run(map_params)

In [None]:
ldf_pivot = likelihood_df.reset_index(level="chain").pivot(columns=["chain"])
ldf_pivot["logposterior"].plot()

In [None]:
plot_posterior_comparison(idata, 0.995)

In [None]:
map_params

In [None]:
from IPython.display import Markdown

results_table = tabulate_calib_results(idata, param_info)
Markdown(results_table.to_markdown())

In [None]:
variable = "notification"
fig = plot_model_vs_actual(map_res.derived_outputs, 
                     bcm.targets[variable].data,
                     variable,
                     variable,
                     "",
                     "Actual data")

#fig.update_xaxes(range=[1998, 2023])
#fig
display_plot(fig, "calib_notification", "svg")



In [None]:
fig.update_xaxes(range=[1998, 2023])
#fig
display_plot(fig, "calib_notification_2000", "svg")

In [None]:
variable = "prevalence"
fig = plot_model_vs_actual(map_res.derived_outputs, 
                     bcm.targets[variable].data,
                     variable,
                     variable,
                     "",
                     "Actual data")
#fig.update_xaxes(range=[1998, 2023])
#fig
display_plot(fig, "calib_prevalence", "svg")

In [None]:
spaghetti = pd.read_hdf(out_path / 'results.hdf', 'spaghetti')
out_req = ["notification", "prevalence"]
fig = plot_spaghetti_calib_comparison(spaghetti, out_req)
fig
display_plot(fig, "calib_spaghetti", "svg")

In [None]:
fig.update_xaxes(range=[1998, 2023])
display_plot(fig, "calib_spaghetti_2000", "svg")