In [None]:
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
pd.options.plotting.backend = 'plotly'
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pymc as pm

from estival.model import BayesianCompartmentalModel
import estival.priors as esp
import estival.targets as est
from estival.wrappers import pymc as epm

from aust_covid.inputs import load_calibration_targets, load_who_data, load_serosurvey_data
from aust_covid.model import MATRIX_LOCATIONS, build_model
from general_utils.tex_utils import StandardTexDoc
from general_utils.parameter_utils import load_param_info
from general_utils.calibration_utils import round_sigfig

PROJECT_PATH = Path().resolve().parent
DATA_PATH = PROJECT_PATH / 'data'

In [None]:
analysis_start_date = datetime(2021, 7, 1)
analysis_end_date = datetime(2022, 10, 1)
plot_start_date = datetime(2021, 12, 1)
reference_date = datetime(2019, 12, 31)

In [None]:
app_doc = StandardTexDoc(PROJECT_PATH / 'supplement', 'supplement', "Australia's 2023 Omicron Waves Supplement", 'austcovid')

In [None]:
targets_average_window = 7
case_targets = load_calibration_targets(datetime(2021, 12, 15), targets_average_window, app_doc)
death_targets = load_who_data(targets_average_window, app_doc)
serosurvey_targets = load_serosurvey_data(14.0, app_doc)

In [None]:
parameters = {
    'ba1_seed_time': 620.0,
    'start_cdr': 0.3,
    'contact_rate': 0.07,
    'vacc_prop': 0.4,
    'infectious_period': 2.5,
    'natural_immunity_period': 60.0,
    'ba2_seed_time': 660.0,
    'ba2_escape': 0.4,
    'ba5_seed_time': 715.0,
    'ba5_escape': 0.54,
    'latent_period': 1.8,
    'seed_rate': 1.0,
    'seed_duration': 10.0,
    'notifs_shape': 2.0,
    'notifs_mean': 4.0,
    'vacc_infect_protect': 0.4,
    'rural_susc_adj': 1.0,
    'wa_reopen_period': 30.0,
    'deaths_shape': 2.0,
    'deaths_mean': 20.0,
    'ifr_0': 0.0,
    'ifr_5': 0.0,
    'ifr_10': 0.0,
    'ifr_15': 2.6e-5,
    'ifr_20': 2.6e-5,
    'ifr_25': 2.6e-5,
    'ifr_30': 2.6e-5,
    'ifr_35': 5.8e-5,
    'ifr_40': 5.8e-5,
    'ifr_45': 5.8e-5,
    'ifr_50': 14.6e-5,
    'ifr_55': 14.6e-5,
    'ifr_60': 24.6e-5,
    'ifr_65': 24.6e-5,
    'ifr_70': 246e-5,
    'ifr_75': 246e-5,
}
param_info = load_param_info(PROJECT_PATH / 'inputs' / 'parameters.yml', parameters)

In [None]:
aust_model = build_model(reference_date, analysis_start_date, analysis_end_date, app_doc)
app_doc.write_doc()

In [None]:
# Set up for calibration or optimisation
priors = [
    esp.UniformPrior('ba1_seed_time', (570.0, 670.0)), 
    esp.UniformPrior('contact_rate', (0.03, 0.1)),
    esp.UniformPrior('infectious_period', (2.5, 7.0)),
    esp.UniformPrior('start_cdr', (0.25, 0.5)),
    esp.UniformPrior('latent_period', (1.5, 4.0)),
]
targets = [
    est.TruncatedNormalTarget('notifications', case_targets, [0.0, np.inf], case_targets.max() * 0.1),
    est.TruncatedNormalTarget('deaths', death_targets, [0.0, np.inf], death_targets.max() * 0.1),
    est.BinomialTarget('adult_seropos_prop', serosurvey_targets, pd.Series([20] * 4, index=serosurvey_targets.index)),
]
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, targets)

In [None]:
with pm.Model() as pmc_model:
    start_params = {k: np.clip(v, *calibration_model.priors[k].bounds(0.99)) for k, v in parameters.items() if k in calibration_model.priors}
    variables = epm.use_model(calibration_model)
    map_params = pm.find_MAP(start=start_params, vars=variables, include_transformed=False)
    map_params = {k: float(v) for k, v in map_params.items()}
print('Best calibration parameters found:')
for i_param, param in enumerate(map_params):
    print(f'   {param}: {round_sigfig(map_params[param], 4)} (within bound {priors[i_param].bounds()}')

In [None]:
# Run with optimised parameters
parameters.update(map_params)
aust_model.run(parameters=parameters)

In [None]:
fig = make_subplots(rows=3, cols=2)
derived_outputs = aust_model.get_derived_outputs_df()
x_vals = derived_outputs.index
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['notifications'], name='modelled cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=case_targets.index, y=case_targets, name='reported cases'), row=1, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['deaths'], name='deaths'), row=1, col=2)
fig.add_trace(go.Scatter(x=death_targets.index, y=death_targets, name='reported deaths'), row=1, col=2)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['adult_seropos_prop'], name='adult seropos'), row=2, col=1)
fig.add_trace(go.Scatter(x=serosurvey_targets.index, y=serosurvey_targets, name='seropos estimates'), row=2, col=1)
fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs['reproduction_number'], name='reproduction number'), row=2, col=2)
for agegroup in aust_model.stratifications['agegroup'].strata:
    fig.add_trace(go.Scatter(x=x_vals, y=derived_outputs[f'deathsXagegroup_{agegroup}'], name=f'{agegroup} deaths'), row=3, col=1)
fig.update_xaxes(range=(plot_start_date, analysis_end_date))
fig.update_layout(height=600, width=1200)
fig.show()