In [None]:
from pathlib import Path
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
pd.options.plotting.backend = 'plotly'
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pymc as pm
import arviz as az
import plotly.express as px
from jax import numpy as jnp

from estival.model import BayesianCompartmentalModel
import estival.priors as esp
import estival.targets as est
from estival.wrappers import pymc as epm

from autumn.infrastructure.remote import springboard
from autumn.core.runs import ManagedRun

from aust_covid.inputs import load_calibration_targets, load_who_data, load_serosurvey_data, get_ifrs
from aust_covid.model import MATRIX_LOCATIONS, build_model
from general_utils.calibration_utils import param_table_to_tex
from general_utils.tex_utils import StandardTexDoc
from general_utils.parameter_utils import load_param_info
from general_utils.calibration_utils import round_sigfig, sample_idata, get_sampled_outputs, melt_spaghetti, plot_param_progression, plot_param_posterior, tabulate_priors, tabulate_param_results
PROJECT_PATH = Path().resolve().parent
DATA_PATH = PROJECT_PATH / 'data'

In [None]:
analysis_start_date = datetime(2021, 7, 1)
analysis_end_date = datetime(2022, 10, 1)
plot_start_date = datetime(2021, 12, 1)
reference_date = datetime(2019, 12, 31)

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
targets_average_window = 7
case_targets = load_calibration_targets(datetime(2021, 12, 15), targets_average_window, app_doc)
death_targets = load_who_data(targets_average_window, app_doc)
serosurvey_targets = load_serosurvey_data(14.0, app_doc)

In [None]:
parameters = {
    'ba1_seed_time': 620.0,
    'start_cdr': 0.3,
    'contact_rate': 0.065,
    'vacc_prop': 0.4,
    'infectious_period': 2.5,
    'natural_immunity_period': 60.0,
    'ba2_seed_time': 660.0,
    'ba2_escape': 0.4,
    'ba5_seed_time': 715.0,
    'ba5_escape': 0.54,
    'latent_period': 1.8,
    'seed_rate': 1.0,
    'seed_duration': 10.0,
    'notifs_shape': 2.0,
    'notifs_mean': 4.0,
    'vacc_infect_protect': 0.4,
    'wa_reopen_period': 30.0,
    'deaths_shape': 2.0,
    'deaths_mean': 20.0,
    'ba2_rel_ifr': 0.5,
    'ifr_adjuster': 3.0,
}
ifrs = get_ifrs(app_doc)
parameters.update(ifrs)

In [None]:
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml', parameters | ifrs)

In [None]:
aust_model = build_model(reference_date, analysis_start_date, analysis_end_date, app_doc)

In [None]:
aust_model.run(parameters=parameters)

In [None]:
# Set up for calibration or optimisation
def truncation_ceiling(modelled, obs, parameters, time_weights):
    return jnp.where(modelled > obs, -1e11, 0.0)

priors = [
    esp.UniformPrior('ba1_seed_time', (580.0, 620.0)), 
    esp.UniformPrior('ba2_seed_time', (620.0, 660.0)),
    esp.UniformPrior('ba5_seed_time', (660.0, 700.0)),
    esp.UniformPrior('ba2_escape', (0.2, 0.8)),
    esp.UniformPrior('ba5_escape', (0.2, 0.8)),    
    esp.UniformPrior('contact_rate', (0.02, 0.15)),
    esp.UniformPrior('infectious_period', (2.0, 6.0)),
    esp.UniformPrior('start_cdr', (0.1, 0.5)),
    esp.UniformPrior('latent_period', (1.0, 5.0)),
    esp.UniformPrior('ba2_rel_ifr', (0.2, 1.2)),
    esp.UniformPrior('vacc_prop', (0.2, 0.8)),
    esp.UniformPrior('vacc_infect_protect', (0.2, 0.8)),
    esp.UniformPrior('natural_immunity_period', (40.0, 100.0)),
    esp.UniformPrior('ifr_adjuster', (1.0, 5.0)),
]
targets = [
    est.NegativeBinomialTarget('notifications', case_targets, dispersion_param=esp.UniformPrior('cases_dispersion', (40.0, 140.0))),
    est.NegativeBinomialTarget('deaths', death_targets, dispersion_param=esp.UniformPrior('deaths_dispersion', (60.0, 200.0))),
    est.BinomialTarget('adult_seropos_prop', serosurvey_targets, pd.Series([20] * 4, index=serosurvey_targets.index)),
]
targets.append(est.CustomTarget('adult_seropos_prop', pd.Series([0.15], index=[datetime(2022, 1, 1)]), truncation_ceiling))
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, targets)

