List of outstanding tasks with this calibration:
- Initialise parameters from different points
- Reinstate mobility contact scaling
- Consider implementing empiric vaccination-related susceptibility
- Get back running over Colab
- Write the paper

In [None]:
from datetime import datetime
import numpy as np
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pymc as pm
import arviz as az
from jax import numpy as jnp

from estival.model import BayesianCompartmentalModel
import estival.priors as esp
import estival.targets as est
from estival.wrappers import pymc as epm

from inputs.constants import ANALYSIS_START_DATE, ANALYSIS_END_DATE, PLOT_START_DATE, PROJECT_PATH
from autumn.infrastructure.remote import springboard
from autumn.core.runs import ManagedRun
from aust_covid.inputs import load_calibration_targets, load_who_data, load_serosurvey_data, get_ifrs, load_raw_pop_data, get_raw_state_mobility
from aust_covid.model import build_model
from aust_covid.plotting import plot_key_outputs, plot_cdr_examples, plot_subvariant_props, plot_dispersion_examples
from aust_covid.plotting import plot_state_mobility, plot_processed_mobility, plot_example_model_matrices
from aust_covid.calibration import get_priors
from emutools.tex import StandardTexDoc
from emutools.inputs import load_param_info
from emutools.calibration import param_table_to_tex, round_sigfig, sample_idata, get_sampled_outputs, plot_param_progression, view_posterior_comparison, tabulate_priors, tabulate_param_results

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
case_targets = load_calibration_targets(app_doc)
death_targets = load_who_data(app_doc)
serosurvey_targets = load_serosurvey_data(app_doc)

In [None]:
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml')
ifrs = get_ifrs(app_doc)
param_info['value'].update(ifrs)
parameters = param_info['value'].to_dict()

In [None]:
aust_model = build_model(app_doc, mobility_sens=True)

In [None]:
aust_model.run(parameters=parameters)

In [None]:
priors = get_priors()
prior_names = [p.name for p in priors]

In [None]:
# Set up for calibration or optimisation
def truncation_ceiling(modelled, obs, parameters, time_weights):
    return jnp.where(modelled > obs, -1e11, 0.0)


targets = [
    est.NegativeBinomialTarget('notifications_ma', case_targets, dispersion_param=esp.UniformPrior('notifications_ma_dispersion', (10.0, 140.0))),
    est.NegativeBinomialTarget('deaths_ma', death_targets, dispersion_param=esp.UniformPrior('deaths_ma_dispersion', (60.0, 200.0))),
    est.BinomialTarget('adult_seropos_prop', serosurvey_targets, pd.Series([20] * 4, index=serosurvey_targets.index)),
]
targets.append(est.CustomTarget('seropos_ceiling', pd.Series([0.04], index=[datetime(2021, 12, 1)]), truncation_ceiling, model_key='adult_seropos_prop'))

In [None]:
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, targets)

In [None]:
# with pm.Model() as pmc_model:
#     start_params = {k: np.clip(v, *calibration_model.priors[k].bounds(0.99)) for k, v in parameters.items() if k in calibration_model.priors}
#     variables = epm.use_model(calibration_model)
#     map_params = pm.find_MAP(start=start_params, vars=variables, include_transformed=False)
#     map_params = {k: float(v) for k, v in map_params.items()}
# print('Best calibration parameters found:')
# for i_param, param in enumerate([p for p in map_params if '_dispersion' not in p]):
#     print(f'   {param}: {round_sigfig(map_params[param], 4)} (within bound {priors[i_param].bounds()}')
# parameters.update(map_params)
# aust_model.run(parameters=parameters)

In [None]:
app_doc.include_table(param_table_to_tex(param_info, prior_names), section='Parameters', col_splits=[0.17, 0.15, 0.15, 0.53], longtable=True)
app_doc.include_table(tabulate_priors(priors, param_info), section='Calibration', col_splits=[0.25] * 4)

In [None]:
app_doc.save_content()

In [None]:
# with pm.Model() as pm_model:
#     variables = epm.use_model(calibration_model)
#     idata_local = pm.sample(step=[pm.DEMetropolis(variables)], draws=100, tune=0, cores=9, chains=18)
# idata_local.to_netcdf('calibration_out.nc')

In [None]:
def run_calibration(bridge: springboard.task.TaskBridge, calibration_model: BayesianCompartmentalModel):
    import multiprocessing as mp
    mp.set_start_method('forkserver')
    
    with pm.Model() as pm_model:
        variables = epm.use_model(calibration_model)
        idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=10000, tune=2000, cores=8, chains=8, progressbar=False)
    
    idata_raw.to_netcdf(str(bridge.out_path / 'calibration_out.nc'))
    bridge.logger.info('Calibration complete')
    
mspec = springboard.EC2MachineSpec(8, 2, 'compute')
tspec = springboard.TaskSpec(run_calibration, {'calibration_model': calibration_model})
run_path = springboard.launch.get_autumn_project_run_path('aust_covid', 'base_case_analysis', 'try_DEMZ_10k')
run_path

In [None]:
aust_covid_commands = [
    'git clone --branch main https://github.com/monash-emu/aust-covid',
    'pip install -e ./aust-covid',
]
runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=aust_covid_commands)

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')
app_doc.load_content()

In [None]:
run_path = 'projects/aust_covid/base_case_analysis/2023-08-30T1608-try_DEMZ_10k'

In [None]:
mr = ManagedRun(run_path)
mr.remote.download(mr.remote.list_contents()[-1])
idata = az.from_netcdf(mr.list_local()[-1])

In [None]:
# Burn
idata = idata.sel(draw=np.s_[2000:])

In [None]:
sampled_idata = sample_idata(idata, 2, calibration_model)
key_outputs = ['notifications_ma', 'adult_seropos_prop', 'deaths_ma']
variant_prop_outputs = [i for i in aust_model.derived_outputs.keys() if re.fullmatch('ba[0-5]_prop', i)]
req_outputs = key_outputs + variant_prop_outputs
output_results = get_sampled_outputs(calibration_model, sampled_idata, req_outputs, parameters)

In [None]:
plot_subvariant_props(sampled_idata, output_results, PLOT_START_DATE, ANALYSIS_END_DATE, app_doc, show_fig=True)

In [None]:
plot_key_outputs(sampled_idata, output_results, PLOT_START_DATE, ANALYSIS_END_DATE, app_doc, key_outputs, case_targets, serosurvey_targets, death_targets, show_fig=True)

In [None]:
plot_cdr_examples(sampled_idata['start_cdr'], app_doc, show_fig=True)

In [None]:
section_order = [
    'Model Structure', 
    'Population', 
    'Stratification', 
    'Mixing', 
    'Reinfection', 
    'Parameters',
    'Outputs', 
    'Targets',
    'Calibration',
    'Results',
    'Mobility',
]

In [None]:
plot_dispersion_examples(
    idata,
    aust_model,
    parameters,
    prior_names,
    targets[:2],
    PLOT_START_DATE,
    ANALYSIS_END_DATE,
    {'notifications_ma': '10, 10, 100', 'deaths_ma': '100, 10, 10'},
    app_doc,
    np.linspace(0.1, 0.9, 9),
    show_fig=True,
)

In [None]:
app_doc.write_doc(order=section_order)