In [None]:
import warnings
warnings.filterwarnings("ignore")
import copy
import pandas as pd
from tb_incubator.constants import set_project_base_path, QUANTILES, image_path, ImplementCDR
from tb_incubator.calibrate import (
    plot_output_ranges
)
from tb_incubator.scenario_utils import (
    get_quantile_outputs, 
    load_idata, 
    extract_and_save_idata, 
    run_model_for_scenario, calculate_waic_comparison
)
from tb_incubator.plotting import get_combined_plot, overlay_plots
from tb_incubator.input import load_targets
import arviz as az
project_paths = set_project_base_path("../tb_incubator/")
calib_out = project_paths["OUT_PATH"]
out_path = project_paths['OUTPUTS']

In [None]:
targets = load_targets()
out_req = ["notification_log", "adults_prevalence_pulmonary_log"]
params = {
    "start_population_size": 1.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}

## WAIC

In [None]:
file_suffix = "25000d10000t_01nr"
model_configs = {
    f'no_cdr_{file_suffix}': {
        'apply_diagnostic_capacity': True,
        'xpert_improvement': True,
        'apply_cdr': ImplementCDR.NONE,
    },
    f'cdr_notif_{file_suffix}': { 
        'apply_diagnostic_capacity': True,
        'xpert_improvement': True,
        'apply_cdr': ImplementCDR.ON_NOTIFICATION,
    },
    f'cdr_detect_{file_suffix}': {
        'apply_diagnostic_capacity': True,
        'xpert_improvement': True,
        'apply_cdr': ImplementCDR.ON_DETECTION,
    }
}

In [None]:
# Run this to create idata for the first time
#inference_data_dict = load_idata(calib_out, model_configs)
#extract_and_save_idata(inference_data_dict, calib_out, num_samples=1000)

In [None]:
assump_outputs = run_model_for_scenario(params, calib_out, model_configs, QUANTILES)

In [None]:
waic_results = calculate_waic_comparison(assump_outputs)
waic_results
waic_results.to_csv(out_path / 'results_table_waic.csv', index=True)

## Plots

In [None]:
file_names = [f"no_cdr_{file_suffix}",
              f"cdr_notif_{file_suffix}",
              f"cdr_detect_{file_suffix}"
            ]

In [None]:
#quantile_outputs = get_quantile_outputs(file_names, calib_out)
quantile_outputs = assump_outputs
plot_indicators = {
    'notification': ['notification'],
    'adults_prev': ['adults_prevalence_pulmonary'],
}

plots = {}
for indicator_key, indicator_list in plot_indicators.items():
    plots[indicator_key] = {}
    for name in file_names:
        name_index = file_names.index(name)
        
        if name_index == 0:  
            color = "217,95,2"
        elif name_index == 1: 
            color = "27, 158, 119"  
        else:  
            color = "100, 150, 200" 
        
        plots[indicator_key][name] = plot_output_ranges(
            quantile_outputs[name]['indicator_outputs'], targets, indicator_list, 1, 2010, 2025, 2013,
            show_legend=False, show_target_data=True, show_title=False, colour=color
        )

## Compare all assumptions

In [None]:
scenarios = [
    ('no_cdr_25000d10000t_01nr', 'No CDR'),
    ('cdr_notif_25000d10000t_01nr', 'CDR on notification'),
    ('cdr_detect_25000d10000t_01nr', 'CDR on treatment commencement'),
]

notif_prev = [
    plots[plot_type][scenario_key]
    for plot_type in ['notification', 'adults_prev']
    for scenario_key, _ in scenarios
]

plot_titles = [title for _ in range(2) for _, title in scenarios]

comb_plot = get_combined_plot(
    plot_list=notif_prev,
    n_cols=3,
    subplot_titles=plot_titles,
    shared_yaxes=True,
    shared_xaxes=False,
    horizontal_spacing=0.03,
    vertical_spacing=0.15
)

In [None]:
positions = [(1, 1), (1, 2), (1, 3), 
             (2, 1), (2, 2), (2, 3)
             ]
years_data = [
    (2017, "1st inventory study"),
    (2023, "2nd inventory study")
]

for row, col in positions:
    for year, text in years_data:
        comb_plot.add_vline(x=year, line_dash="dot", row=row, col=col, 
                       line_color="rgba(148, 145, 145, 0.5)", line_width=2,
                       annotation_text=text, annotation_textangle=-90,
                       annotation_position="bottom right",
                       annotation_font_size=8, annotation_font_color="#4E4B4B")

comb_plot

In [None]:
comb_plot.write_image(image_path / f'assump_compare.png', width=1200, height=600,scale=3)