In [None]:
from autumn.tools.project import get_project, ParameterSet
from matplotlib import pyplot as plt
from autumn.tools.plots.utils import REF_DATE, change_xaxis_to_date
import pandas as pd
from autumn.tools.inputs.demography.queries import get_population_by_agegroup

In [None]:
project = get_project("sm_sir", "malaysia")
#project = get_project("covid_19", "sri_lanka")

In [None]:
# run baseline model
model_0 = project.run_baseline_model(project.param_set.baseline)
derived_df = model_0.get_derived_outputs_df()

In [None]:
# run scenarios
start_times = [
    sc_params.to_dict()["time"]["start"] for sc_params in project.param_set.scenarios
]
sc_models = project.run_scenario_models(model_0, project.param_set.scenarios, start_times=start_times)

In [None]:
derived_dfs = [m.get_derived_outputs_df() for m in sc_models]

In [None]:
outputs = ["notifications", "hospital_occupancy", "icu_occupancy"]

In [None]:
sc_colors = [ "red","green", "orange", "purple", "cornflowerblue"]

max_y = {
    "notifications": 6000, "hospital_occupancy": 7000, "icu_occupancy": 1300,
}

scenario_lists = [[1,2,3, 4]]

scenario_names = [
    "VoC twice as transmissible as Omicron",
    "VoC more transmissible and completely escapes immunity",
    "VoC completely escapes immunity",
    "Waning vaccine-induced immunity"
] 

label_font_size = 15

for i_s, scenario_list in enumerate(scenario_lists):
    
    for output in outputs:
        fig = plt.figure(figsize=(12, 8))
        plt.style.use("ggplot")
        axis = fig.add_subplot()
        axis = derived_df[output].plot(color="black", label="baseline")

        for sc_id in scenario_list:
            
            d = derived_dfs[sc_id - 1]
            
            if output in d.columns:
                d[output].plot(color=sc_colors[sc_id - 1], label=scenario_names[sc_id - 1],linestyle = '--')
        
      
        axis.set_ylabel(output.replace("_", " "), fontsize=15)
        axis.tick_params(axis="x", labelrotation=45) 
        plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
        plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
        plt.legend(fontsize=12, facecolor="white", markerscale=3)

        

## verify sum(age_specific_prop_infected)=prop_infected

In [None]:
# plt.style.use("ggplot")
# AGEGROUP_STRATA = ['0', '15', '25', '50', '70']
# total_prop_ever_infected = 0
# age_distribution = get_population_by_agegroup(AGEGROUP_STRATA, 'MYS', None)
# index = 0
# for agegroup in AGEGROUP_STRATA:
#     agegroup_string = f"Xagegroup_{agegroup}"
#     name = f"ever_infected{agegroup_string}"
    
 
#     fig = plt.figure(figsize=(12, 8))
#     axis = fig.add_subplot()
#     plt.plot(derived_df[name]/age_distribution[index],label = "prop_ever_infected")
#     print(age_distribution[index])
#     index +=1

    