In [None]:
with pm.Model() as pmc_model:
    start_params = {k: np.clip(v, *calibration_model.priors[k].bounds(0.99)) for k, v in parameters.items() if k in calibration_model.priors}
    variables = epm.use_model(calibration_model)
    map_params = pm.find_MAP(start=start_params, vars=variables, include_transformed=False)
    map_params = {k: float(v) for k, v in map_params.items()}
print('Best calibration parameters found:')
for i_param, param in enumerate([p for p in map_params if '_dispersion' not in p]):
    print(f'   {param}: {round_sigfig(map_params[param], 4)} (within bound {priors[i_param].bounds()}')
parameters.update(map_params)
aust_model.run(parameters=parameters)

In [None]:
fig = make_subplots(rows=3, cols=2)
derived_outputs = aust_model.get_derived_outputs_df()
x_vals = derived_outputs.index
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['notifications'], name='modelled cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=case_targets.index, y=case_targets, name='reported cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['deaths'], name='deaths'), row=1, col=2)
fig.add_trace(go.Scatter(x=death_targets.index, y=death_targets, name='reported deaths'), row=1, col=2)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['adult_seropos_prop'], name='adult seropos'), row=2, col=1)
fig.add_trace(go.Scatter(x=serosurvey_targets.index, y=serosurvey_targets, name='seropos estimates'), row=2, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['reproduction_number'], name='reproduction number'), row=2, col=2)
for agegroup in aust_model.stratifications['agegroup'].strata:
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=1)
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=2)
fig['layout']['yaxis6'].update(type='log', range=[-2.0, 2.0])
fig.update_xaxes(range=(plot_start_date, analysis_end_date))
fig.update_layout(height=600, width=1200)
fig.show()

In [None]:
app_doc.include_table(param_table_to_tex(param_info, [i.name for i in priors]), section='Parameters', widths=[3.0, 2.0, 1.5, 3.5], longtable=True)
app_doc.include_table(tabulate_priors(priors, param_info), section='Calibration')

In [None]:
with pm.Model() as pm_model:
    variables = epm.use_model(calibration_model)
    idata_local = pm.sample(step=[pm.DEMetropolis(variables)], draws=100, tune=0, cores=9, chains=18, progressbar=False)
idata_local.to_netcdf('calibration_out.nc')

In [None]:
def run_calibration(bridge: springboard.task.TaskBridge, calibration_model: BayesianCompartmentalModel):
    import multiprocessing as mp
    mp.set_start_method('forkserver')
    
    with pm.Model() as pm_model:
        variables = epm.use_model(calibration_model)
        idata_raw = pm.sample(step=[pm.DEMetropolis(variables)], draws=5000, tune=0, cores=9, chains=18, progressbar=False)
    
    idata_raw.to_netcdf(str(bridge.out_path / 'calibration_out.nc'))
    bridge.logger.info('Calibration complete')
    
mspec = springboard.EC2MachineSpec(8, 2, 'compute')
tspec = springboard.TaskSpec(run_calibration, {'calibration_model': calibration_model})
run_path = springboard.launch.get_autumn_project_run_path('aust_covid', 'finalising_epi_model', 'revise_priors')
run_path

In [None]:
aust_covid_commands = [
    'git clone --branch main https://github.com/monash-emu/aust-covid',
    'pip install -e ./aust-covid',
]
runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=aust_covid_commands)

In [None]:
run_path = 'projects/aust_covid/initial_exploration/2023-08-14T1255-revise_chain_number'

In [None]:
mr = ManagedRun(run_path)
mr.remote.download(mr.remote.list_contents()[-1])
idata = az.from_netcdf(mr.list_local()[-1])

