In [None]:
import pandas as pd
from IPython.display import Markdown
import plotly.graph_objects as go
import warnings
warnings.filterwarnings("ignore")

import tb_incubator.constants as const
from tb_incubator.constants import set_project_base_path
from tb_incubator.model import build_model
from tb_incubator.plotting import plot_tracked_outputs
from tb_incubator.input import load_targets, load_param_info
from tb_incubator.utils import get_param_table
import arviz as az
az.style.use("arviz-doc")

pd.options.plotting.backend = "plotly"  
project_paths = set_project_base_path("../tb_incubator/")
calib_out = project_paths["OUT_PATH"]

## Build epi model

In [None]:
targets = load_targets()
indicator_names = const.indicator_names
compartments = const.COMPARTMENTS
param_info = load_param_info()

# Load fixed parameters
fixed_params = param_info["value"]
covid_effects = {
    'detection_reduction':True
}

In [None]:
# Load best parameter set from calibration
file_suffix = "xpert_utilisation_improvement_p17_96"
idata = az.from_netcdf(calib_out / f'calib_full_out_{file_suffix}.nc')
likelihood_df = pd.read_hdf(calib_out / f'results_{file_suffix}.hdf', 'likelihood')
ldf_sorted = likelihood_df.sort_values(by="logposterior",ascending=False)
map_params = idata.posterior.to_dataframe().loc[ldf_sorted.index[0]].to_dict()
map_params_filtered = {k: v for k, v in map_params.items() 
                      if "_dispersion" not in k}
fixed_params.update(map_params_filtered)

In [None]:
params = {}
acf_screening_rate = {
    2024.0 : 0.0,
    2025.0 : 0.0,
    2026.0 : 0.22,
    2027.0 : 0.22,
    2027.01: 0.0
}

In [None]:
model, desc = build_model(fixed_params, xpert_improvement=True, covid_effects=covid_effects, acf_screening_rate=acf_screening_rate)

In [None]:
model.run(fixed_params | params)

## Outputs

In [None]:
outs = model.get_derived_outputs_df()

In [None]:
outs["total_population"].plot()

In [None]:
outs[[f'total_populationXage_{age}' for age in [0, 5, 15, 35, 50]]].plot()

In [None]:
outs['incidence'].plot()

In [None]:
outs['prevalence_pulmonary'].plot()

In [None]:
fig = outs['notification'].plot()
fig.update_xaxes(range=[2000,2035])

In [None]:
tracked_outputs = ["base_detection", "diagnostic_capacity", "diagnostic_improvement", "final_detection"]
plot_tracked_outputs(outs, tracked_outputs)

In [None]:
import plotly.express as px

compartments = const.COMPARTMENTS
df = outs[[f"prop_{comp}" for comp in compartments]]
fig = px.area(df, x=df.index, y=[f"prop_{comp}" for comp in compartments])
fig.update_xaxes(range=[1970,2025])

In [None]:
#| label: tbl-params
#| tbl-cap: Model parameters
#| tbl-cap-location: top

#prior_names = [p.name for p in get_all_priors()]
#fixed_param_table = get_param_table(param_info, prior_names)
#fixed_param_table
#Markdown(fixed_param_table.to_markdown())
