In [None]:
#| warning: false
from jax import jit, random, vmap
from jax import numpy as jnp
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
import arviz as az
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from IPython.display import Markdown
pd.options.plotting.backend = "plotly"
import plotly.express as px
colours = px.colors.qualitative.Plotly

from estival.sampling import tools as esamp
import summer2
from summer2.utils import Epoch

from emu_renewal.process import sinterp, cosine_multicurve
from emu_renewal.distributions import JaxGammaDens
from emu_renewal.renew import JaxModel
from emu_renewal.outputs import plot_spaghetti, get_area_from_df, plot_uncertainty_patches

In [None]:
run_in = 30
n_process_periods = 12
mys_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data.index = pd.to_datetime(mys_data.index)
pop = 33e6
fixed_param_desc = "### Fixed parameter values\n " \
    f"The target population is initialised as {str(int(pop))} susceptible persons. " \
    f"The simulation runs for a run-in period of {run_in} days before comparison against the calibration data commences.\n"
analysis_start_time = datetime(2021, 3, 1)
analysis_end_time = datetime(2021, 11, 1)
select_data = mys_data.loc[analysis_start_time: analysis_end_time]

In [None]:
start = analysis_start_time - timedelta(days=run_in)
end = analysis_end_time
distri = JaxGammaDens()
epoch = Epoch(datetime(2019, 12, 31))
renew_model = JaxModel(33e6, start, end, run_in, 12, distri, 50, epoch, run_in)

def get_inc_result(gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.func(gen_mean, gen_sd, proc, seed).incidence[run_in:] * cdr

renewal_wrap = jit(get_inc_result)

In [None]:
# Define parameter ranges
scalar_req = {
    "gen_mean": {"name": "Generation time mean (days)", "lower": 5.0, "upper": 12.0},
    "gen_sd": {"name": "Generation time standard deviation (days)", "lower": 2.0, "upper": 7.0},
    "cdr": {"name": "Case detection proportion", "lower": 0.05, "upper": 0.4},
    "seed": {"name": "Peak seed rate", "lower": 5.0, "upper": 15.0},
}
params_df = pd.DataFrame(scalar_req).transpose()

In [None]:
calib_desc = "\n\n### Calibration targets\nThe model described above was fit to the target data " \
    "to minimise the density of the observed number of cases at each available data point " \
    "from a normal distribution centred at the modelled notification rate. " \
    "Modelled notifications are calculated as the product of modelled incidence and the " \
    "(constant through time) case detection proportion. "

def calib_model():
    param_updates = {k: numpyro.sample(k, dist.Uniform(v["lower"], v["upper"])) for k, v in scalar_req.items()}
    proc_dispersion = numpyro.sample("proc_dispersion", dist.Uniform(0.0, 1.0))
    proc_dist = dist.Normal(jnp.repeat(0.0, n_process_periods), proc_dispersion)
    param_updates["proc"] = numpyro.sample("proc", proc_dist)
    logmodel_res = jnp.log(renewal_wrap(**param_updates))
    logtarget = jnp.log(jnp.array(select_data))
    dispersion = numpyro.sample("dispersion", dist.Uniform(jnp.log(1.0), jnp.log(1.5)))
    like = dist.Normal(logmodel_res, dispersion).log_prob(logtarget).sum()
    numpyro.factor("notifications_ll", like)

In [None]:
Markdown(calib_desc)

In [None]:
kernel = numpyro.infer.NUTS(calib_model, dense_mass=True)
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
burn_in = 10
n_samples = 100
centiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.func(gen_mean, gen_sd, proc, seed)

full_wrap = jit(get_full_result)

In [None]:
param_idx = [str(i) for i in sample_params.index]
cases_spagh = pd.DataFrame(columns=param_idx)
suscept_spagh = pd.DataFrame(columns=param_idx)
r_spagh = pd.DataFrame(columns=param_idx)
proc_df = pd.DataFrame()
for i, p in sample_params.iterrows():
    params = {k: v for k, v in p.items() if "dispersion" not in k}
    result = full_wrap(**params)
    cases_spagh[str(i)] = result.incidence * p["cdr"]
    suscept_spagh[str(i)] = result.suscept
    r_spagh[str(i)] = result.r_t
    y_vals = sinterp.get_scale_data(np.array(p["proc"]))
    proc_df[str(i)] = renew_model.fit_process_curve(y_vals)
times = renew_model.epoch.index_to_dti(renew_model.model_times)
cases_spagh.index = times
suscept_spagh.index = times
r_spagh.index = times
proc_df.index = renew_model.epoch.index_to_dti(renew_model.model_times)
case_quantiles = cases_spagh.quantile(centiles, axis=1).T
suscept_quantiles = suscept_spagh.quantile(centiles, axis=1).T
r_quantiles = r_spagh.quantile(centiles, axis=1).T
proc_quantiles = proc_df.quantile(centiles, axis=1).T

In [None]:
Markdown(fixed_param_desc) 

In [None]:
Markdown(calib_desc)

In [None]:
titles = ["cases", "susceptibles", "transmission potential", "R"]
margins = {m: 20 for m in ["t", "b", "l", "r"]}
spagh_fig = plot_spaghetti(cases_spagh, select_data, proc_df, suscept_spagh, r_spagh, margins, titles)

In [None]:
patch_fig = plot_uncertainty_patches(case_quantiles, select_data, proc_quantiles, suscept_quantiles, r_quantiles, margins, titles, colours)
patch_fig

In [None]:
#| label: fig-calib
#| fig-cap: "Calibration to sample data from Malaysia"
patch_fig.write_image("patch_fig.svg")

In [None]:
Markdown(renew_model.get_full_desc())

In [None]:
params_df.columns = ["name", "Lower limit", "Upper limit"]
params_df.index = params_df["name"]
params_df = params_df.drop(columns=["name"])
params_df.index.name = None

In [None]:
Markdown(params_df.to_markdown())

In [None]:
evidence_table = pd.DataFrame(index=params_df.index, columns=["Evidence"])
evidence_table.loc[:, "Evidence"] = "To be populated [@cori2013]"
Markdown(evidence_table.to_markdown())