In [None]:
# Burn
idata = idata.sel(draw=np.s_[1500:])

In [None]:
# app_doc.include_table(tabulate_param_results(idata, priors, param_info), section='Calibration', widths=[2.0, 1.2, 1.2, 1.0, 1.0, 1.0, 1.5])

In [None]:
plot_param_progression(idata, param_info, app_doc, True)

In [None]:
plot_param_posterior(idata, param_info, app_doc)

In [None]:
section_order = [
    'Model Structure', 
    'Population', 
    'Stratification', 
    'Mixing', 
    'Reinfection', 
    'Parameters',
    'Outputs', 
    'Targets',
    'Calibration',
]

In [None]:
app_doc.write_doc(order=section_order)

In [None]:
sampled_idata = sample_idata(idata, 20, calibration_model)
key_outputs = ['notifications', 'adult_seropos_prop', 'deaths']
variant_prop_outputs = [i for i in aust_model.get_derived_outputs_df().columns if 'ba' in i]
req_outputs = key_outputs + variant_prop_outputs
output_results = get_sampled_outputs(calibration_model, sampled_idata, req_outputs, parameters)

In [None]:
fig = go.Figure()
ba1_results = melt_spaghetti(output_results, 'ba1_prop', sampled_idata)
ba5_results = melt_spaghetti(output_results, 'ba5_prop', sampled_idata)
ba5_results['value'] = 1.0 - ba5_results['value']  # Flip BA.5 results
fig.add_traces(px.line(ba1_results, y='value', color='chain', line_group='draw', hover_data=ba1_results.columns).data)
fig.add_traces(px.line(ba5_results, y='value', color='chain', line_group='draw', hover_data=ba5_results.columns).data)
voc_emerge_df = pd.DataFrame(
    {
        'ba1': [datetime(2021, 11, 22), datetime(2021, 11, 29), datetime(2021, 12, 20), 'blue'],
        'ba2': [datetime(2021, 11, 29), datetime(2022, 1, 10), datetime(2022, 3, 7), 'red'], 
        'ba5': [datetime(2022, 3, 28), datetime(2022, 5, 16), datetime(2022, 6, 27), 'green'],
    },
    index=['any', '>1%', '>50%', 'colour']
)
lag = timedelta(days=3.5)  # Dates are given as first day of week in which VoC was first detected
for voc in voc_emerge_df:
    voc_info = voc_emerge_df[voc]
    colour = voc_info['colour']
    fig.add_vline(voc_info['any'] + lag, line_dash='dot', line_color=colour)
    fig.add_vline(voc_info['>1%'] + lag, line_dash='dash', line_color=colour)
    fig.add_vline(voc_info['>50%'] + lag, line_color=colour)
fig.update_xaxes(range=(plot_start_date, analysis_end_date))
fig.update_yaxes(range=(0.0, 1.0))

In [None]:
fig = make_subplots(rows=3, cols=1, subplot_titles=key_outputs)
for i_out, out in enumerate(key_outputs):
    spaghetti = melt_spaghetti(output_results, out, sampled_idata)
    lines = px.line(spaghetti, y='value', color='chain', line_group='draw', hover_data=spaghetti.columns)
    fig.add_traces(lines.data, rows=i_out + 1, cols=1)
fig.add_trace(
    go.Scatter(
        x=case_targets.index, 
        y=case_targets, 
        name='reported cases',
        mode='markers',
        marker={'color': 'LightBlue', 'size': 4, 'line': {'color': 'black', 'width': 1}},
    ), 
    row=1, 
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=serosurvey_targets.index, 
        y=serosurvey_targets, 
        name='serosurveys',
        mode='markers',
        marker={'color': 'white', 'size': 20, 'line': {'color': 'black', 'width': 1}},
    ), 
    row=2, 
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=death_targets.index, 
        y=death_targets, 
        name='reported deaths', 
        mode='markers',
        marker={'color': 'Red', 'size': 4, 'line': {'color': 'black', 'width': 1}},
    ),
    row=3, 
    col=1,
)
fig.update_xaxes(range=(plot_start_date, analysis_end_date))
fig.update_layout(height=1000, width=1000)
fig