## Check inputs to Victoria regional model before running
Work through all the aspects of the model that don't require it to be actually run before calibrating.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.cm as cm
import numpy as np
from datetime import datetime
from datetime import date

from summer.utils import ref_times_to_dti

from autumn.settings import Region, Models
from autumn.tools.plots.utils import REF_DATE
from autumn.tools.inputs.covid_au.queries import get_vic_testing_numbers
from autumn.models.covid_19.constants import AGEGROUP_STRATA, BASE_DATETIME
from autumn.tools.project import get_project, run_project_locally, TimeSeriesSet
from autumn.tools import inputs
from autumn.tools.inputs.mobility.queries import get_mobility_data
from autumn.tools.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.tools.calibration.targets import get_target_series
from autumn.tools.inputs.demography.queries import get_population_by_agegroup
from autumn.tools.utils.utils import apply_moving_average
from autumn.models.covid_19.detection import create_cdr_function
from autumn.tools.curve.scale_up import scale_up_function
from autumn.tools.inputs.database import get_input_db
from autumn.tools.inputs.covid_au.queries import get_both_vacc_coverage
from autumn.models.covid_19.mixing_matrix.mobility import weight_mobility_data

In [None]:
age_integers = [int(group) for group in AGEGROUP_STRATA]
model = Models.COVID_19

## Population

