In [None]:
from autumn.tools.project import get_project, ParameterSet
from matplotlib import pyplot as plt
from autumn.tools.plots.utils import REF_DATE
import pandas as pd

In [None]:
project = get_project("covid_19", "manila")

In [None]:
# run baseline model
model_0 = project.run_baseline_model(project.param_set.baseline)
derived_df = model_0.get_derived_outputs_df()

In [None]:
# run scenarios
start_times = [
    sc_params.to_dict()["time"]["start"] for sc_params in project.param_set.scenarios
]
sc_models = project.run_scenario_models(model_0, project.param_set.scenarios, start_times=start_times)

In [None]:
derived_dfs = [m.get_derived_outputs_df() for m in sc_models]

In [None]:
outputs = ["new_hospital_admissions", "proportion_vaccinated"]

In [None]:
sc_colors = ["blue", "green"]

for output in outputs:
    fig = plt.figure(figsize=(12, 8))
    plt.style.use("ggplot")
    axis = fig.add_subplot()
    # axis = derived_df[output].plot()
    
    for i, d in enumerate(derived_dfs):
       
        if output in d.columns:
            d[output].plot(color=sc_colors[i])
    
    axis.vlines(x=pd.Timestamp("2021-11-01"), ymin=0, ymax=max(d[output]), color="black", linestyle="dashed")
    axis.set_xlim((pd.Timestamp("2021-09-01"), pd.Timestamp("2022-06-01")))
    axis.set_ylabel(output.replace("_", " "), fontsize=15)

# Cumulative TTS calculations

In [None]:
diff_output = {}
for output in ["incidence", "new_hospital_admissions", "new_icu_admissions",  "infection_deaths"]:
    cum_output = [sum(d[output]) for d in derived_dfs]
    diff_output[output] = cum_output[0] - cum_output[1]
    
agegroups = [f"agegroup_{int(5*i)}" for i in range(16)]
diff_output["tts_cases"] = sum([sum(derived_dfs[1][f"tts_casesX{agegroup}"]) for agegroup in agegroups])
diff_output["tts_deaths"] = sum([sum(derived_dfs[1][f"tts_deathsX{agegroup}"]) for agegroup in agegroups])

In [None]:
[ print(f"{name}: {round(value)}") for name, value in diff_output.items()]


In [None]:
fig = plt.figure(figsize=(12, 8))
plt.style.use("ggplot")
axis = fig.add_subplot()

xticks_vals, xticks_labs = [], []
for i, output in enumerate(list(diff_output.keys())):
    val = diff_output[output]
    col = "coral"
    pref = ""
    if output.startswith("tts"):
        col = "blueviolet"
    else:
        pref = "adverted "
        
    axis.hlines(y=-i, xmin=0, xmax=val, linewidth=20, color=col)
    axis.text(x=1.2 * val, y=-i, s=round(val))
    
    
    xticks_vals.append(-i)
    xticks_labs.append(pref + output.replace("_", " "))

axis.set_xscale('log')
axis.set_xlabel("log(N)")
plt.yticks(xticks_vals,xticks_labs, fontsize=13)
plt.show()