List of outstanding tasks with this calibration:
- Initialise parameters from different points
- Reinstate mobility contact scaling
- Consider implementing empiric vaccination-related susceptibility
- Get back running over Colab
- Write the paper

In [None]:
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pymc as pm
import arviz as az
from jax import numpy as jnp
import re

from summer2.functions.time import get_piecewise_scalar_function, get_linear_interpolation_function, Function

from estival.model import BayesianCompartmentalModel
import estival.priors as esp
import estival.targets as est
from estival.wrappers import pymc as epm

from autumn.infrastructure.remote import springboard
from autumn.core.runs import ManagedRun
from aust_covid.inputs import load_calibration_targets, load_who_data, load_serosurvey_data, get_ifrs, load_raw_pop_data, get_raw_state_mobility
from aust_covid.model import build_model
from aust_covid.mobility import get_non_wa_mob_averages, get_constants_from_mobility, get_relative_mobility, map_mobility_locations
from aust_covid.plotting import plot_key_outputs, plot_cdr_examples, plot_subvariant_props, plot_dispersion_examples
from emutools.tex import StandardTexDoc
from emutools.inputs import load_param_info
from emutools.calibration import param_table_to_tex, round_sigfig, sample_idata, get_sampled_outputs, plot_param_progression, view_posterior_comparison, tabulate_priors, tabulate_param_results
PROJECT_PATH = Path().resolve().parent

In [None]:
analysis_start_date = datetime(2021, 7, 1)
analysis_end_date = datetime(2022, 10, 1)
plot_start_date = datetime(2021, 12, 1)
reference_date = datetime(2019, 12, 31)
targets_start_date = datetime(2022, 1, 1)

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
targets_average_window = 7
case_targets = load_calibration_targets(datetime(2021, 12, 15), targets_average_window, app_doc)[targets_start_date:]
death_targets = load_who_data(targets_average_window, app_doc)[targets_start_date:]
serosurvey_targets = load_serosurvey_data(14.0, app_doc)

In [None]:
parameters = {
    'contact_rate': 0.072,
    'latent_period': 1.8,
    'infectious_period': 2.5,
    'natural_immunity_period': 60.0,
    'start_cdr': 0.3,
    'imm_prop': 0.4,
    'imm_infect_protect': 0.4,
    'ifr_adjuster': 3.0,
    'ba1_seed_time': 619.0,
    'ba2_seed_time': 659.0,
    'ba5_seed_time': 715.0,
    'ba2_escape': 0.4,
    'ba5_escape': 0.54,
    'ba2_rel_ifr': 0.5,
    'wa_reopen_period': 50.0,
    'seed_duration': 10.0,
    'seed_rate': 1.0,
    'notifs_mean': 4.0,
    'notifs_shape': 2.0,
    'deaths_mean': 15.93,
    'deaths_shape': 5.0,
}
ifrs = get_ifrs(app_doc)
parameters.update(ifrs)

In [None]:
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml', parameters | ifrs)

In [None]:
aust_model = build_model(reference_date, analysis_start_date, analysis_end_date, app_doc, targets_average_window, mobility_sens=True)

In [None]:
aust_model.run(parameters=parameters)

In [None]:
# f = aust_model.graph.filter('mixing_matrix').get_callable()

In [None]:
# f(model_variables={'time': 1.0, 'parameters': parameters})

In [None]:
# Set up for calibration or optimisation
def truncation_ceiling(modelled, obs, parameters, time_weights):
    return jnp.where(modelled > obs, -1e11, 0.0)

