In [1]:
import numpy as np
import pandas as pd

from tb_incubator.constants import set_project_base_path
from tb_incubator.input import load_param_info
from tb_incubator.calibrate import get_bcm, tabulate_calib_results, plot_posterior_comparison, plot_spaghetti_calib_comparison
from tb_incubator.plotting import plot_model_vs_actual, display_plot
import arviz as az


#pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")
calib_out = project_paths["OUT_PATH"]


In [2]:
params= load_param_info()["value"]
param_info = load_param_info()
bcm = get_bcm(params, xpert_sensitivity=True, covid_effects=True)

In [3]:
file_suffix = "p06_54"
idata = az.from_netcdf(calib_out / f'calib_full_out_{file_suffix}.nc')

In [4]:
likelihood_df = pd.read_hdf(calib_out / f'results_{file_suffix}.hdf', 'likelihood')
ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
map_res = bcm.run(map_params)

In [None]:
map_params

In [None]:
az.plot_trace(idata, figsize=(15, 22), compact=False, legend=False)

In [None]:
from IPython.display import Markdown

results_table = tabulate_calib_results(idata, param_info)
Markdown(results_table.to_markdown())

In [None]:
plot_posterior_comparison(idata, 0.995)

In [None]:
variable = "notification"
fig = plot_model_vs_actual(map_res.derived_outputs, 
                     np.exp(bcm.targets["notification_log"].data),
                     variable,
                     variable,
                     "",
                     "Target data")
fig.update_xaxes(range=[1998, 2024])

In [None]:
variable = "prevalence"
fig = plot_model_vs_actual(map_res.derived_outputs, 
                     np.exp(bcm.targets["prevalence_log"].data),
                     variable,
                     variable,
                     "",
                     "Target data")
#fig.update_xaxes(range=[1998, 2023])
fig

In [None]:
spaghetti = pd.read_hdf(calib_out / f'results_{file_suffix}.hdf', 'spaghetti')
out_req = ["notification_log", "prevalence_log"]
fig = plot_spaghetti_calib_comparison(spaghetti, out_req)
fig

#display_plot(fig, "calib_spaghetti", "svg")

In [None]:
fig.update_xaxes(range=[1998, 2024])
#display_plot(fig, "calib_spaghetti_2000", "svg")

In [None]:
map_params