In [None]:
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
pd.options.plotting.backend = 'plotly'
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pymc as pm
import arviz as az
import plotly.express as px
from jax import numpy as jnp

from estival.model import BayesianCompartmentalModel
import estival.priors as esp
import estival.targets as est
from estival.wrappers import pymc as epm

from autumn.infrastructure.remote import springboard
from autumn.core.runs import ManagedRun

from aust_covid.inputs import load_calibration_targets, load_who_data, load_serosurvey_data
from aust_covid.model import MATRIX_LOCATIONS, build_model
from general_utils.calibration_utils import param_table_to_tex
from general_utils.tex_utils import StandardTexDoc
from general_utils.parameter_utils import load_param_info
from general_utils.calibration_utils import round_sigfig, sample_idata, get_sampled_outputs, melt_spaghetti, plot_param_progression, plot_param_posterior, tabulate_priors, tabulate_param_results
PROJECT_PATH = Path().resolve().parent
DATA_PATH = PROJECT_PATH / 'data'

In [None]:
analysis_start_date = datetime(2021, 7, 1)
analysis_end_date = datetime(2022, 10, 1)
plot_start_date = datetime(2021, 12, 1)
reference_date = datetime(2019, 12, 31)

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
targets_average_window = 7
case_targets = load_calibration_targets(datetime(2021, 12, 15), targets_average_window, app_doc)
death_targets = load_who_data(targets_average_window, app_doc)
serosurvey_targets = load_serosurvey_data(14.0, app_doc)

In [None]:
parameters = {
    'ba1_seed_time': 620.0,
    'start_cdr': 0.3,
    'contact_rate': 0.065,
    'vacc_prop': 0.4,
    'infectious_period': 2.5,
    'natural_immunity_period': 60.0,
    'ba2_seed_time': 660.0,
    'ba2_escape': 0.4,
    'ba5_seed_time': 715.0,
    'ba5_escape': 0.54,
    'latent_period': 1.8,
    'seed_rate': 1.0,
    'seed_duration': 10.0,
    'notifs_shape': 2.0,
    'notifs_mean': 4.0,
    'vacc_infect_protect': 0.4,
    'wa_reopen_period': 30.0,
    'deaths_shape': 2.0,
    'deaths_mean': 20.0,
    'ifr_0': 0.0,
    'ifr_5': 0.0,
    'ifr_10': 0.0,
    'ifr_15': 2.6e-5,
    'ifr_20': 2.6e-5,
    'ifr_25': 2.6e-5,
    'ifr_30': 2.6e-5,
    'ifr_35': 5.8e-5,
    'ifr_40': 5.8e-5,
    'ifr_45': 5.8e-5,
    'ifr_50': 14.6e-5,
    'ifr_55': 14.6e-5,
    'ifr_60': 24.6e-5,
    'ifr_65': 24.6e-5,
    'ifr_70': 24.6e-5,
    'ifr_75': 5e-3,
    'ba2_rel_ifr': 0.5,
}
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml', parameters)

In [None]:
aust_model = build_model(reference_date, analysis_start_date, analysis_end_date, app_doc)

In [None]:
aust_model.run(parameters=parameters)

In [None]:
# Set up for calibration or optimisation
def truncation_ceiling(modelled, obs, parameters, time_weights):
    return jnp.where(modelled > obs, -1e11, 0.0)

priors = [
    esp.UniformPrior('ba1_seed_time', (600.0, 630.0)), 
    esp.UniformPrior('ba2_seed_time', (630.0, 660.0)),
    esp.UniformPrior('ba5_seed_time', (660.0, 700.0)),
    esp.UniformPrior('ba2_escape', (0.2, 0.8)),
    esp.UniformPrior('ba5_escape', (0.2, 0.9)),    
    esp.UniformPrior('contact_rate', (0.02, 0.1)),
    esp.UniformPrior('infectious_period', (0.0, 5.0)),
    esp.UniformPrior('start_cdr', (0.1, 0.5)),
    esp.UniformPrior('latent_period', (1.0, 4.0)),
    esp.UniformPrior('ba2_rel_ifr', (0.2, 1.2)),
    esp.UniformPrior('vacc_prop', (0.0, 1.0)),
    esp.UniformPrior('vacc_infect_protect', (0.0, 1.0)),
    esp.UniformPrior('natural_immunity_period', (40.0, 100.0)),
    esp.UniformPrior('ifr_75', (1e-4, 1e-2)),
]
targets = [
    est.NegativeBinomialTarget('notifications', case_targets, dispersion_param=esp.UniformPrior('cases_dispersion', (60.0, 200.0))),
    est.NegativeBinomialTarget('deaths', death_targets, dispersion_param=esp.UniformPrior('deaths_dispersion', (100.0, 240.0))),
    est.BinomialTarget('adult_seropos_prop', serosurvey_targets, pd.Series([20] * 4, index=serosurvey_targets.index)),
]
targets.append(est.CustomTarget('adult_seropos_prop', pd.Series([0.15], index=[datetime(2022, 1, 1)]), truncation_ceiling))
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, targets)

