In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from datetime import date
import numpy as np
import os
import pandas as pd

from summer.utils import ref_times_to_dti

from autumn.core.inputs.mobility.queries import get_mobility_data
from autumn.core.inputs.demography.queries import get_population_by_agegroup
from autumn.core.inputs.database import get_input_db
from autumn.core.inputs.social_mixing.build_synthetic_matrices import build_synthetic_matrices, load_socialmixr_matrices
from autumn.core.inputs.testing.testing_data import get_testing_numbers_for_region

from autumn.core.utils.utils import apply_moving_average
from autumn.core.project import get_project, load_timeseries

from autumn.model_features.curve.scale_up import scale_up_function

from autumn.settings import Region, Models
from autumn.settings.constants import COVID_BASE_DATETIME

from autumn.models.sm_sir.mixing_matrix.macrodistancing import weight_mobility_data
from autumn.models.sm_sir.detection import create_cdr_function

In [None]:
model = Models.SM_SIR
region = Region.NCR

project = get_project(model, region)
params = project.param_set.baseline
iso3 = params["country"]["iso3"]
pop_region = params["population"]["region"]

AGEGROUP_STRATA = params['age_groups']
age_integers = [int(group) for group in AGEGROUP_STRATA]


## Population

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
total_pops = get_population_by_agegroup(
    AGEGROUP_STRATA, 
    iso3,
    pop_region, 
    year=project.param_set.baseline["population"]["year"]
)
print(f"Total modelled population of {region} is: {round(sum(total_pops) / 1e3, 3)} thousand")
ax.bar(age_integers, total_pops, width=4)
ax.set_title(region)
ax.set_ylabel("population")
ax.set_xlabel("starting age of age bracket")
fig.suptitle("population distribution by age")

## Mobility

In [None]:
print("Whether the mobility effects are actually turned on at all:")
project.param_set.baseline["is_dynamic_mixing_matrix"]

In [None]:
y_upper = 2.

# Collate data together
input_db = get_input_db()

mob_df, int_times = get_mobility_data(iso3, pop_region, COVID_BASE_DATETIME)
google_mob_df = weight_mobility_data(mob_df, project.param_set.baseline["mobility"]["google_mobility_locations"])

times = ref_times_to_dti(COVID_BASE_DATETIME, int_times)

# Get plots ready
mob_fig, mob_axes = plt.subplots(1, 2, figsize=(12, 6))
plot_left_date = date(2020, 1, 1)
plot_right_date = times[-1]  # Not sure why this is necessary

# Plot raw mobility data
ax = mob_axes[0]
for mobility_domain in ["grocery_and_pharmacy", "residential", "parks", "retail_and_recreation", "transit_stations"]:
    ax.plot(times, mob_df[mobility_domain], label=mobility_domain)
ax.set_ylim((0., y_upper))
ax.tick_params(axis="x", labelrotation=45)
ax.set_title("raw Google mobility domains")
ax.legend(loc="lower right")
ax.set_xlim(left=plot_left_date, right=plot_right_date)

# Plot processed mobility data
ax = mob_axes[1]
for location in list(project.param_set.baseline["mobility"]["google_mobility_locations"].keys()):
    ax.plot(times, google_mob_df[location], label=location)
ax.tick_params(axis="x", labelrotation=45)
ax.set_ylim((0., y_upper))
ax.legend(loc="lower left")
ax.set_title("mobility as implemented in the model")
mob_fig.tight_layout(w_pad=1.5, h_pad=3.5)
ax.set_xlim(left=plot_left_date, right=plot_right_date)


# School mobility profile
school_times, school_values = params["mobility"]["mixing"]["school"]["times"], params["mobility"]["mixing"]["school"]["values"]
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.plot(school_times, school_values, "bo")

## Mixing matrix
### Check how the mixing matrix is specified

In [None]:
print(f"Modelled country: {iso3}")
print(f"Modelled sub-region: {pop_region}")
print(f"Proxy country: {params['ref_mixing_iso3']}")
print("Always age-adjusted under SM-SIR code")

### Display the matrix and the matrix components

In [None]:
_ , source_age_breaks = load_socialmixr_matrices(project.param_set.baseline["ref_mixing_iso3"], ["all_locations"])
agegroup_types = {
    "base age groups": source_age_breaks,
    "modelled age groups": project.param_set.baseline["age_groups"],
}

