## Check inputs to Victoria state-wide simple model before running
Work through all the aspects of the model that don't require it to be actually run before calibrating.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.cm as cm
import numpy as np
from datetime import datetime
from datetime import date

from summer.utils import ref_times_to_dti

from autumn.settings import Region, Models
from autumn.tools.plots.utils import REF_DATE
from autumn.tools.inputs.covid_au.queries import get_vic_testing_numbers
from autumn.models.covid_19.constants import AGEGROUP_STRATA, BASE_DATETIME, GOOGLE_MOBILITY_LOCATIONS
from autumn.tools.project import get_project, run_project_locally, TimeSeriesSet
from autumn.tools import inputs
from autumn.tools.inputs.mobility.queries import get_mobility_data
from autumn.tools.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices
from autumn.tools.calibration.targets import get_target_series
from autumn.tools.inputs.demography.queries import get_population_by_agegroup
from autumn.tools.utils.utils import apply_moving_average
from autumn.models.covid_19.detection import create_cdr_function
from autumn.tools.curve.scale_up import scale_up_function
from autumn.tools.inputs.database import get_input_db
from autumn.tools.inputs.covid_au.queries import get_both_vacc_coverage
from autumn.models.covid_19.mixing_matrix.macrodistancing import weight_mobility_data

In [None]:
input_db = get_input_db()

In [None]:
age_integers = [int(group) for group in AGEGROUP_STRATA]
model = Models.COVID_19
region = Region.VICTORIA
project = get_project(model, region)

## Population

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
total_pops = inputs.get_population_by_agegroup(
    AGEGROUP_STRATA, 
    project.param_set.baseline["country"]["iso3"],
    project.param_set.baseline["population"]["region"], 
    year=project.param_set.baseline["population"]["year"]
)
print(f"total population of {region} is: {round(sum(total_pops) / 1e6, 3)} million")
ax.bar(age_integers, total_pops, width=4)
ax.set_title(f"{region} population distribution")

## Mobility

In [None]:
plot_start_date = date(2021, 6, 1)
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Raw
ax = axes[0]
mob_df = input_db.query("mobility", conditions={"iso3": "AUS", "region": "Victoria"})
dates = [datetime.strptime(i, "%Y-%m-%d") for i in mob_df["date"]]
mob_df.index = dates
mob_df.plot(ax=ax)
ax.set_title("raw Google mobility domains")

# Model inputs
ax = axes[1]
mobility_values, mobility_days = get_mobility_data(
    project.param_set.baseline["country"]["iso3"], project.param_set.baseline["population"]["region"], BASE_DATETIME, 
)
model_mobility_values = weight_mobility_data(
    mobility_values, project.param_set.baseline["mobility"]["google_mobility_locations"]
)
model_mobility_values.index = dates
model_mobility_values.plot(ax=ax)
ax.set_title("model mobility inputs")

# Tiyding up
for ax in axes:
    ax.set_ylim((0., 1.4))
    ax.set_xlim(left=plot_start_date)
    ax.tick_params(axis="x", labelrotation=45)

fig.tight_layout()

## Mixing matrix
### Check how mixing matrix is specified

In [None]:
project = get_project(model, region)
print(f"\nFor cluster: {region}")
print(f"\tModelled country: {project.param_set.baseline['country']['iso3']}")
print(f"\tModelled sub-region: {project.param_set.baseline['population']['region']}")
print(f"\tProxy country: {project.param_set.baseline['mixing_matrices']['source_iso3']}")
print(f"\tWhether age adjusted: {project.param_set.baseline['mixing_matrices']['age_adjust']}")

### Display the matrix and the matrix components

In [None]:
mixing_matrix = build_synthetic_matrices(
    project.param_set.baseline["country"]["iso3"],
    project.param_set.baseline["mixing_matrices"]["source_iso3"],
    AGEGROUP_STRATA,
    project.param_set.baseline["mixing_matrices"]["age_adjust"],
    project.param_set.baseline["population"]["region"]
)

fig = plt.figure(figsize=(12, 8))
positions = [1, 2, 3, 5, 6]
for i_loc, location in zip(positions, mixing_matrix.keys()):
    ax = fig.add_subplot(2, 3, i_loc)
    ax.imshow(
        np.flipud(np.transpose(mixing_matrix[location])), 
        cmap=cm.hot, 
        vmin=0,
        vmax=mixing_matrix[location].max(), 
        origin="lower"
    )
    ax.set_title(location.replace("_", " "))
    ax.set_xticks([])
    ax.set_yticks([])

## Calibration targets

