In [None]:
from pathlib import Path

In [None]:
# This is required for parallel evaluation in notebooks
# Note that if running under (non-WSL) Windows, you should
# disable this line, and use single threaded evaluation in pymc

# import multiprocessing as mp
# mp.set_start_method('forkserver')

In [None]:
# Installation in case running over Colab
try:
    import google.colab
    %pip install estival
    %pip install pylatex==1.4.1
    %pip install kaleido
    ! git clone https://github.com/monash-emu/aust-covid
    %cd aust-covid
    %pip install -e ./
    PROJECT_PATH = Path().resolve()
except:
    PROJECT_PATH = Path().resolve().parent

DATA_PATH = PROJECT_PATH / "data"
SUPPLEMENT_PATH = PROJECT_PATH / "supplement"

In [None]:
import pandas as pd
pd.options.plotting.backend = "plotly"
from datetime import datetime
import pylatex as pl
from pylatex.utils import NoEscape
import matplotlib.pyplot as plt
from aust_covid.model import build_aust_model
from aust_covid.calibration import DocumentedCalibration
from aust_covid.output_utils import convert_idata_to_df, run_samples_through_model, round_sigfig, plot_from_model_runs_df
import yaml
import nevergrad as ng
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from estival.model import BayesianCompartmentalModel
from estival.optimization.nevergrad import optimize_model
from estival.priors import UniformPrior
from estival.targets import NegativeBinomialTarget, CustomTarget
from estival.calibration import pymc as epm
import pymc as pm
import arviz as az

In [None]:
# Data inputs
aust_data = pd.read_csv(DATA_PATH / "Aus_covid_data.csv", index_col="date")
aust_data.index = pd.to_datetime(aust_data.index)

# Extract national
national_data = aust_data[aust_data["region"] == "AUS"]
smoothed_national_cases = national_data["cases"].rolling(window=7).mean().dropna()

# Extract non-WA
non_wa_data = aust_data.loc[(aust_data["region"] != "AUS") & (aust_data["region"] != "WA")]
non_wa_data = non_wa_data.groupby(non_wa_data.index).sum()
smoothed_non_wa_cases = non_wa_data["cases"].rolling(window=7).mean().dropna()

In [None]:
# Set up for manual run with supplementary material document
supplement = pl.Document()
supplement.preamble.append(pl.Package("biblatex", options=["sorting=none"]))
supplement.preamble.append(pl.Command("addbibresource", arguments=["austcovid.bib"]))
supplement.preamble.append(pl.Command("title", "Supplemental Appendix"))
supplement.append(NoEscape(r"\maketitle"))

In [None]:
start_date = datetime(2021, 9, 1)
end_date = datetime(2022, 10, 1)
parameters = {
    "contact_rate": 0.048,
    "infectious_period": 5.0,
    "latent_period": 2.0,
    "cdr": 0.1,
    "seed_rate": 1.0,
    "seed_duration": 1.0,
    "ba1_seed_time": 660.0,
    "ba2_seed_time": 688.0,
    "ba5_seed_time": 720.0,
    "ba2_escape": 0.45,
    "ba5_escape": 0.38,
    "notifs_shape": 2.0,
    "notifs_mean": 4.0,
    "deaths_shape": 2.0,
    "deaths_mean": 20.0,
    "natural_immunity_period": 50.0,
    "ifr_0": 0.0,
    "ifr_5": 0.0,
    "ifr_10": 0.0,
    "ifr_15": 0.0,
    "ifr_20": 0.0,
    "ifr_25": 0.0,
    "ifr_30": 0.0,
    "ifr_35": 0.0,
    "ifr_40": 0.0,
    "ifr_45": 0.0,
    "ifr_50": 0.0,
    "ifr_55": 0.0,
    "ifr_60": 0.0,
    "ifr_65": 0.0,
    "ifr_70": 0.01,
}

aust_model = build_aust_model(start_date, end_date, supplement, add_documentation=False)
aust_model.run(parameters=parameters)

In [None]:
# Calibration settings
with open(PROJECT_PATH / "inputs/parameters.yml", "r") as param_file:
    param_info = yaml.safe_load(param_file)
param_descriptions = param_info["descriptions"]
param_units = param_info["units"]
param_evidence = param_info["evidence"]