In [None]:
fig_absolute, axes_absolute = plt.subplots(4, 2, figsize=(12, 15), sharey="all")
for i_ax, ax in enumerate(axes_absolute.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    project = get_project(model, region)
    total_pops = inputs.get_population_by_agegroup(
        AGEGROUP_STRATA, 
        project.param_set.baseline["country"]["iso3"],
        project.param_set.baseline["population"]["region"], 
        year=project.param_set.baseline["population"]["year"]
    )
    print(f"total population of {region} is: {round(sum(total_pops) / 1e6, 3)} million")
    ax.bar(age_integers, total_pops, width=4)
    ax.set_title(region)
fig_absolute.suptitle("population comparison across Victoria")
fig_absolute.tight_layout(pad=2)

In [None]:
fig_relative, axes_relative = plt.subplots(4, 2, figsize=(12, 15))
fig_relative.tight_layout(pad=2)
for i_ax, ax in enumerate(axes_relative.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    project = get_project(model, region)
    total_pops = inputs.get_population_by_agegroup(
        AGEGROUP_STRATA, 
        project.param_set.baseline["country"]["iso3"],
        project.param_set.baseline["population"]["region"], 
        year=project.param_set.baseline["population"]["year"]
    )
    ax.bar(age_integers, total_pops, width=4)
    ax.set_title(region)
fig_relative.suptitle("population comparison within region")
fig_relative.tight_layout(pad=2)

## Mobility

In [None]:
input_db = get_input_db()
mob_fig, mob_axes = plt.subplots(4, 2, figsize=(14, 18))
plot_left_date = date(2021, 6, 1)

for i_ax, ax in enumerate(mob_axes.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    mob_df = input_db.query("mobility", conditions={"iso3": "AUS", "region": region.upper().replace("-", "_")})
    times = [datetime.strptime(i, "%Y-%m-%d") for i in mob_df["date"]]
    plot_right_date = times[-1]  # Not sure why this is necessary
    for mobility_domain in ["grocery_and_pharmacy", "residential", "parks", "retail_and_recreation", "transit_stations"]:
        ax.plot(times, mob_df[mobility_domain], label=mobility_domain)
    ax.set_ylim((0., 1.4))
    ax.set_xlim(left=plot_left_date, right=plot_right_date)
    ax.tick_params(axis="x", labelrotation=45)
    ax.set_title(region)
ax.legend()
mob_fig.suptitle("raw Google mobility domains")
mob_fig.tight_layout()

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(12, 18), sharey="all")
fig.tight_layout(w_pad=1.5, h_pad=3.5)

for i_ax, ax in enumerate(axes.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    project = get_project(model, region)

    mobility_values, mobility_days = get_mobility_data(
        project.param_set.baseline["country"]["iso3"], 
        project.param_set.baseline["population"]["region"], 
        BASE_DATETIME, 
    )
    google_mobility_values = weight_mobility_data(
        mobility_values, project.param_set.baseline["mobility"]["google_mobility_locations"]
    )

    times = ref_times_to_dti(BASE_DATETIME, mobility_days)
    locations = list(project.param_set.baseline["mobility"]["google_mobility_locations"].keys())
    for location in locations:
        ax.plot(times, google_mobility_values[location], label=location)
    ax.tick_params(axis="x", labelrotation=45)
    ax.set_title(region)
    ax.set_ylim((0., 1.3))
    ax.set_xlim(left=plot_left_date, right=plot_right_date)

ax.legend(loc="lower right")
fig.suptitle("mobility as implemented in the model")
fig.tight_layout(w_pad=1.5, h_pad=3.5)

## Mixing matrix
### Check how mixing matrix is specified for each region

In [None]:
for region in Region.VICTORIA_SUBREGIONS:
    project = get_project(model, region)
    print(f"\nFor cluster: {region}")
    print(f"\tModelled country: {project.param_set.baseline['country']['iso3']}")
    print(f"\tModelled sub-region: {project.param_set.baseline['population']['region']}")
    print(f"\tProxy country: {project.param_set.baseline['mixing_matrices']['source_iso3']}")
    print(f"\tWhether age adjusted: {project.param_set.baseline['mixing_matrices']['age_adjust']}")

### Display the matrix and the matrix components

In [None]:
mixing_matrix = build_synthetic_matrices(
    project.param_set.baseline["country"]["iso3"],
    project.param_set.baseline["mixing_matrices"]["source_iso3"],
    AGEGROUP_STRATA,
    project.param_set.baseline["mixing_matrices"]["age_adjust"],
    project.param_set.baseline["population"]["region"]
)

fig = plt.figure(figsize=(12, 8))
positions = [1, 2, 3, 5, 6]
for i_loc, location in zip(positions, mixing_matrix.keys()):
    ax = fig.add_subplot(2, 3, i_loc)
    ax.imshow(
        np.flipud(np.transpose(mixing_matrix[location])), 
        cmap=cm.hot, 
        vmin=0,
        vmax=mixing_matrix[location].max(), 
        origin="lower"
    )
    ax.set_title(location.replace("_", " "))
    ax.set_xticks([])
    ax.set_yticks([])

## Calibration targets

In [None]:
plot_left_date = date(2021, 8, 2)
for output in ("notifications", "hospital_admissions", "hospital_occupancy", "icu_admissions", "icu_occupancy", "infection_deaths"):
    fig, axes = plt.subplots(4, 2, figsize=(12, 15), sharey="all", sharex="all")
    for i_ax, ax in enumerate(axes.reshape(-1)):
        region = Region.VICTORIA_SUBREGIONS[i_ax]
        project = get_project(model, region)
        targets_path = f"../../../autumn/projects/covid_19/victoria/{region.replace('-', '_')}/targets.secret.json"
        ts_set = TimeSeriesSet.from_file(targets_path)
        output_data = ts_set.get(output)
        ax.scatter(ref_times_to_dti(REF_DATE, output_data.times), output_data.values, c="k")
        ax.tick_params(axis="x", labelrotation=45)
        ax.set_title(region)
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%d-%m"))
        ax.set_xlim(left=plot_left_date)
    fig.suptitle(output)
    fig.tight_layout(w_pad=1.5, h_pad=3.5)

## Case detection

In [None]:
# Get the CDR function of tests
cdr_from_tests_func = create_cdr_function(
    project.param_set.baseline["testing_to_detection"]["assumed_tests_parameter"],
    project.param_set.baseline["testing_to_detection"]["assumed_cdr_parameter"],
)

# Get the denominator population
testing_pops = get_population_by_agegroup(
    AGEGROUP_STRATA, 
    project.param_set.baseline["country"]["iso3"], 
    "Victoria",
    year=project.param_set.baseline["population"]["year"]
)

# Process the data
last_values_to_keep = 60
test_times, test_values = get_vic_testing_numbers()
test_times = test_times[-last_values_to_keep:]
test_values = test_values[-last_values_to_keep:]

test_dates = ref_times_to_dti(BASE_DATETIME, [int(time) for time in test_times])
per_capita_tests = [i_tests / sum(testing_pops) for i_tests in test_values]
dummy_tests = np.linspace(0, max(per_capita_tests), 200)
if project.param_set.baseline["testing_to_detection"]["assumed_tests_parameter"]:
    smoothed_per_capita_tests = apply_moving_average(
        per_capita_tests, project.param_set.baseline["testing_to_detection"]["smoothing_period"]
    )
else:
    smoothed_per_capita_tests = per_capita_tests
cdr_function_of_time = scale_up_function(
    test_times,
    [cdr_from_tests_func(test_rate) for test_rate in smoothed_per_capita_tests],
    smoothness=0.2, method=4, bound_low=0.,
)    

# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.tight_layout(w_pad=1.5, h_pad=5)

# Plot daily number of tests
axis = axes[0, 0]
axis.plot(test_dates, test_values, marker="o")
axis.tick_params(axis="x", labelrotation=45)
axis.set_title("daily testing numbers")

# Plot daily number of tests
axis = axes[0, 1]
axis.plot(test_dates, per_capita_tests, label="raw")
axis.plot(test_dates, smoothed_per_capita_tests, label="smoothed")
axis.tick_params(axis="x", labelrotation=45)
axis.set_title("daily per capita testing rate")
axis.legend()

# Plot relationship of daily tests to CDR proportion
axis = axes[1, 0]
axis.plot(dummy_tests, cdr_from_tests_func(dummy_tests))
axis.scatter(per_capita_tests, [cdr_from_tests_func(i_tests) for i_tests in per_capita_tests], color="r")
axis.set_ylabel("case detection proportion")
axis.set_xlabel("per capita testing rate")
axis.set_title("daily per capita tests to CDR relationship")
axis.set_ylim(top=1.)

# Plot CDR values
axis = axes[1, 1]
axis.scatter(test_dates, [cdr_from_tests_func(i_test_rate) for i_test_rate in smoothed_per_capita_tests], color="r")
axis.plot(test_dates, [cdr_function_of_time(time) for time in test_times])
axis.set_title("Final case detection rate")
axis.set_ylabel("proportion")
axis.tick_params(axis="x", labelrotation=45)

## Modelled vaccination roll-out

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(12, 15), sharey="all", sharex="all")
lower_age_limit = 12
for i_ax, ax in enumerate(axes.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    vacc_dates, vacc_coverage = get_both_vacc_coverage(region.upper().replace("-", "_"), start_age=lower_age_limit)
    vacc_dates = ref_times_to_dti(BASE_DATETIME, [int(i) for i in vacc_dates])  # Converting numpy ints returned
    ax.plot(vacc_dates, vacc_coverage)
    ax.set_title(region.replace("-", " "))
    ax.tick_params(axis="x", labelrotation=45)
    ax.set_ylim([0., 1.])
fig.tight_layout(w_pad=1.5, h_pad=3.5)