In [None]:
from matplotlib import pyplot
import pandas as pd
import warnings
import copy

from summer.utils import ref_times_to_dti

from autumn.tools.project import get_project, load_timeseries, build_rel_path
from autumn.settings.constants import COVID_BASE_DATETIME
from autumn.tools.utils.pretty import pretty_print

In [None]:
pyplot.style.use("ggplot")
warnings.filterwarnings("ignore")
region = "bangladesh"

In [None]:
# Be careful of this
from autumn.projects.sm_sir.bangladesh.bangladesh import project

In [None]:
def convert_ts_index_to_date(ts):
    ts.index = ref_times_to_dti(COVID_BASE_DATETIME, ts.index)
    return ts

In [None]:
ts_sets = project.ts_set
ts_set_dates = {k: convert_ts_index_to_date(v) for k, v in ts_sets.items()}

In [None]:
project = get_project("sm_sir", region, reload=True)
baseline_params = project.param_set.baseline
model = project.run_baseline_model(baseline_params)
derived_df = model.get_derived_outputs_df()
model_start_time = ref_times_to_dti(COVID_BASE_DATETIME, [baseline_params["time"]["start"]])[0]
model_end_time = ref_times_to_dti(COVID_BASE_DATETIME, [baseline_params["time"]["end"]])[0]
death_string = "infection_deaths"
outputs_to_plot = ["notifications", "infection_deaths", "prop_ever_infected", "hospital_admissions"]

plot_end_time = ref_times_to_dti(COVID_BASE_DATETIME, [1000.])[0]

fig = pyplot.figure(figsize=(15, 12))
for i_out, output in enumerate(outputs_to_plot):
    axis = fig.add_subplot(2, 2, i_out + 1)
    if output in ts_set_dates:
        ts_set_dates[output].plot(ax=axis, style='.')
    if output in derived_df:
        derived_df[output].plot(ax=axis)
    axis.set_title(output.replace("_", " "))
    axis.set_xlim([model_start_time, plot_end_time])