In [None]:
from autumn.tools.project import get_project, ParameterSet
from matplotlib import pyplot as plt
import pandas as pd
from autumn.projects.tuberculosis.marshall_islands.project import ANALYSIS
from autumn.projects.tuberculosis.marshall_islands.utils import SA_PARAM_VALUES

from autumn.projects.tuberculosis.marshall_islands.outputs.utils import OUTPUT_TITLES

In [None]:
project = get_project("tuberculosis", "marshall-islands")
assert ANALYSIS != "main"

In [None]:
# run baseline model
model_0 = project.run_baseline_model(project.param_set.baseline)
derived_df = model_0.get_derived_outputs_df()

In [None]:
# run scenarios
start_times = [
    sc_params.to_dict()["time"]["start"] for sc_params in project.param_set.scenarios
]
sc_models = project.run_scenario_models(model_0, project.param_set.scenarios, start_times=start_times)

In [None]:
derived_dfs = [m.get_derived_outputs_df() for m in sc_models]

In [None]:
outputs = ["incidence", "mortality", "percentage_latent", "notifications"]

In [None]:
n_scenarios = len(SA_PARAM_VALUES[ANALYSIS])
low, high = 100, 150 
gradient_indices = [low + int((high - low)*i) for i in range(n_scenarios)]

sc_colors = {
    "sa_importation": [plt.get_cmap("Greens")(j) for j in gradient_indices],
    "sa_screening": [plt.get_cmap("Reds")(j) for j in gradient_indices]
}

legend_titles = {
    "sa_importation": "% LTBI among immigrants:",
    "sa_screening": "Screening sensitivity:"
}

n_col = 2
n_row = 2
panel_h = 5
panel_w = 7
widths = [panel_w] * n_col
heights = [panel_h] * n_row
fig = plt.figure(constrained_layout=True, figsize=(sum(widths), sum(heights)))  # (w, h)
spec = fig.add_gridspec(ncols=n_col, nrows=n_row, width_ratios=widths, height_ratios=heights)
plt.style.use("ggplot")

i_row, i_col = 0, 0
for output in outputs:
    axis = fig.add_subplot(spec[i_row, i_col])   
    max_val = 0
    for i, d in enumerate(derived_dfs):
       
        if output in d.columns:
            label = f"{int(100.*SA_PARAM_VALUES[ANALYSIS][i])}%" 
            d[output].plot(color=sc_colors[ANALYSIS][i], lw=3, label=label)
            max_val = max(max_val, max(d[output]))
    
    axis.set_ylim([0, max_val*1.1])
    axis.set_xlim([2017, 2050])
    axis.set_ylabel(OUTPUT_TITLES[output], fontsize=15)
    
    i_col += 1
    if i_col == n_col:
        i_col = 0
        i_row += 1
        
    if output == "mortality":
        
        handles, labels = plt.gca().get_legend_handles_labels()
        if ANALYSIS == "sa_importation":
            handles.reverse()
            labels.reverse()
        plt.legend(handles, labels, title=legend_titles[ANALYSIS], facecolor="white", fontsize=12, title_fontsize=12)
        
png_path = f"{ANALYSIS}_.png"
plt.savefig(png_path, dpi=300)