In [None]:
plot_left_date = date(2021, 8, 2)
requested_outputs = (
    "notifications", "hospital_admissions", "hospital_occupancy", "icu_admissions", "icu_occupancy", "infection_deaths"
)

region = Region.VICTORIA
targets_path = f"../../../autumn/projects/covid_19/victoria/victoria/targets.secret.json"
ts_set = TimeSeriesSet.from_file(targets_path)

fig, axes = plt.subplots(3, 2, figsize=(12, 15), sharex="all")
for i_ax, ax in enumerate(axes.reshape(-1)):
    output = requested_outputs[i_ax]
    output_data = ts_set.get(output)
    ax.scatter(ref_times_to_dti(REF_DATE, output_data.times), output_data.values, c="k")
    ax.tick_params(axis="x", labelrotation=45)
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%d-%m"))
    ax.set_xlim(left=plot_left_date)
    ax.set_title(output)
fig.tight_layout(w_pad=1.5, h_pad=3.5)

## Case detection

In [None]:
# Get the CDR function of tests
cdr_from_tests_func = create_cdr_function(
    project.param_set.baseline["testing_to_detection"]["assumed_tests_parameter"],
    project.param_set.baseline["testing_to_detection"]["assumed_cdr_parameter"],
)

# Get the denominator population
testing_pops = get_population_by_agegroup(
    AGEGROUP_STRATA, project.param_set.baseline["country"]["iso3"], "Victoria", 
    year=project.param_set.baseline["population"]["year"]
)

# Process the data
test_times, test_values = get_vic_testing_numbers()
test_dates = ref_times_to_dti(BASE_DATETIME, [int(time) for time in test_times])
per_capita_tests = [i_tests / sum(testing_pops) for i_tests in test_values]
dummy_tests = np.linspace(0, max(per_capita_tests), 200)
if project.param_set.baseline["testing_to_detection"]["assumed_tests_parameter"]:
    smoothed_per_capita_tests = apply_moving_average(
        per_capita_tests, project.param_set.baseline["testing_to_detection"]["smoothing_period"]
    )
else:
    smoothed_per_capita_tests = per_capita_tests
    
cdr_function_of_time = scale_up_function(
    test_times,
    [cdr_from_tests_func(test_rate) for test_rate in smoothed_per_capita_tests],
    smoothness=0.2, method=4, bound_low=0.,
)    

# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.tight_layout(w_pad=1.5, h_pad=5)

# Plot daily number of tests
axis = axes[0, 0]
axis.plot(test_dates, test_values, marker="o")
axis.tick_params(axis="x", labelrotation=45)
axis.set_title("daily testing numbers")
axis.set_xlim(left=plot_left_date)

# Plot daily number of tests
axis = axes[0, 1]
axis.plot(test_dates, per_capita_tests, label="raw")
axis.plot(test_dates, smoothed_per_capita_tests, label="smoothed")
axis.tick_params(axis="x", labelrotation=45)
axis.set_xlim(left=plot_left_date)
axis.set_title("daily per capita testing rate")
axis.legend()

# Plot relationship of daily tests to CDR proportion
axis = axes[1, 0]
axis.plot(dummy_tests, cdr_from_tests_func(dummy_tests))
axis.scatter(per_capita_tests, [cdr_from_tests_func(i_tests) for i_tests in per_capita_tests], color="r")
axis.set_ylabel("case detection proportion")
axis.set_xlabel("per capita testing rate")
axis.set_title("daily per capita tests to CDR relationship")
axis.set_ylim([0., 1.])

# Plot CDR values
axis = axes[1, 1]
axis.scatter(test_dates, [cdr_from_tests_func(i_test_rate) for i_test_rate in smoothed_per_capita_tests], color="r")
axis.plot(test_dates, [cdr_function_of_time(time) for time in test_times])
axis.set_title("Final case detection rate")
axis.set_ylabel("proportion")
axis.tick_params(axis="x", labelrotation=45)
axis.set_xlim(left=plot_left_date)
axis.set_ylim([0., 1.])

fig.tight_layout()

## Modelled vaccination roll-out

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
lower_age_limit = 12
vacc_dates, vacc_coverage = get_both_vacc_coverage(region.upper(), start_age=lower_age_limit)
vacc_dates = ref_times_to_dti(BASE_DATETIME, [int(i) for i in vacc_dates])  # Converting numpy ints returned
ax.plot(vacc_dates, vacc_coverage)
ax.set_title(f"{region} vaccination coverage")
ax.set_ylim([0., 1.])
ax.tick_params(axis="x", labelrotation=45)