In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.cm as cm
import numpy as np

from summer.utils import ref_times_to_dti

from autumn.settings import Region, Models
from autumn.tools.plots.utils import REF_DATE
from autumn.tools.inputs.covid_au.queries import get_vic_testing_numbers
from autumn.models.covid_19.constants import AGEGROUP_STRATA, BASE_DATETIME
from autumn.tools.project import get_project, run_project_locally
from autumn.tools import inputs
from autumn.tools.inputs.mobility.queries import get_mobility_data
from autumn.tools.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.tools.calibration.targets import get_target_series
from autumn.tools.inputs.demography.queries import get_population_by_agegroup

from autumn.models.covid_19.preprocess.testing import create_cdr_function

In [None]:
# Some variables that will stay constant in the cells below
age_integers = [int(group) for group in AGEGROUP_STRATA]
model = Models.COVID_19

In [None]:
# Plot the populations
fig_absolute, axes_absolute = plt.subplots(3, 3, figsize=(12, 8), sharey="all")
fig_absolute.tight_layout(pad=2)
for i_ax, ax in enumerate(axes_absolute.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    project = get_project(model, region)
    total_pops = inputs.get_population_by_agegroup(
        AGEGROUP_STRATA, 
        project.param_set.baseline["country"]["iso3"],
        project.param_set.baseline["population"]["region"], 
        year=project.param_set.baseline["population"]["year"]
    )
    print(f"Total population of {region} is {round(sum(total_pops) / 1e6, 3)} million")
    ax.bar(age_integers, total_pops, width=4)
    ax.set_title(region)
        
fig_relative, axes_relative = plt.subplots(3, 3, figsize=(12, 8))
fig_relative.tight_layout(pad=2)
for i_ax, ax in enumerate(axes_relative.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    project = get_project(model, region)
    total_pops = inputs.get_population_by_agegroup(
        AGEGROUP_STRATA, 
        project.param_set.baseline["country"]["iso3"],
        project.param_set.baseline["population"]["region"], 
        year=project.param_set.baseline["population"]["year"]
    )
    ax.bar(age_integers, total_pops, width=4)
    ax.set_title(region)

In [None]:
# Plot the mobility inputs to the model over recent weeks/months
fig, axes = plt.subplots(3, 3, figsize=(12, 8), sharey="all")
fig.tight_layout(w_pad=1.5, h_pad=3.5)

for i_ax, ax in enumerate(axes.reshape(-1)):
    region = Region.VICTORIA_SUBREGIONS[i_ax]
    project = get_project(model, region)

    mobility_values, mobility_days = get_mobility_data(
        project.param_set.baseline["country"]["iso3"], 
        project.param_set.baseline["population"]["region"], 
        BASE_DATETIME, 
        project.param_set.baseline["mobility"]["google_mobility_locations"],
    )

    data_period = 60
    times = ref_times_to_dti(BASE_DATETIME, mobility_days[-data_period:])

    locations = project.param_set.baseline["mobility"]["google_mobility_locations"].keys()
    for location in locations:
        ax.plot(times, mobility_values[location][-data_period:], label=location)
    ax.tick_params(axis="x", labelrotation=45)
    ax.set_title(region)
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%d-%m"))

ax.legend(loc="upper right")

In [None]:
# Plot the mixing matrix being used
for region in Region.VICTORIA_SUBREGIONS:
    project = get_project(model, region)
    print(f"\nFor cluster: {region}")
    print(f"\tModelled country: {project.param_set.baseline['country']['iso3']}")
    print(f"\tModelled sub-region: {project.param_set.baseline['population']['region']}")
    print(f"\tProxy country: {project.param_set.baseline['mixing_matrices']['source_iso3']}")
    print(f"\tWhether age adjusted: {project.param_set.baseline['mixing_matrices']['age_adjust']}")
mixing_matrix = build_synthetic_matrices(
    project.param_set.baseline["country"]["iso3"],
    project.param_set.baseline["mixing_matrices"]["source_iso3"],
    AGEGROUP_STRATA,
    project.param_set.baseline["mixing_matrices"]["age_adjust"],
    project.param_set.baseline["population"]["region"]
)

fig = plt.figure(figsize=(12, 8))
positions = [1, 2, 3, 5, 6]
for i_loc, location in zip(positions, mixing_matrix.keys()):
    ax = fig.add_subplot(2, 3, i_loc)
    ax.imshow(
        np.flipud(np.transpose(mixing_matrix[location])), 
        cmap=cm.hot, 
        vmin=0,
        vmax=mixing_matrix[location].max(), 
        origin="lower"
    )
    ax.set_title(location.replace("_", " "))
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
for output in ("notifications", "hospital_admissions"):
    fig, axes = plt.subplots(3, 3, figsize=(12, 8), sharey="all", sharex="all")
    fig.tight_layout(w_pad=1.5, h_pad=3.5)
    for i_ax, ax in enumerate(axes.reshape(-1)):
        region = Region.VICTORIA_SUBREGIONS[i_ax]
        project = get_project(model, region)
        dates, values = get_target_series(project.calibration.targets, REF_DATE, output)
        ax.scatter(dates[-60:], values[-60:])
        ax.tick_params(axis="x", labelrotation=45)
        ax.set_title(region)
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%d-%m"))

In [None]:
# Get the CDR function of tests
cdr_function = create_cdr_function(
    project.param_set.baseline["testing_to_detection"]["assumed_tests_parameter"],
    project.param_set.baseline["testing_to_detection"]["assumed_cdr_parameter"],
)

# Get the denominator population
testing_pops = get_population_by_agegroup(
    AGEGROUP_STRATA, 
    project.param_set.baseline["country"]["iso3"], 
    project.param_set.baseline["population"]["region"], 
    year=project.param_set.baseline["population"]["year"]
)

# Process the data
test_values = get_vic_testing_numbers()[1]
per_capita_tests = [i_tests / sum(testing_pops) for i_tests in test_values]
dummy_tests = np.linspace(0, max(per_capita_tests), 200)

# Plot
fig, axis = plt.subplots(figsize=(12, 8))
axis.plot(dummy_tests, cdr_function(dummy_tests))
axis.scatter(per_capita_tests, [cdr_function(i_tests) for i_tests in per_capita_tests])
axis.set_ylabel("case detection proportion")
axis.set_xlabel("per capita testing rate")
axis.set_ylim(top=1.)