# Calibration

In [None]:
import multiprocessing as mp
import platform

if platform.system() != "Windows":
    mp.set_start_method('forkserver')


from tb_incubator.constants import set_project_base_path
from tb_incubator.input import load_param_info
from tb_incubator.calibrate import get_bcm, tabulate_calib_results, plot_posterior_comparison
from tb_incubator.plotting import plot_model_vs_actual

from estival.wrappers import pymc as epm
import pymc as pm

import arviz as az
from estival.sampling.tools import likelihood_extras_for_idata

#pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")


In [2]:
params= load_param_info()["value"]
param_info = load_param_info()

In [3]:
bcm = get_bcm(params)

In [4]:
def calibrate(draws, tune):
    with pm.Model() as model:
        variables = epm.use_model(bcm)
        idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=draws, tune=tune,cores=8,chains=16)

    return idata

In [5]:
idata = calibrate(1500, 1000)

KeyboardInterrupt: 

In [6]:
likelihood_df = likelihood_extras_for_idata(idata, bcm)
ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
map_res = bcm.run(map_params)

In [None]:
plot_posterior_comparison(idata, 0.995)

In [None]:
tabulate_calib_results(idata, param_info)

In [None]:
variable = "notification"
fig = plot_model_vs_actual(map_res.derived_outputs, 
                     bcm.targets[variable].data,
                     variable,
                     variable,
                     "",
                     "Actual data")
#fig.update_xaxes(range=[1998, 2023])
fig



In [None]:
variable = "prevalence"
fig = plot_model_vs_actual(map_res.derived_outputs, 
                     bcm.targets[variable].data,
                     variable,
                     variable,
                     "",
                     "Actual data")
#fig.update_xaxes(range=[1998, 2023])
fig