In [None]:
#| warning: false
from jax import jit, random, vmap
from jax import numpy as jnp
import numpy as np
import pandas as pd
from datetime import datetime
import numpyro
from numpyro import distributions as dist
import arviz as az
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from IPython.display import Markdown
pd.options.plotting.backend = "plotly"

from estival.sampling import tools as esamp
import summer2

from emu_renewal.process import sinterp, cosine_multicurve
from emu_renewal.distributions import JaxGammaDens
from emu_renewal.renew import JaxModel

In [None]:
renew_model = JaxModel(33e6, 276, 30, 12, JaxGammaDens(), 50)

def get_inc_result(renewal_model, gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.func(gen_mean, gen_sd, proc, seed).incidence[run_in:] * cdr

renewal_wrap = jit(get_inc_result, static_argnames=["renewal_model"])

In [None]:
run_in = 30
n_process_periods = 12
raw_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
raw_data.index = pd.to_datetime(raw_data.index)
mys_data = raw_data.loc[datetime(2021, 3, 1): datetime(2021, 11, 1)].reset_index()["MYS"]
mys_data.index += run_in
n_times = len(mys_data) + run_in
calib_kwargs = {"pop": 33e6, "n_times": n_times, "run_in": run_in, "targets": mys_data}
fixed_param_desc = "### Fixed parameter values\n " \
    f"The target population is initialised as {str(int(calib_kwargs["pop"]))} susceptible persons. " \
    f"The simulation runs for a run-in period of {run_in} days before comparison against the calibration data commences.\n"

In [None]:
# Define parameter ranges
scalar_req = {
    "gen_mean": {"name": "Generation time mean (days)", "lower": 5.0, "upper": 12.0},
    "gen_sd": {"name": "Generation time standard deviation (days)", "lower": 2.0, "upper": 7.0},
    "cdr": {"name": "Case detection proportion", "lower": 0.05, "upper": 0.4},
    "seed": {"name": "Peak seed rate", "lower": 5.0, "upper": 15.0},
}
params_df = pd.DataFrame(scalar_req).transpose()
proc_req = {"name": "Random process values", "lower": -2.0, "upper": 2.0}

priors = {}
for k, v in scalar_req.items():
    priors[k] = dist.Uniform(v["lower"], v["upper"])
priors["proc"] = dist.Uniform(np.repeat(proc_req["lower"], n_process_periods), np.repeat(proc_req["upper"], n_process_periods))

In [None]:
calib_desc = "\n\n### Calibration targets\nThe model described above was fit to the target data " \
    "to minimise the density of the observed number of cases at each available data point " \
    "from a normal distribution centred at the modelled notification rate. " \
    "Modelled notifications are calculated as the product of modelled incidence and the " \
    "(constant through time) case detection proportion. "

def calib_model():
    param_updates = {k: numpyro.sample(k, v) for k, v in priors.items()}
    model_res = renewal_wrap(**param_updates, renewal_model=renew_model)
    like = dist.Normal(model_res, jnp.std(jnp.array(mys_data)) * 0.1).log_prob(jnp.array(mys_data)).sum()
    numpyro.factor("incidence", like)

In [None]:
Markdown(calib_desc)

In [None]:
kernel = numpyro.infer.NUTS(calib_model, dense_mass=True)
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
burn_in = 10
n_samples = 10
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
spaghetti = pd.DataFrame()
for i, p in enumerate(sample_params):
    incidence = get_inc_result(renew_model, **p)
    spaghetti[i] = incidence
spaghetti.columns = sample_params.index.to_flat_index().map(str)
spaghetti.index += run_in
spaghetti["targets"] = mys_data

In [None]:
times = np.array(renew_model.model_times)
proc_df = pd.DataFrame()
for k, v in sample_params.iterrows():
    y_vals = sinterp.get_scale_data(np.array(v["proc"]))
    model_vals = np.exp(vmap(cosine_multicurve, in_axes=(0, None, None))(times, renew_model.x_proc_vals, y_vals))
    proc_df[str(k)] = model_vals

In [None]:
Markdown(fixed_param_desc) 

In [None]:
Markdown(calib_desc)

In [None]:
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05)
fig.update_layout(margin={m: 20 for m in ["t", "b", "l", "r"]}, height=600)
fig.add_traces(spaghetti.plot().data, rows=1, cols=1)
fig.add_traces(proc_df.plot().data, rows=2, cols=1)

In [None]:
#| label: fig-calib
#| fig-cap: "Calibration to sample data from Malaysia"
fig.write_image("calib_fig.svg")

In [None]:
Markdown(renew_model.get_full_desc())

In [None]:
params_df.loc["proc", :] = proc_req
params_df.columns = ["name", "Lower limit", "Upper limit"]
params_df.index = params_df["name"]
params_df = params_df.drop(columns=["name"])
params_df.index.name = None

In [None]:
Markdown(params_df.to_markdown())

In [None]:
evidence_table = pd.DataFrame(index=params_df.index, columns=["Evidence"])
evidence_table.loc[:, "Evidence"] = "To be populated [@cori2013]"
Markdown(evidence_table.to_markdown())