In [None]:
#| warning: false
from jax import jit, random
from jax import numpy as jnp
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
import arviz as az
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from IPython.display import Markdown
pd.options.plotting.backend = "plotly"
from plotly.express.colors import qualitative as qual_colours

from estival.sampling import tools as esamp
from summer2.utils import Epoch

from emu_renewal.process import sinterp
from emu_renewal.distributions import JaxGammaDens
from emu_renewal.renew import JaxModel
from emu_renewal.outputs import get_quantiles_from_spaghetti, plot_spaghetti, plot_uncertainty_patches, PANEL_SUBTITLES
from emu_renewal.utils import format_date_for_str

In [None]:
# Specify fixed parameters and get calibration data
run_in = 30
proc_update_freq = 7
mys_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data.index = pd.to_datetime(mys_data.index)
pop = 33e6
data_start = datetime(2021, 3, 1)
data_end = datetime(2021, 11, 1)
select_data = mys_data.loc[data_start: data_end]
analysis_start = data_start - timedelta(days=run_in)
analysis_end = data_end
epoch = Epoch(datetime(2019, 12, 31))
fixed_param_desc = (
    "### Fixed parameter values\n " 
    f"The target population is initialised as {str(int(pop))} susceptible persons. " 
    f"The simulation runs from the {format_date_for_str(analysis_start)} "
    f"to the {format_date_for_str(analysis_end)}. "
    f"with a run-in period of {run_in} days before "
    "comparison against the calibration data commences.\n"
)

In [None]:
renew_model = JaxModel(33e6, analysis_start, analysis_end, run_in, proc_update_freq, JaxGammaDens(), 50, epoch, run_in)

def get_inc_result(gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.func(gen_mean, gen_sd, proc, seed).incidence[run_in:] * cdr

renewal_wrap = jit(get_inc_result)

In [None]:
# Define parameter ranges
params_dict = {
    "gen_mean": {"name": "Generation time mean (days)", "lower": 5.0, "upper": 12.0},
    "gen_sd": {"name": "Generation time standard deviation (days)", "lower": 2.0, "upper": 7.0},
    "cdr": {"name": "Case detection proportion", "lower": 0.05, "upper": 0.4},
    "seed": {"name": "Peak seed rate", "lower": 5.0, "upper": 15.0},
}
params_df = pd.DataFrame(params_dict).transpose()

In [None]:
calib_desc = (
    "\n\n### Calibration targets\nThe model described above was fit to the target data "
    "to minimise the density of the observed number of cases at each available data point "
    "from a normal distribution centred at the modelled notification rate. "
    "Modelled notifications are calculated as the product of modelled incidence and the "
    "(constant through time) case detection proportion. "
)

def calib_model():
    param_updates = {k: numpyro.sample(k, dist.Uniform(v["lower"], v["upper"])) for k, v in params_dict.items()}
    proc_dispersion = numpyro.sample("proc_dispersion", dist.HalfNormal(1.0))
    n_process_periods = len(renew_model.x_proc_data.points)
    proc_dist = dist.Normal(jnp.repeat(0.0, n_process_periods), proc_dispersion)
    param_updates["proc"] = numpyro.sample("proc", proc_dist)
    logmodel_res = jnp.log(renewal_wrap(**param_updates))
    logtarget = jnp.log(jnp.array(select_data))
    dispersion = numpyro.sample("dispersion", dist.Uniform(jnp.log(1.0), jnp.log(1.5)))
    like = dist.Normal(logmodel_res, dispersion).log_prob(logtarget).sum()
    numpyro.factor("notifications_ll", like)

In [None]:
kernel = numpyro.infer.NUTS(calib_model, dense_mass=True)
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
burn_in = 10
n_samples = 100
quantiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.func(gen_mean, gen_sd, proc, seed)

full_wrap = jit(get_full_result)

In [None]:
column_names = pd.MultiIndex.from_product([PANEL_SUBTITLES, sample_params.index.map(str)])
spaghetti = pd.DataFrame(index=renew_model.epoch.index_to_dti(renew_model.model_times), columns=column_names)
for i, p in sample_params.iterrows():
    params = {k: v for k, v in p.items() if "dispersion" not in k}
    result = full_wrap(**params)
    spaghetti.loc[:, ("cases", str(i))] = result.incidence * p["cdr"]
    spaghetti.loc[:, ("suscept", str(i))] = result.suscept
    spaghetti.loc[:, ("R", str(i))] = result.r_t
    spaghetti.loc[:, ("transmission potential", str(i))] = result.process

In [None]:
# Calculate quantiles from spaghetti
case_quantiles = get_quantiles_from_spaghetti(spaghetti.loc[:, "cases"], quantiles)
suscept_quantiles = get_quantiles_from_spaghetti(spaghetti.loc[:, "suscept"], quantiles)
r_quantiles = get_quantiles_from_spaghetti(spaghetti.loc[:, "R"], quantiles)
proc_quantiles = get_quantiles_from_spaghetti(spaghetti.loc[:, "transmission potential"], quantiles)

In [None]:
Markdown(fixed_param_desc)
Markdown(calib_desc)

In [None]:
plot_spaghetti(cases_spagh, select_data, suscept_spagh, r_spagh, proc_spagh)

In [None]:
patch_fig = plot_uncertainty_patches(case_quantiles, select_data, suscept_quantiles, r_quantiles, proc_quantiles, qual_colours.Plotly)
patch_fig

In [None]:
#| label: fig-calib
#| fig-cap: "Calibration to sample data from Malaysia"
patch_fig.write_image("patch_fig.svg")

In [None]:
Markdown(renew_model.get_full_desc())

In [None]:
params_df.columns = ["name", "Lower limit", "Upper limit"]
params_df.index = params_df["name"]
params_df = params_df.drop(columns=["name"])
params_df.index.name = None

In [None]:
Markdown(params_df.to_markdown())

In [None]:
evidence_table = pd.DataFrame(index=params_df.index, columns=["Evidence"])
evidence_table.loc[:, "Evidence"] = "To be populated [@cori2013]"
Markdown(evidence_table.to_markdown())