priors = [
    esp.UniformPrior('contact_rate', (0.02, 0.15)),
    esp.GammaPrior.from_mode('latent_period', 2.5, 5.0),
    esp.GammaPrior.from_mode('infectious_period', 3.5, 6.0),
    esp.GammaPrior.from_mode('natural_immunity_period', 180.0, 1000.0),
    esp.UniformPrior('start_cdr', (0.1, 0.6)),
    esp.UniformPrior('imm_prop', (0.0, 1.0)),
    esp.UniformPrior('imm_infect_protect', (0.0, 1.0)),
    esp.TruncNormalPrior('ifr_adjuster', 1.0, 2.0, (0.2, np.inf)),
    esp.UniformPrior('ba1_seed_time', (580.0, 625.0)), 
    esp.UniformPrior('ba2_seed_time', (625.0, 660.0)),
    esp.UniformPrior('ba5_seed_time', (660.0, 740.0)),
    esp.BetaPrior.from_mean_and_ci('ba2_escape', 0.4, (0.2, 0.6)),
    esp.BetaPrior.from_mean_and_ci('ba5_escape', 0.4, (0.2, 0.6)),
    esp.TruncNormalPrior('ba2_rel_ifr', 0.7, 0.15, (0.2, np.inf)),
    esp.UniformPrior('wa_reopen_period', (30.0, 75.0)),
    esp.GammaPrior.from_mean('notifs_mean', 4.17, 7.0),
    esp.GammaPrior.from_mean('deaths_mean', 15.93, 18.79),
]
targets = [
    est.NegativeBinomialTarget('notifications_ma', case_targets, dispersion_param=esp.UniformPrior('notifications_ma_dispersion', (10.0, 140.0))),
    est.NegativeBinomialTarget('deaths_ma', death_targets, dispersion_param=esp.UniformPrior('deaths_ma_dispersion', (60.0, 200.0))),
    est.BinomialTarget('adult_seropos_prop', serosurvey_targets, pd.Series([20] * 4, index=serosurvey_targets.index)),
]
targets.append(est.CustomTarget('seropos_ceiling', pd.Series([0.04], index=[datetime(2021, 12, 1)]), truncation_ceiling, model_key='adult_seropos_prop'))
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, targets)

In [None]:
# with pm.Model() as pmc_model:
#     start_params = {k: np.clip(v, *calibration_model.priors[k].bounds(0.99)) for k, v in parameters.items() if k in calibration_model.priors}
#     variables = epm.use_model(calibration_model)
#     map_params = pm.find_MAP(start=start_params, vars=variables, include_transformed=False)
#     map_params = {k: float(v) for k, v in map_params.items()}
# print('Best calibration parameters found:')
# for i_param, param in enumerate([p for p in map_params if '_dispersion' not in p]):
#     print(f'   {param}: {round_sigfig(map_params[param], 4)} (within bound {priors[i_param].bounds()}')
# parameters.update(map_params)
# aust_model.run(parameters=parameters)

In [None]:
fig = make_subplots(rows=3, cols=2)
derived_outputs = aust_model.get_derived_outputs_df()
x_vals = derived_outputs.index
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['notifications_ma'], name='modelled cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=case_targets.index, y=case_targets, name='reported cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['deaths_ma'], name='deaths_ma'), row=1, col=2)
fig.add_trace(go.Scatter(x=death_targets.index, y=death_targets, name='reported deaths ma'), row=1, col=2)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['adult_seropos_prop'], name='adult seropos'), row=2, col=1)
fig.add_trace(go.Scatter(x=serosurvey_targets.index, y=serosurvey_targets, name='seropos estimates'), row=2, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['reproduction_number'], name='reproduction number'), row=2, col=2)
for agegroup in aust_model.stratifications['agegroup'].strata:
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=1)
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=2)
fig['layout']['yaxis6'].update(type='log', range=[-2.0, 2.0])
fig.update_xaxes(range=(plot_start_date, analysis_end_date))
fig.update_layout(height=600, width=1200)
fig.show()

In [None]:
prior_names = [p.name for p in priors]
app_doc.include_table(param_table_to_tex(param_info, prior_names), section='Parameters', col_splits=[0.17, 0.15, 0.15, 0.53], longtable=True)
app_doc.include_table(tabulate_priors(priors, param_info), section='Calibration', col_splits=[0.25] * 4)

