In [None]:
from pathlib import Path
from aust_covid.inputs import load_pop_data
from aust_covid.model import get_processed_mobility_data, get_interp_funcs_from_mobility, build_base_model, adapt_gb_matrices_to_aust, get_age_stratification, set_starting_conditions
from datetime import datetime
from emutools.tex import StandardTexDoc
from summer2.parameters import Function
import pandas as pd
from aust_covid.model import BASE_PATH

MATRIX_LOCATIONS = [
    'school', 
    'home', 
    'work', 
    'other_locations',
]
PROJECT_PATH = Path().resolve().parent
DATA_PATH = BASE_PATH / 'data'
start_date = datetime(2021, 7, 1)
end_date = datetime(2022, 10, 1)
ref_date = datetime(2019, 12, 31)
age_strata = list(range(0, 80, 5))
n_latent_comps = 4
n_infectious_comps = n_latent_comps
latent_compartments = [f'latent_{i}' for i in range(n_latent_comps)]
infectious_compartments = [f'infectious_{i}' for i in range(n_infectious_comps)]
compartments = ['susceptible', 'recovered', 'waned'] + infectious_compartments + latent_compartments
tex_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')
model_pops = load_pop_data(age_strata, tex_doc)
aust_model = build_base_model(ref_date, compartments, infectious_compartments, start_date, end_date, tex_doc)
set_starting_conditions(aust_model, model_pops, tex_doc)

In [None]:
raw_matrices = {l: pd.read_csv(DATA_PATH / f'{l}.csv', index_col=0).to_numpy() for l in MATRIX_LOCATIONS}
adjusted_matrices = adapt_gb_matrices_to_aust(age_strata, raw_matrices, model_pops, tex_doc)

In [None]:
model_mob = get_processed_mobility_data()

In [None]:
interp_funcs = get_interp_funcs_from_mobility(model_mob, aust_model.get_epoch())

In [None]:
def mobility_scaling(matrices, non_wa_work, non_wa_other, nothing):
    return matrices['home'] + matrices['school'] + non_wa_other * matrices['other_locations'] + non_wa_work * matrices['work']

In [None]:
# This runs:
mixing_matrix = Function(mobility_scaling, [adjusted_matrices, interp_funcs['non_wa']['work'], interp_funcs['non_wa']['other_locations'], 0.0])

# This doesn't:
# mixing_matrix = Function(mobility_scaling, [adjusted_matrices, interp_funcs['non_wa']['work'], interp_funcs['non_wa']['other_locations'], interp_funcs])

In [None]:
age_strat = get_age_stratification(compartments, age_strata, mixing_matrix, tex_doc)
aust_model.stratify_with(age_strat)

In [None]:
aust_model.run()