In [None]:
import arviz as az
import pandas as pd
from pathlib import Path
from tbh.paths import REPO_ROOT_PATH

import tbh.plotting as pl

In [None]:
REPO_ROOT_PATH

In [None]:
analysis_path = REPO_ROOT_PATH / "notebooks" / "test_outputs" / "test_full_analysis_scenarios_dfs"

In [None]:
scenarios = ["scenario_1"]
unc_dfs = {
    sc: pd.read_parquet(analysis_path / f"uncertainty_df_{sc}.parquet") for sc in ["baseline"] + scenarios
}
diff_outputs_dfs = {
    sc: pd.read_parquet(analysis_path / f"diff_quantiles_df_{sc}.parquet") for sc in scenarios
}


In [None]:
title_lookup = {
    "tb_incidence": "TB incidence",
    "tb_incidence_per100k": "TB incidence (/100k)",
    "tbi_prevalence_perc": "TB infection prev. (%)"
}
sc_colours = ["black", "crimson"]
unc_sc_colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))


def plot_two_scenarios(axis, uncertainty_dfs, output_name, include_unc=False, include_legend=True):
    ymax = 0.
    for i_sc, scenario in enumerate(["baseline", "scenario_1"]):
        df = uncertainty_dfs[scenario][output_name]
        median_df = df['0.5']
        time = df.index
        
        colour = unc_sc_colours[i_sc]
        label = scenario
        scenario_zorder = 10 if i_sc == 0 else i_sc + 2

        if include_unc:
            axis.fill_between(
                time, 
                df['0.25'], df['0.75'], 
                color=colour, alpha=0.7, 
                edgecolor=None,
                zorder=scenario_zorder
            )
            ymax = max(ymax, df['0.75'].max())
        else:
            ymax = median_df.max()

        axis.plot(time, median_df, color=colour, label=label, lw=1.)
        
    plot_ymax = ymax * 1.1    

    # axis.tick_params(axis="x", labelrotation=45)
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_ylabel(title)
    # axis.set_xlim((model_start, model_end))
    axis.set_ylim((0, plot_ymax))

    if include_legend:
        axis.legend(title="(median and IQR)")


In [None]:
from matplotlib import pyplot as plt 

fig, ax = plt.subplots(1, 1)

plot_two_scenarios(ax, unc_dfs, "tb_incidence_per100k", include_unc=True)

ax.set_xlim((2000, 2050))
ax.set_ylim((0., 1200))

In [None]:
fig, ax = plt.subplots(1, 1)

plot_two_scenarios(ax, unc_dfs, "tbi_prevalence_perc", include_unc=False)

ax.set_xlim((2000, 2050))
ax.set_ylim((0., 60))