In [None]:
# with pm.Model() as pmc_model:
#     start_params = {k: np.clip(v, *calibration_model.priors[k].bounds(0.99)) for k, v in parameters.items() if k in calibration_model.priors}
#     variables = epm.use_model(calibration_model)
#     map_params = pm.find_MAP(start=start_params, vars=variables, include_transformed=False)
#     map_params = {k: float(v) for k, v in map_params.items()}
# print('Best calibration parameters found:')
# for i_param, param in enumerate(map_params):
#     print(f'   {param}: {round_sigfig(map_params[param], 4)} (within bound {priors[i_param].bounds()}')
# parameters.update(map_params)
# aust_model.run(parameters=parameters)

In [None]:
app_doc.include_table(param_table_to_tex(param_info, [i.name for i in priors]), section='Parameters', widths=[3.0, 2.0, 1.5, 3.5], longtable=True)
app_doc.include_table(tabulate_priors(priors, param_info), section='Calibration')

In [None]:
fig = make_subplots(rows=3, cols=2)
derived_outputs = aust_model.get_derived_outputs_df()
x_vals = derived_outputs.index
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['notifications'], name='modelled cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=case_targets.index, y=case_targets, name='reported cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['deaths'], name='deaths'), row=1, col=2)
fig.add_trace(go.Scatter(x=death_targets.index, y=death_targets, name='reported deaths'), row=1, col=2)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['adult_seropos_prop'], name='adult seropos'), row=2, col=1)
fig.add_trace(go.Scatter(x=serosurvey_targets.index, y=serosurvey_targets, name='seropos estimates'), row=2, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['reproduction_number'], name='reproduction number'), row=2, col=2)
for agegroup in aust_model.stratifications['agegroup'].strata:
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=1)
fig.update_xaxes(range=(plot_start_date, analysis_end_date))
fig.update_layout(height=600, width=1200)
fig.show()

In [None]:
# def run_calibration(bridge: springboard.task.TaskBridge, calibration_model: BayesianCompartmentalModel):
#     import multiprocessing as mp
#     mp.set_start_method('forkserver')
    
#     with pm.Model() as pm_model:
#         variables = epm.use_model(calibration_model)
#         idata_raw = pm.sample(step=[pm.DEMetropolis(variables)], draws=20000, tune=0, cores=8, chains=8, progressbar=False)
    
#     idata_raw.to_netcdf(str(bridge.out_path / 'calibration_out.nc'))
#     bridge.logger.info('Calibration complete')
    
# mspec = springboard.EC2MachineSpec(8, 2, 'compute')
# tspec = springboard.TaskSpec(run_calibration, {'calibration_model': calibration_model})
# run_path = springboard.launch.get_autumn_project_run_path('aust_covid', 'initial_exploration', 'fine_tune_priors_long_run')
# run_path

In [None]:
# aust_covid_commands = [
#     'git clone --branch main https://github.com/monash-emu/aust-covid',
#     'pip install -e ./aust-covid',
# ]
# runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=aust_covid_commands)

In [None]:
run_path = 'projects/aust_covid/initial_exploration/2023-08-09T2216-fine_tune_priors_long_run'

In [None]:
mr = ManagedRun(run_path)
mr.remote.download(mr.remote.list_contents()[-1])
idata = az.from_netcdf(mr.list_local()[-1])

In [None]:
# Burn
idata = idata.sel(draw=np.s_[7000:])

In [None]:
app_doc.include_table(tabulate_param_results(idata, priors, param_info), section='Calibration', widths=[2.0, 1.2, 1.2, 1.0, 1.0, 1.0, 1.5])

In [None]:
plot_param_progression(idata, param_info, app_doc, True)

In [None]:
plot_param_posterior(idata, param_info, app_doc)

In [None]:
section_order = [
    'Model Structure', 
    'Population', 
    'Stratification', 
    'Mixing', 
    'Reinfection', 
    'Parameters',
    'Outputs', 
    'Targets',
    'Calibration',
]

In [None]:
app_doc.write_doc(order=section_order)

In [None]:
sampled_idata = sample_idata(idata, 25, calibration_model)
req_outputs = ['notifications', 'adult_seropos_prop', 'deaths']
output_results = get_sampled_outputs(calibration_model, sampled_idata, req_outputs, parameters)

In [None]:
fig = make_subplots(rows=3, cols=1, subplot_titles=req_outputs)
for i_out, out in enumerate(req_outputs):
    spaghetti = melt_spaghetti(output_results, out, sampled_idata)
    lines = px.line(spaghetti, y='value', color='chain', line_group='draw', hover_data=spaghetti.columns)
    fig.add_traces(lines.data, rows=i_out + 1, cols=1)
fig.add_trace(
    go.Scatter(
        x=case_targets.index, 
        y=case_targets, 
        name='reported cases',
        mode='markers',
        marker={'color': 'LightBlue', 'size': 4, 'line': {'color': 'black', 'width': 1}},
    ), 
    row=1, 
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=serosurvey_targets.index, 
        y=serosurvey_targets, 
        name='serosurveys',
        mode='markers',
        marker={'color': 'white', 'size': 20, 'line': {'color': 'black', 'width': 1}},
    ), 
    row=2, 
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=death_targets.index, 
        y=death_targets, 
        name='reported deaths', 
        mode='markers',
        marker={'color': 'Pink', 'size': 4, 'line': {'color': 'black', 'width': 1}},
    ),
    row=3, 
    col=1,
)
fig.update_xaxes(range=(plot_start_date, analysis_end_date))
fig.update_layout(height=1000, width=1000)
fig