In [None]:
import multiprocessing as mp
import platform

if platform.system() != "Windows":
    mp.set_start_method('forkserver')

import pandas as pd

from tb_incubator.constants import set_project_base_path
from tb_incubator.input import load_targets, load_param_info
from tb_incubator.calibrate import get_bcm

from estival.wrappers import pymc as epm
import pymc as pm

import arviz as az
from estival.sampling.tools import likelihood_extras_for_idata
from estival.utils.parallel import map_parallel

#pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")


In [2]:
params= load_param_info()["value"]
all_targets = load_targets()

In [3]:
bcm = get_bcm(params)

In [None]:
with pm.Model() as model:
    variables = epm.use_model(bcm)
    idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=2000, tune=2000,cores=16,chains=16)

In [None]:
az.summary(idata)

In [None]:
az.plot_trace(idata, figsize=(16,3.2*len(idata.posterior)),compact=False)

In [None]:
az.plot_posterior(idata)

In [None]:
likelihood_df = likelihood_extras_for_idata(idata, bcm)
likelihood_df

In [None]:
# Examine the performance of chains over time
ldf_pivot = likelihood_df.reset_index(level="chain").pivot(columns=["chain"])
ldf_pivot["logposterior"].plot()

In [None]:
ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()

map_params

In [None]:
bcm.loglikelihood(**map_params), ldf_sorted.iloc[0]["loglikelihood"]

In [12]:
map_res = bcm.run(map_params)

In [None]:
variable = "prevalence"

pd.Series(map_res.derived_outputs[variable]).plot(title = f"{variable} (MLE)")
bcm.targets[variable].data.plot(style='.')

In [None]:
variable = "notification"

pd.Series(map_res.derived_outputs[variable]).plot(title = f"{variable} (MLE)")
bcm.targets[variable].data.plot(style='.')