In [None]:
from autumn.core.project import get_project, ParameterSet
from matplotlib import pyplot as plt
from autumn.core.plots.utils import REF_DATE, change_xaxis_to_date
import pandas as pd
from autumn.core.inputs.demography.queries import get_population_by_agegroup

In [None]:
project = get_project("sm_sir", "malaysia")

In [None]:
# run baseline model
model_0 = project.run_baseline_model(project.param_set.baseline)
derived_df = model_0.get_derived_outputs_df()

In [None]:
# run scenarios
start_times = [
    sc_params.to_dict()["time"]["start"] for sc_params in project.param_set.scenarios
]
sc_models = project.run_scenario_models(model_0, project.param_set.scenarios, start_times=start_times)

In [None]:
derived_dfs = [m.get_derived_outputs_df() for m in sc_models]

In [None]:
outputs = ["notifications", "cumulative_infection_deaths", "hospital_occupancy", "strain_propXstrain_omicron"]

In [None]:
scenario_names = [
    sc_params.to_dict()["description"] for sc_params in project.param_set.scenarios
]

In [None]:
sc_colors = [ "red","green", "orange", "purple", "cornflowerblue", "red","green", "orange", "purple",
            "red","green", "orange", "purple"]

max_y = {
    "notifications": 6000, "hospital_occupancy": 7000, "icu_occupancy": 1300,
}

scenario_lists = [[3,4,5]]

label_font_size = 18

for i_s, scenario_list in enumerate(scenario_lists):
    
    for output in outputs:
        fig = plt.figure(figsize=(12, 8))
        plt.style.use("ggplot")
        axis = fig.add_subplot()
        axis = derived_df[output].plot(color="black", label="baseline")

        for sc_id in scenario_list:
       
            d = derived_dfs[sc_id - 1]
            
            if output in d.columns:
                d[output].plot(color=sc_colors[sc_id - 1], label=scenario_names[sc_id - 1],linestyle = '--')
        
            if output == "notifications":
                plt.legend(fontsize=label_font_size, facecolor="white", markerscale=3)
        axis.set_ylabel(output.replace("_", " "), fontsize=22)
        axis.tick_params(axis="x", labelrotation=45) 
        plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
        plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
        

## plot voc proportions

In [None]:
voc_outputs = [ "strain_propXstrain_beta", "strain_propXstrain_delta", "strain_propXstrain_omicron"]

In [None]:
label_font_size = 18
fig = plt.figure(figsize=(12, 8))
axis = fig.add_subplot()
lower_value = [0.] * len(derived_df)
for voc_output in voc_outputs:

    
    working_value = derived_df[voc_output]
    upper_value = lower_value + working_value
    axis.fill_between(derived_df.index, lower_value, upper_value, label=voc_output)
    lower_value = upper_value
    
axis.set_ylabel("VoC proportions", fontsize=22)
axis.tick_params(axis="x", labelrotation=45) 
plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
axis.legend(loc='upper right')
plt.style.use("ggplot")

## plot immune proportions

In [None]:
immune_outputs = [ "prop_immune_none", "prop_immune_low", "prop_immune_high"]

In [None]:
label_font_size = 18
fig = plt.figure(figsize=(12, 8))
axis = fig.add_subplot()
lower_value = [0.] * len(derived_df)
for immune_output in immune_outputs:

    
    working_value = derived_df[immune_output]
    upper_value = lower_value + working_value
    axis.fill_between(derived_df.index, lower_value, upper_value, label=immune_output)
    lower_value = upper_value
    
axis.set_ylabel("immune proportions", fontsize=22)
axis.tick_params(axis="x", labelrotation=45) 
plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
axis.legend(loc='upper right')
plt.style.use("ggplot")