for title, agegroups in agegroup_types.items():

    mixing_matrix = build_synthetic_matrices(
        project.param_set.baseline["country"]["iso3"],
        project.param_set.baseline["ref_mixing_iso3"],
        agegroups,
        True,
        project.param_set.baseline["population"]["region"]
    )

    fig = plt.figure(figsize=(12, 8))
    positions = [1, 2, 3, 5, 6]
    for i_loc, location in zip(positions, mixing_matrix.keys()):
        ax = fig.add_subplot(2, 3, i_loc)
        ax.imshow(
            np.flipud(np.transpose(mixing_matrix[location])), 
            cmap=cm.hot, 
            vmin=0,
            vmax=mixing_matrix[location].max(), 
            origin="lower"
        )
        ax.set_title(location.replace("_", " "))
        ax.set_xticks([])
        ax.set_yticks([])
    fig.suptitle(title)

## Case detection

In [None]:
# Get the CDR function of tests
cdr_from_tests_func = create_cdr_function(
    project.param_set.baseline["testing_to_detection"]["assumed_tests_parameter"],
    project.param_set.baseline["testing_to_detection"]["assumed_cdr_parameter"],
)

# Get the denominator population
testing_pops = get_population_by_agegroup(
    project.param_set.baseline["age_groups"],
    iso3,
    pop_region
)

# Process the data
test_df= get_testing_numbers_for_region(iso3, pop_region)
test_times, test_values = test_df.index.to_list(), test_df.to_list()
test_dates = ref_times_to_dti(COVID_BASE_DATETIME, [int(time) for time in test_times])
per_capita_tests = [i_tests / sum(testing_pops) for i_tests in test_values]
dummy_tests = np.linspace(0, max(per_capita_tests), 200)
if project.param_set.baseline["testing_to_detection"]["assumed_tests_parameter"]:
    smoothed_per_capita_tests = apply_moving_average(
        per_capita_tests, 
        project.param_set.baseline["testing_to_detection"]["smoothing_period"]
    )
else:
    smoothed_per_capita_tests = per_capita_tests
cdr_function_of_time = scale_up_function(
    test_times,
    [cdr_from_tests_func(test_rate) for test_rate in smoothed_per_capita_tests],
    smoothness=0.2, method=4, bound_low=0.,
)    

# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.tight_layout(w_pad=1.5, h_pad=5)

def sort_axis_dates(ax):
    axis.tick_params(axis="x", labelrotation=45)
    axis.set_xlim(left=plot_left_date, right=plot_right_date)

# Plot daily number of tests
axis = axes[0, 0]
axis.plot(test_dates, test_values, marker="o")
axis.set_title("daily testing numbers")
sort_axis_dates(axis)

# Plot daily number of tests
axis = axes[0, 1]
axis.plot(test_dates, per_capita_tests, label="raw")
axis.plot(test_dates, smoothed_per_capita_tests, label="smoothed")
axis.set_title("daily per capita testing rate")
sort_axis_dates(axis)
axis.legend()

# Plot relationship of daily tests to CDR proportion
axis = axes[1, 0]
axis.plot(dummy_tests, cdr_from_tests_func(dummy_tests))
axis.scatter(per_capita_tests, [cdr_from_tests_func(i_tests) for i_tests in per_capita_tests], color="r")
axis.set_ylabel("case detection proportion")
axis.set_xlabel("per capita testing rate")
axis.set_title("daily per capita tests to CDR relationship")
axis.set_ylim(top=1.)

# Plot CDR values
axis = axes[1, 1]
axis.scatter(test_dates, [cdr_from_tests_func(i_test_rate) for i_test_rate in smoothed_per_capita_tests], color="r")
axis.plot(test_dates, [cdr_function_of_time(time) for time in test_times])
axis.set_title("Final case detection rate")
axis.set_ylabel("proportion")
sort_axis_dates(axis)

fig.tight_layout()

## Calibration Targets

In [None]:
all_targets = load_timeseries(os.path.join(project.get_path(), "timeseries.json"))
if project.region_name == "national-capital-region":
    new_targets = load_timeseries(os.path.join(project.get_path(), "new_targets.json"))
    all_targets.update(new_targets)

for target in all_targets:
    all_targets[target].index = ref_times_to_dti(COVID_BASE_DATETIME, all_targets[target].index) 

calibrated_targets = project.calibration.targets
calibrated_targets_names = [t.data.name for t in calibrated_targets]

for output in all_targets:
    if len(all_targets[output].index) == 0:
        continue
    fig = plt.figure(figsize=(12, 8))
    plt.style.use("ggplot")
    axis = fig.add_subplot()

    axis.scatter(all_targets[output].index, all_targets[output], color="k", s=5, alpha=0.5, zorder=10)
    
    if output in calibrated_targets_names:
        index = calibrated_targets_names.index(output)
        cal_times = calibrated_targets[index].data.index.tolist()
        cal_values = calibrated_targets[index].data.tolist()
        cal_date_times = pd.to_datetime(cal_times, origin="01Jan2020",unit="D") 
        
        plt.scatter(cal_date_times, cal_values, color="red", s=10.)
          
    axis.set_title("COVID-19 " + project.plots[output]["output_key"] + " in NCR")