In [None]:
from autumn.core.project import get_project, load_timeseries
from autumn.settings import Region, Models
from matplotlib import pyplot as plt
from autumn.core.plots.utils import REF_DATE
import pandas as pd
from autumn.core.plots.uncertainty.plots import _get_target_values, _plot_targets_to_axis
from summer.utils import ref_times_to_dti
import numpy as np

### Retrieve parameter names

In [None]:

project = get_project(Models.SM_COVID, Region.PHILIPPINES)
param_names = [p['param_name'] for p in project.calibration.all_priors]

### Read parameter values

In [None]:
param_df = pd.read_csv("best_profile.csv", names=param_names + ["objective"], sep=' ', index_col=False)
param_df

### Run model for selected iterations

In [None]:
n_selected_gens = 2

selected_idx = [round(g) for g in np.linspace(0, param_df.index[-1], num=n_selected_gens)]
derived_outputs = {}
for gen_id in selected_idx:
    # get the calibrated parameters
    calib_params = {name: param_df[name].loc[gen_id] for name in param_names}
    params = project.param_set.baseline.update(calib_params, calibration_format=True)
    
    # run the model
    model = project.run_baseline_model(params)
    derived_df = model.get_derived_outputs_df()
    derived_outputs[gen_id] = derived_df


### Plot model outputs against data

In [None]:
# plot the model outputs and the data used as calibration targets
target_t_min, target_t_max = project.calibration.targets[1].data.index.min(), project.calibration.targets[1].data.index.max()
all_targets = load_timeseries(os.path.join(project.get_path(), "timeseries.json"))
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(REF_DATE, all_targets[target].index)  

plt.style.use("ggplot")

outputs_to_plot = ["infection_deaths", "cumulative_infection_deaths", "transformed_random_process"]
x_max = {"infection_deaths": 250, "cumulative_infection_deaths": 60000, "transformed_random_process": 3}
n_outputs = len(outputs_to_plot)

for gen_id in selected_idx:
    fig, axes = plt.subplots(1, n_outputs, figsize=(8. * n_outputs, 6))

    for i_ax, axis in enumerate(axes):
        output = outputs_to_plot[i_ax]
        
        derived_outputs[gen_id][output].plot(ax=axis)
        axis.set_title(output)
        axis.set_ylim([0, x_max[output]])           
        axis.scatter(all_targets[output].index, all_targets[output], color="k", s=5, alpha=0.5, zorder=10)
    
    fig.suptitle(f"Generation {gen_id}", fontsize=15)
    plt.tight_layout()
    plt.savefig(f"gen_{gen_id}")
