In [None]:
from matplotlib import pyplot as plt

from summer.utils import ref_times_to_dti

from autumn.settings.region import Region
from autumn.tools.inputs.covid_au.queries import get_both_vacc_coverage
from autumn.tools.project import get_project
from autumn.models.covid_19.constants import BASE_DATETIME

## Run all the cluster models
Can't yet run the North-East Metro region, until we have other code issues sorted.

In [None]:
derived_dfs = {}
regions = [region for region in Region.VICTORIA_SUBREGIONS if "north" not in region]  # Currently working subregions
for region in regions:
    project = get_project("covid_19", region)
    model = project.run_baseline_model(project.param_set.baseline)
    derived_dfs[region] = model.get_derived_outputs_df()

## Plot the vaccination coverage

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(12, 20), sharex="all")
fig.tight_layout(w_pad=1.5, h_pad=3.5)

for i_ax, axis in enumerate(axes.reshape(-1)[:-1]):
    axis.fill_between(
        derived_dfs[regions[i_ax]].index, 
        [0.] * len(derived_dfs[regions[i_ax]]), 
        derived_dfs[regions[i_ax]]["proportion_fully_vaccinated"], 
        label="fully vaccinated"
    )
    axis.fill_between(
        derived_dfs[regions[i_ax]].index, 
        derived_dfs[regions[i_ax]]["proportion_fully_vaccinated"], 
        derived_dfs[regions[i_ax]]["at_least_one_dose_prop"], 
        label="one dose only"
    )
    axis.fill_between(
        derived_dfs[regions[i_ax]].index, 
        derived_dfs[regions[i_ax]]["at_least_one_dose_prop"], 
        [1.] * len(derived_dfs[regions[i_ax]]), 
        label="unvaccinated"
    )
    vacc_times, vacc_coverage = get_both_vacc_coverage(regions[i_ax].upper().replace("-", "_"))
    vacc_dates = ref_times_to_dti(BASE_DATETIME, [int(i) for i in vacc_times])  # Converting numpy ints returned
    lagged_dates = ref_times_to_dti(BASE_DATETIME, [int(i) + 14 for i in vacc_times])
    axis.plot(vacc_dates, vacc_coverage, color="k", label="actual vaccination")
    axis.plot(lagged_dates, vacc_coverage, color="k", linestyle="--", label="lagged for immunity")
    axis.tick_params(axis="x", labelrotation=45)
    axis.set_title(regions[i_ax])
    if i_ax == len(axes.reshape(-1)) - 2:
        axis.legend()