In [None]:
from matplotlib import pyplot
import pandas as pd
import warnings

from summer.utils import ref_times_to_dti

from autumn.core.plots.utils import REF_DATE
from autumn.core.project import get_project
from autumn.core.utils.display import pretty_print
from autumn import wpro_list

In [None]:
pyplot.style.use("ggplot")
warnings.filterwarnings("ignore")

In [None]:

def simulate_WPRO_countries(WPR_country):
    region_name = WPR_country
    project = get_project("WPRO", region_name)
    baseline_params = project.param_set.baseline
    custom_params = project.param_set.baseline.update(
            {
                "time": {
                    "end": 1000.},

            }
    )
    model = project.run_baseline_model(custom_params)
    derived_df = model.get_derived_outputs_df()

    model_start_time = ref_times_to_dti(REF_DATE, [custom_params["time"]["start"]])[0]
    model_end_time = ref_times_to_dti(REF_DATE, [custom_params["time"]["end"]])[0]

    targets_dict = {
    t.data.name: pd.Series(t.data.values, index=ref_times_to_dti(model.ref_date, t.data.index)) for 
    t in project.calibration.targets
    }
    
    return derived_df, targets_dict, model_start_time, model_end_time

In [None]:
outputs_to_plot = ["notifications", "infection_deaths", "cdr", "prop_ever_infected"]
def plot_calibration_results(derived_df, targets_dict, model_start_time, model_end_time ):
    fig = pyplot.figure(figsize=(15, 12))

    for i_out, output in enumerate(outputs_to_plot):
        axis = fig.add_subplot(2, 2, i_out + 1)
        if output in targets_dict:
            targets_dict[output].plot(ax=axis, style='.')
        if output in derived_df:
            derived_df[output].plot(ax=axis)
        axis.set_title(output.replace("_", " "))
        axis.set_xlim([model_start_time, model_end_time])
        
    return fig

In [None]:
wpro_countries = wpro_list["region"]

for wpro_project in wpro_countries:
    region_name = f"wpro_{wpro_project.upper()}"
    derived_df, targets_dict, model_start_time, model_end_time = simulate_WPRO_countries(region_name)
    plot_calibration_results(derived_df, targets_dict, model_start_time, model_end_time )