iterations = 500
burn_in = 100
priors = [
    UniformPrior("contact_rate", (0.03, 0.06)),
    UniformPrior("infectious_period", (3.0, 7.0)),
    UniformPrior("ba2_escape", (0.3, 0.7)),
    UniformPrior("ba5_escape", (0.3, 0.7)),
    UniformPrior("ba1_seed_time", (645.0, 665.0)),
    UniformPrior("ba2_seed_time", (675.0, 700.0)),
    UniformPrior("ba5_seed_time", (705.0, 730.0)),
    UniformPrior("cdr", (0.05, 0.5)),
]
smoothed_non_wa_cases_intindex = smoothed_non_wa_cases.copy()
smoothed_non_wa_cases_intindex.index=(smoothed_non_wa_cases.index - aust_model.ref_date).days
def least_squares(modelled, obs, parameters, time_weights):
    return 0.0 - (((modelled - obs) ** 2.0)).sum()
targets = [
    CustomTarget("notifications", smoothed_non_wa_cases_intindex, least_squares)
]
binom_targets = [
    NegativeBinomialTarget("notifications", smoothed_non_wa_cases_intindex, 500.0),
]
aust_model = build_aust_model(start_date, end_date, supplement, True)
calibration_model = BayesianCompartmentalModel(aust_model, parameters, priors, binom_targets)

In [None]:
calibration_model._evaluators["notifications"]

In [None]:
print("Optimising with nevergrad \n Progression of loss function values:")
optim_runner = optimize_model(calibration_model)
for i in range(10):
    rec = optim_runner.minimize(100)
    print(rec.loss)
optim_params = rec.value[1]
parameters.update(optim_params)
aust_model.run(parameters=parameters)
print("Best calibration parameters found:")
optim_params

In [None]:
axis_labels = {"index": "time", "value": "cases"}
comparison_df = pd.concat((smoothed_non_wa_cases, aust_model.get_derived_outputs_df()["notifications"]), axis=1)
comparison_df.plot(labels=axis_labels, title="Optimised parameter outputs")

In [None]:
# Main calibration loop
iterations = 3000
burn_in = 800
n_chains = 10
with pm.Model() as model:
    variables = epm.use_model(calibration_model)
    idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=iterations, tune=0, cores=8, chains=n_chains)
burnt_idata = idata.sel(draw=range(burn_in, iterations))  # Discard burn-in
calib_df = burnt_idata.to_dataframe(groups="posterior")  # Also get as dataframe

In [None]:
# Parameter starting points by chain
idata.posterior.isel(draw=0).to_dataframe()

In [None]:
# Report acceptance ratios
(idata.sample_stats.accepted.sum(axis=1) / idata.sample_stats.coords["draw"].size).to_dataframe()

In [None]:
az.summary(burnt_idata)

In [None]:
from importlib import reload

In [None]:
import aust_covid.calibration

In [None]:
reload(aust_covid.calibration)

In [None]:
from aust_covid.doc_utils import DocumentedProcess, FigElement, TextElement, TableElement

In [None]:
documented_calib = aust_covid.calibration.DocumentedCalibration(
    priors, 
    targets, 
    iterations, 
    burn_in, 
    build_aust_model,
    parameters, 
    param_descriptions, 
    param_units, 
    param_evidence, 
    start_date,
    end_date,
    supplement,
)
documented_calib.graph_param_progression(burnt_idata)
documented_calib.compile_doc()

In [None]:
az.plot_posterior(burnt_idata);

In [None]:
param_names = list(burnt_idata.posterior.data_vars.keys())

num_samples_request = 30
sampled_idata = az.extract(burnt_idata, num_samples=num_samples_request)  # Sample from the inference data
sampled_df = convert_idata_to_df(sampled_idata, param_names)
sample_model_results = run_samples_through_model(sampled_df, calibration_model)
fig = plot_from_model_runs_df(sample_model_results, sampled_df, param_names)
fig.add_trace(
    go.Scatter(x=smoothed_non_wa_cases.index, y=smoothed_non_wa_cases, marker=dict(color="black"), name="non-WA cases", mode="markers"),
)

In [None]:
# # Look at a subset of the results of calibration
# sample_outputs = documented_calib.get_sample_outputs(50)
# pd.concat((smoothed_non_wa_cases, sample_outputs), axis=1).plot(labels=axis_labels)

In [None]:
# Finish up the supplement document with bibliography
supplement.append(pl.NewPage())
supplement.append(pl.Command("printbibliography"))
supplement.generate_tex(str(PROJECT_PATH / "supplement/supplement"))