In [None]:
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
from tb_incubator.constants import set_project_base_path, MODEL_TIMES, image_path, ImplementCDR, QUANTILES
from tb_incubator.calibrate import plot_output_ranges, get_bcm
from tb_incubator.scenario_utils import get_quantile_outputs
from tb_incubator.plotting import get_combined_plot
from tb_incubator.input import load_targets
import plotly.graph_objects as go
import arviz as az
import numpy as np
from estival.sampling import tools as esamp
project_paths = set_project_base_path("../tb_incubator/")
calib_out = project_paths["OUT_PATH"]
output_path = project_paths["OUTPUTS"]

In [None]:
targets = load_targets()
out_req = ["notification_log", "adults_prevalence_pulmonary_log"]
file_id = "25000d10000t_01rs"
file_suffix = [f'cdr_notif_{file_id}']

params = {
    "start_population_size": 1.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0
}

idata_raw = az.from_netcdf(calib_out / f'calib_full_out_{file_suffix[0]}.nc')
burnt_idata = idata_raw.sel(draw=np.s_[5000:])
idata_extract = az.extract(burnt_idata, num_samples=1000)
bcm = get_bcm(params, covid_effects=True, apply_diagnostic_capacity=True, xpert_improvement=True, apply_cdr=ImplementCDR.ON_NOTIFICATION)
base_results = esamp.model_results_for_samples(idata_extract, bcm).results
outputs = esamp.quantiles_for_results(base_results, QUANTILES)

In [None]:
outputs.to_csv(output_path / f'results_{file_suffix[0]}_quantile_outputs.csv', index=True)
outputs_transpose = outputs.T
outputs_transpose.to_csv(output_path / f'results_{file_suffix[0]}_quantile_outputs_transpose.csv', index=True)

In [None]:
epi_plots = {}
indicators = ['notification', 'prevalence_smear_positive', 'incidence', 'mortality', 'total_population']

for indicator in indicators:
    epi_plots[indicator] = plot_output_ranges(
        outputs, targets, 
        [f'{indicator}'], 1, 2010, 2035, 2013, show_legend=False, show_title=False, history=False
    )

In [None]:
#prev = epi_plots['prevalence_smear_positive']
inc_h = epi_plots['incidence']
mort_h = epi_plots['mortality']

plot_list = [inc_h, mort_h]

epi_plot = get_combined_plot(plot_list=plot_list, n_cols=2, shared_xaxes=False, shared_yaxes=False, horizontal_spacing=0.1)

milestone = [(1, 163), (2, 10)]
for col, val in milestone:
    epi_plot.add_hline(y=val, line_dash="dot", row=1, col=col,               
                   annotation_text="2025 milestone", 
                   annotation_position="bottom left",
                   annotation_font_color="gray")

epi_plot

#epi_plot.write_image(image_path / f'epi.png', scale=3)

In [None]:
hist = plot_output_ranges(outputs, targets, ['total_population', 'adults_prevalence_pulmonary', 'incidence', 'mortality'], 2, MODEL_TIMES[0], MODEL_TIMES[1], 2040, show_legend=False, show_title=False, history=True)
hist.write_image(image_path / f'hist.png', scale=3)

In [None]:
tot_pop = plot_output_ranges(outputs, targets, 
                             ['total_population'], 1, 2009, 2025, 2040, show_legend=True, show_title=False, legend_name="UN World population estimates")
tot_pop.add_trace(
    go.Scatter(
        x=targets["census_pop"].index,
        y=targets["census_pop"],
        mode="markers",
        marker=dict(
            symbol="circle",
            size=4,
            color="black",
            line=dict(width=1)  
        ), 
        name="Census data",
    )
)

tot_pop.update_layout(
    height=300, 
    width=550,
    showlegend=True,
    margin=dict(
        l=50,   # left margin
        r=50,   # right margin
        t=30,   # top margin
        b=30    # bottom margin
    )
)
#tot_pop.write_image(image_path / f'total_pop.png', scale=3)