In [None]:
import pandas as pd

from aust_covid.inputs import get_ifrs
from emutools.tex import StandardTexDoc
from aust_covid.inputs import get_base_vacc_data
from aust_covid.vaccination import add_derived_data_to_vacc, get_model_vacc_vals_from_data
from aust_covid.model import build_model
from aust_covid.plotting import plot_immune_props
from inputs.constants import PROJECT_PATH, SUPPLEMENT_PATH
from emutools.inputs import load_param_info

In [None]:
vacc_df = get_base_vacc_data()
ext_vacc_df = add_derived_data_to_vacc(vacc_df)
boost_data = get_model_vacc_vals_from_data(ext_vacc_df, 'prop boosted in preceding')
primary_data = get_model_vacc_vals_from_data(ext_vacc_df, 'prop primary full in preceding')

In [None]:
# from summer2 import CompartmentalModel, Stratification, population
# from summer2.parameters import Parameter, Function

In [None]:
# from jax import numpy as jnp
# import numpy as np

In [None]:
# from inputs.constants import AGE_STRATA

In [None]:
# from aust_covid.inputs import load_pop_data

In [None]:
app_doc = StandardTexDoc(SUPPLEMENT_PATH, 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml')
param_info['value'].update(get_ifrs(app_doc))
parameters = param_info['value'].to_dict()
epi_model = build_model(app_doc, vacc_sens=True)
epoch = epi_model.get_epoch()

In [None]:
# vacc_sens_val = 0.0
# vacc_sens = True

# model_pops = load_pop_data(app_doc)
# pops_dict = model_pops.to_dict()
# params = {'prop_imm': 0.4}

# m = CompartmentalModel([0, 100], ['S', 'I', 'R'], ['I'])

# age_strat = Stratification('agegroup', [str(age) for age in AGE_STRATA], ['S', 'I', 'R'])
# m.stratify_with(age_strat)

# state_strat = Stratification('states', list(pops_dict.keys()), ['S', 'I', 'R'])
# m.stratify_with(state_strat)

# imm_strat = Stratification('immunity', ['imm', 'nonimm'], ['S', 'I', 'R'])
# m.stratify_with(imm_strat)

# def get_init_pop(imm_split):
#     init_pop = jnp.zeros(len(m.compartments), dtype=np.float64)
#     for agegroup in m.stratifications['agegroup'].strata:
#         for state in m.stratifications['states'].strata:
#             for imm_status in m.stratifications['immunity'].strata:
#                 q = m.query_compartments({'name': 'S', 'agegroup': agegroup, 'states': state, 'immunity': imm_status}, as_idx=True)
#                 pop = pops_dict[state][int(agegroup)] * imm_split[imm_status]
#                 init_pop = init_pop.at[q].set(pop)
#     return init_pop

# prop_immune = vacc_sens_val if vacc_sens else Parameter('prop_imm')
# imm_split = {'imm': prop_immune, 'nonimm': 1.0 - prop_immune}
# m.init_population_with_graphobject(Function(get_init_pop, [imm_split]))
# m.run(parameters=params)

In [None]:
# init_pop = m.get_initial_population(params)

In [None]:
# model_comps = [c.name for c in epi_model._original_compartment_names]
# for patch in epi_model.stratifications['states'].strata:
#     for age in epi_model.stratifications['agegroup'].strata:
#         sub_strat = f'{patch}_{age}'
#         epi_model.request_output_for_compartments(sub_strat, model_comps, {'agegroup': age, 'states': patch})

In [None]:
parameters

In [None]:
epi_model.run(parameters=parameters)

In [None]:
derived_outs = epi_model.get_derived_outputs_df()
start_pops_df = pd.DataFrame()
for patch in ['wa', 'other']:
    patch_data = derived_outs[[i for i in derived_outs.columns if f'{patch}_' in i]].iloc[0, :]
    patch_data.index = patch_data.index.str.replace(f'{patch}_', '')
    start_pops_df[patch] = patch_data
start_pops_df.plot()

In [None]:
init_pop = epi_model.get_initial_population(parameters=parameters)

In [None]:
epi_model.stratifications['states'].strata

In [None]:
from aust_covid.inputs import load_pop_data
model_pops = load_pop_data(app_doc)


In [None]:
states = epi_model.stratifications['states'].strata
for age in epi_model.stratifications['agegroup'].strata:
    for state in states:
        model_pops.loc[int(age), f'check_{state}'] = init_pop[init_pop.index.str.contains(f'states_{state}') & init_pop.index.str.contains(f'agegroup_{age}X')].sum()

imm_check = {}
for imm in epi_model.stratifications['immunity'].strata:
    imm_check[imm] = init_pop[init_pop.index.str.contains(f'immunity_{imm}')].sum()

print([v / sum(imm_check.values()) for v in imm_check.values()])

In [None]:
model_pops.plot()

In [None]:
plot_immune_props(epi_model, ext_vacc_df)