In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import matplotlib as mpl
import pandas as pd
import matplotlib.patches as mpatches

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis

from autumn.dashboards.calibration_results.plots import get_uncertainty_df

In [None]:
# Specify model details
model = Models.TB
region = Region.MARSHALL_ISLANDS
dirname = "2021-10-11"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
scenario_list = uncertainty_df['scenario'].unique()

# make output directories
output_dir = f"{model}_{region}_{dirname}"
base_dir = os.path.join("outputs", output_dir)
os.makedirs(base_dir, exist_ok=True)
dirs_to_make = ["calibration", "MLE", "median", "uncertainty", "csv_files"]
for dir_to_make in dirs_to_make:
    os.makedirs(os.path.join(base_dir, dir_to_make), exist_ok=True)

In [None]:
year_mask = uncertainty_df['time'] == 2050
uncertainty_df = uncertainty_df[year_mask]

In [None]:
scenarios = {
    "counterfactual": 1,
    "10": 8,
    "5": 7,
    "2": 6    
}
outputs = ["abs_diff_cumulative_diseased", "abs_diff_cumulative_deaths", "abs_diff_cumulative_pt", "abs_diff_cumulative_pt_sae"] # "percentage_latentXlocation_majuro"]
quantiles = [0.5, 0.025, .975]

columns = ["scenario"]
for output in outputs:
    columns += [f"{output[20:]}_{q}" for q in quantiles]


In [None]:
df = pd.DataFrame(columns=columns)

q_swap = {
    0.5: 0.5,
    0.975: 0.025,
    0.025: 0.975
}

for sc, sc_id in scenarios.items():
    row = {"scenario": sc}
    sc_sign = 1 if sc == "counterfactual" else -1
    for output in outputs:
        output_sign = -sc_sign  if output.startswith("abs_diff_cumulative_pt") else sc_sign
        for q in quantiles:
            
            q_mask = q if output_sign == 1. else q_swap[q]
            
            mask = (uncertainty_df["type"] == output) & (uncertainty_df["quantile"] == q_mask) & (uncertainty_df["scenario"] == sc_id)
            
            row[f"{output[20:]}_{q}"] = (round(float(uncertainty_df[mask]["value"]))) * output_sign
    df = df.append(row, ignore_index=True)
    

In [None]:
df

In [None]:
title_gap=.5
bar_w = .3
title_ft = 13
colors = {
    "diseased": "cornflowerblue", 
    "deaths": "coral", 
    "pt_sae": "seagreen"
}
y_offset = {
    "diseased": - bar_w, 
    "deaths": 0, 
    "pt_sae": - bar_w/2
}
legend_names = {
    "diseased": "Active TB episodes averted", 
    "deaths": "TB deaths averted",
    "pt_sae": "Serious adverse events"
}

x_max = 6500
y_ticks = [1, 2, 3, 3+2*title_gap+.5]
y_labels = ["TB and LTBI screening\nevery 2 years", "TB and LTBI screening\nevery 5 years", "TB and LTBI screening\nevery 10 years", "Status-quo including\n2017-2018 interventions"]
scenarios = ["2", "5", "10", "counterfactual"]

x_ticks = [-6000, -4000, -2000, 0, 2000, 4000, 6000]
x_labels = [abs(x) for x in x_ticks]


fig = plt.figure(figsize=(12, 8))
plt.style.use("default")
axis = fig.add_subplot()
plt.grid(axis = 'x', lw=.3, zorder=0)

outputs = ["diseased", "deaths", "pt", "pt_sae"]
new_df = pd.DataFrame(columns=["scenario"] + outputs)
for i_sc, scenario in enumerate(scenarios):
    y = y_ticks[i_sc]
    row = {'scenario': scenario}
    for output in outputs:       
        
        sign = 1 if output == "pt_sae" else -1

        median = sign * float(df[df["scenario"] == scenario][f"{output}_0.5"])
        low =  sign * float(df[df["scenario"] == scenario][f"{output}_0.025"])
        high =  sign * float(df[df["scenario"] == scenario][f"{output}_0.975"])
        if output != "pt":
            if i_sc == 0:
                rect = mpatches.Rectangle((0, y + y_offset[output]), width=median, height=bar_w, facecolor=colors[output], zorder=2, label=legend_names[output])
            else:
                rect = mpatches.Rectangle((0, y + y_offset[output]), width=median, height=bar_w, facecolor=colors[output], zorder=2)

            axis.add_patch(rect)    

            plt.hlines(y=y + y_offset[output] + bar_w/2, xmin=low, xmax=high, linewidth=1.5, zorder=4, color="black")

        # store in reformated new_df
        row[output] = f"{abs(round(median))} ({abs(round(low))}-{abs(round(high))})"
    
    new_df = new_df.append(row, ignore_index=True)
    
plt.vlines(x=0,ymin=0, ymax= 3 + 2*bar_w, dashes="dashed")
plt.vlines(x=0,ymin=3 + 2*title_gap, ymax= 3 + 2*title_gap + .5 + 2*bar_w, dashes="dashed")

plt.text(x=0, y=3 + title_gap*3 + .7, s="Reference: No-screening counterfactual", ha="center", fontsize = title_ft)
plt.text(x=0, y=3 + title_gap+.2, s="Reference: Status-quo including 2017-2018 interventions", ha="center", fontsize = title_ft)

plt.ylim(0.5, 3 + title_gap*3 + .8)
plt.xlim(-x_max, x_max)

plt.xticks(ticks=x_ticks, labels=x_labels, fontsize=title_ft)
plt.yticks(ticks=y_ticks, labels=y_labels, fontsize=title_ft)
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(True)
axis.spines['left'].set_visible(False)
plt.tick_params(left = False)

handles, labels = plt.gca().get_legend_handles_labels()
order = [1, 0, 2]
o_handles = [handles[o] for o in order]
o_labels = [labels[o] for o in order]
plt.legend(o_handles, o_labels, facecolor="whitesmoke", fontsize=title_ft*1, loc=(.637,.4))

filename = "screening_benefit_risk.png"
plt.tight_layout()
plt.savefig(filename,dpi=300)

In [None]:
filename = "screening_benefit_risk.csv"
new_df.to_csv(filename)