In [None]:
from autumn.tools.project import get_project
from matplotlib import pyplot
from autumn.tools.plots.utils import REF_DATE
#from autumn.tools.calibration.targets import get_targets
import pandas as pd
from autumn.tools.utils.pretty import pretty_print
from autumn.models.covid_19.constants import BASE_DATETIME

from summer.utils import ref_times_to_dti

from autumn.models.covid_19.strat_processing.vaccination import find_vacc_strata
from autumn.tools.inputs.covid_lka.queries import get_lka_vac_coverage

In [None]:
project = get_project("covid_19", "sri_lanka")

In [None]:
custom_params = project.param_set.baseline#.update(dict(contact_rate = 0.1))
model = project.run_baseline_model(custom_params)
derived_df = model.get_derived_outputs_df()

In [None]:
model_start_time = ref_times_to_dti(REF_DATE, [custom_params["time"]["start"]])[0]
model_end_time = ref_times_to_dti(REF_DATE, [custom_params["time"]["end"]])[0]

In [None]:
output = ["notifications","infection_deaths"]
targets_dict = {
    t.data.name: pd.Series(t.data.data, index=ref_times_to_dti(model.ref_date, t.data.index)) for 
    t in project.calibration.targets
}


In [None]:
outputs_to_plot = ["notifications", "infection_deaths"]

plot_end_time = ref_times_to_dti(REF_DATE, [1000.])[0]

fig = pyplot.figure(figsize=(15, 12))
for i_out, output in enumerate(outputs_to_plot):
    axis = fig.add_subplot(2, 2, i_out + 1)
    if output in targets_dict:
        targets_dict[output].plot(ax=axis, style='.')
    if output in derived_df:
        derived_df[output].plot(ax=axis)
    axis.set_title(output.replace("_", " "))
    axis.set_xlim([model_start_time, plot_end_time])

In [None]:
fig = pyplot.figure(figsize=(12, 8))
pyplot.style.use("ggplot")
axis = fig.add_subplot()
axis = derived_df["hospital_occupancy"].plot()

In [None]:
fig = pyplot.figure(figsize=(15, 6))
pyplot.style.use("ggplot")
axis = fig.add_subplot(1,2,1)
axis = derived_df["cdr"].plot()
axis = fig.add_subplot(1,2,2)
axis = derived_df["prop_ever_infected"].plot()

In [None]:
fig = pyplot.figure(figsize=(12, 8))
pyplot.style.use("ggplot")
axis = fig.add_subplot()
axis = derived_df["prop_ever_infected"].plot()

In [None]:
vacc_strata, _ = find_vacc_strata(True, True, False)

In [None]:
fig, axis = pyplot.subplots(1, 1, figsize=(12, 8))

# Loop over the vaccination types being implemented in the model
lower_value = [0.] * len(derived_df)
for stratum in vacc_strata[::-1]:
    working_value = derived_df[f"proportion_{stratum}"]
    upper_value = lower_value + working_value
    axis.fill_between(derived_df.index, lower_value, upper_value, label=stratum)
    lower_value = upper_value

vacc_times, vacc_coverage = get_lka_vac_coverage(15)
vacc_dates = ref_times_to_dti(BASE_DATETIME, [int(i) for i in vacc_times[1]])  # Converting numpy ints returned

lagged_dates = ref_times_to_dti(BASE_DATETIME, [int(i) + 14 for i in vacc_times[1]])
axis.plot(vacc_dates, vacc_coverage[1], color="k", label="actual vaccination")
axis.plot(lagged_dates, vacc_coverage[1], color="k", linestyle="--", label="lagged for immunity")
axis.tick_params(axis="x", labelrotation=45)
axis.set_title("vaccination check")
axis.legend()