In [None]:
# with pm.Model() as pm_model:
#     variables = epm.use_model(calibration_model)
#     idata_local = pm.sample(step=[pm.DEMetropolis(variables)], draws=100, tune=0, cores=9, chains=18)
# idata_local.to_netcdf('calibration_out.nc')

In [None]:
def run_calibration(bridge: springboard.task.TaskBridge, calibration_model: BayesianCompartmentalModel):
    import multiprocessing as mp
    mp.set_start_method('forkserver')
    
    with pm.Model() as pm_model:
        variables = epm.use_model(calibration_model)
        idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=10000, tune=2000, cores=8, chains=8, progressbar=False)
    
    idata_raw.to_netcdf(str(bridge.out_path / 'calibration_out.nc'))
    bridge.logger.info('Calibration complete')
    
mspec = springboard.EC2MachineSpec(8, 2, 'compute')
tspec = springboard.TaskSpec(run_calibration, {'calibration_model': calibration_model})
run_path = springboard.launch.get_autumn_project_run_path('aust_covid', 'base_case_analysis', 'try_DEMZ_10k')
run_path

In [None]:
aust_covid_commands = [
    'git clone --branch main https://github.com/monash-emu/aust-covid',
    'pip install -e ./aust-covid',
]
runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=aust_covid_commands)

In [None]:
run_path = 'projects/aust_covid/base_case_analysis/2023-08-30T1608-try_DEMZ_10k'

In [None]:
mr = ManagedRun(run_path)
mr.remote.download(mr.remote.list_contents()[-1])
idata = az.from_netcdf(mr.list_local()[-1])

In [None]:
# Burn
idata = idata.sel(draw=np.s_[2000:])

In [None]:
app_doc.include_table(tabulate_param_results(idata, priors, param_info), section='Calibration', col_splits=[0.142] * 7, table_width=12.0, longtable=True)

In [None]:
n_half_priors = round(len(priors) / 2)
plot_param_progression(idata, param_info, app_doc, True, request_vars=prior_names[:n_half_priors], name_ext='_first')

In [None]:
plot_param_progression(idata, param_info, app_doc, True, request_vars=prior_names[n_half_priors:], name_ext='_last')

In [None]:
view_posterior_comparison(idata, priors, prior_names[:n_half_priors], param_info['abbreviations'].to_dict(), 0.995, app_doc, name_ext='_first')

In [None]:
view_posterior_comparison(idata, priors, prior_names[n_half_priors:], param_info['abbreviations'].to_dict(), 0.995, app_doc, name_ext='_last')

In [None]:
sampled_idata = sample_idata(idata, 2, calibration_model)
key_outputs = ['notifications_ma', 'adult_seropos_prop', 'deaths_ma']
variant_prop_outputs = [i for i in aust_model.derived_outputs.keys() if re.fullmatch('ba[0-5]_prop', i)]
req_outputs = key_outputs + variant_prop_outputs
output_results = get_sampled_outputs(calibration_model, sampled_idata, req_outputs, parameters)

In [None]:
plot_subvariant_props(sampled_idata, output_results, plot_start_date, analysis_end_date, app_doc, show_fig=True)

In [None]:
plot_key_outputs(sampled_idata, output_results, plot_start_date, analysis_end_date, app_doc, key_outputs, case_targets, serosurvey_targets, death_targets, show_fig=True)

In [None]:
plot_cdr_examples(sampled_idata['start_cdr'], app_doc, show_fig=True)

In [None]:
section_order = [
    'Model Structure', 
    'Population', 
    'Stratification', 
    'Mixing', 
    'Reinfection', 
    'Parameters',
    'Outputs', 
    'Targets',
    'Calibration',
    'Results',
    'Mobility',
]

In [None]:
plot_dispersion_examples(
    idata,
    aust_model,
    parameters,
    prior_names,
    targets[:2],
    targets_start_date,
    analysis_end_date,
    {'notifications_ma': '10, 10, 100', 'deaths_ma': '100, 10, 10'},
    app_doc,
    np.linspace(0.1, 0.9, 9),
    show_fig=True,
)

In [None]:
app_doc.write_doc(order=section_order)