In [None]:
from pathlib import Path
from aust_covid.inputs import load_pop_data, get_ifrs
from aust_covid.model import get_processed_mobility_data, get_interp_funcs_from_mobility, build_base_model, adapt_gb_matrices_to_aust, get_age_stratification, set_starting_conditions, get_wa_infection_scaling
from datetime import datetime
from emutools.tex import StandardTexDoc
from summer2.parameters import Function
import pandas as pd
from aust_covid.model import BASE_PATH

MATRIX_LOCATIONS = [
    'school', 
    'home', 
    'work', 
    'other_locations',
]
PROJECT_PATH = Path().resolve().parent
DATA_PATH = BASE_PATH / 'data'
start_date = datetime(2021, 7, 1)
end_date = datetime(2022, 10, 1)
ref_date = datetime(2019, 12, 31)
age_strata = list(range(0, 80, 5))
n_latent_comps = 4
n_infectious_comps = n_latent_comps
latent_compartments = [f'latent_{i}' for i in range(n_latent_comps)]
infectious_compartments = [f'infectious_{i}' for i in range(n_infectious_comps)]
compartments = ['susceptible', 'recovered', 'waned'] + infectious_compartments + latent_compartments
tex_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')
model_pops = load_pop_data(age_strata, tex_doc)
aust_model = build_base_model(ref_date, compartments, infectious_compartments, start_date, end_date, tex_doc)
set_starting_conditions(aust_model, model_pops, tex_doc)

In [None]:
parameters = {
    'contact_rate': 0.065,
    'latent_period': 1.8,
    'infectious_period': 2.5,
    'natural_immunity_period': 60.0,
    'start_cdr': 0.3,
    'imm_prop': 0.4,
    'imm_infect_protect': 0.4,
    'ifr_adjuster': 3.0,
    'ba1_seed_time': 619.0,
    'ba2_seed_time': 659.0,
    'ba5_seed_time': 715.0,
    'ba2_escape': 0.4,
    'ba5_escape': 0.54,
    'ba2_rel_ifr': 0.5,
    'wa_reopen_period': 50.0,
    'seed_duration': 10.0,
    'seed_rate': 1.0,
    'notifs_mean': 4.0,
    'notifs_shape': 2.0,
    'deaths_mean': 15.93,
    'deaths_shape': 5.0,
}
ifrs = get_ifrs(tex_doc)
parameters.update(ifrs)

In [None]:
raw_matrices = {l: pd.read_csv(DATA_PATH / f'{l}.csv', index_col=0).to_numpy() for l in MATRIX_LOCATIONS}
adjusted_matrices = adapt_gb_matrices_to_aust(age_strata, raw_matrices, model_pops, tex_doc)

In [None]:
model_mob = get_processed_mobility_data()

In [None]:
interp_funcs = get_interp_funcs_from_mobility(model_mob, aust_model.get_epoch())

In [None]:
state_props = model_pops.sum() / model_pops.sum().sum()

In [None]:
wa_reopen_func = get_wa_infection_scaling(datetime(2022, 3, 3), aust_model)
wa_prop_func = wa_reopen_func * state_props[0]

In [None]:
def capture_kwargs(*args, **kwargs):
    return kwargs

In [None]:
wa_funcs = Function(capture_kwargs, kwargs=interp_funcs['wa'])
non_wa_funcs = Function(capture_kwargs, kwargs=interp_funcs['non_wa'])

In [None]:
def get_dynamic_matrix(matrices, wa_funcs, non_wa_funcs, wa_prop_func):
    working_matrix = matrices['home'] + matrices['school']
    funcs = {
        'wa': wa_funcs,
        'non_wa': non_wa_funcs,
    }
    for location in ['other_locations', 'work']:
        for patch in ['wa', 'non_wa']:
            prop = wa_prop_func if patch == 'wa' else 1.0 - wa_prop_func
            funcs = wa_funcs if patch == 'wa' else non_wa_funcs
            working_matrix += matrices[location] * funcs[location] * prop
    return working_matrix

In [None]:
dynamic_matrix = Function(
    get_dynamic_matrix, 
    [
        adjusted_matrices, 
        wa_funcs,
        non_wa_funcs,
        wa_prop_func,
    ]
)

In [None]:
age_strat = get_age_stratification(compartments, age_strata, dynamic_matrix, tex_doc)
aust_model.stratify_with(age_strat)

In [None]:
aust_model.run(parameters=parameters)