In [None]:
# Uncomment the line below to install on Colab or similar
#! pip install git+https://github.com/monash-emu/wpro-working.git@more-datasets

In [None]:
#| warning: false
from jax import jit, random
import pandas as pd
from datetime import datetime
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from IPython.display import Markdown
from plotly.express.colors import qualitative as qual_colours

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti, plot_uncertainty_patches, PANEL_SUBTITLES, plot_3d_spaghetti
from emu_renewal.calibration import StandardCalib

In [None]:
# Specify fixed parameters and get calibration data
run_in = 30
proc_update_freq = 14
mys_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data.index = pd.to_datetime(mys_data.index)
pop = 33e6
analysis_start = datetime(2021, 3, 1)
analysis_end = datetime(2021, 11, 1)
select_data = mys_data.loc[analysis_start: analysis_end]

In [None]:
proc_fitter = seed_fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, run_in, proc_update_freq, proc_fitter, GammaDens(), seed_fitter, 50)

In [None]:
calib = StandardCalib(renew_model, mys_data)

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.Gamma(10.0, 1.0),
    "gen_sd": dist.Gamma(5.0, 1.0),
    "cdr": dist.Beta(4.0, 10.0),
    "seed": dist.Uniform(3.0, 10.0),
}

In [None]:
kernel = numpyro.infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_uniform(radius=0.5))
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",), params=priors)

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
az.summary(idata)

In [None]:
burn_in = 10
n_samples = 100
quantiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, seed)

full_wrap = jit(get_full_result)
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles)

In [None]:
Markdown(renew_model.get_description())

In [None]:
plot_spaghetti(spaghetti, select_data)

In [None]:
patch_fig = plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)
patch_fig

In [None]:
# Optional - requires kaleido
#| label: fig-calib
#| fig-cap: "Calibration to sample data from Malaysia"
# patch_fig.write_image("patch_fig.svg")

In [None]:
# params_df.columns = ["name", "Lower limit", "Upper limit"]
# params_df.index = params_df["name"]
# params_df = params_df.drop(columns=["name"])
# params_df.index.name = None

In [None]:
Markdown("### Calibration")

In [None]:
Markdown(calib.get_description())

In [None]:
# Markdown(params_df.to_markdown())

In [None]:
# evidence_table = pd.DataFrame(index=params_df.index, columns=["Evidence"])
# evidence_table.loc[:, "Evidence"] = "To be populated [@cori2013]"
# Markdown(evidence_table.to_markdown())

In [None]:
# plot_3d_spaghetti(spaghetti, ["susceptibles", "transmission potential"])

In [None]:
# First attempt at prior-posterior comparison graph
import numpy as np
vars = list(priors.keys())
comparison_plot = az.plot_density(idata, var_names=list(priors.keys()), point_estimate=None, shade=0.5);
for i_ax, ax in enumerate(comparison_plot.ravel()[:len(vars)]):
    ax_limits = ax.get_xlim()
    x_vals = np.linspace(ax_limits[0], ax_limits[1], 100)
    y_vals = np.diff(priors[vars[i_ax]].cdf(x_vals))
    y_vals *= 0.94 / max(y_vals)
    ax.fill_between(x_vals[:-1], y_vals, color='k', alpha=0.2, linewidth=2)

In [None]:
# Convenience function for PDF of a prior
def plotpdf(p):
    x = np.linspace(p.icdf(0.001), p.icdf(0.999), 100)
    return pd.Series(data=np.exp(p.log_prob(x)), index=x)

plotpdf(priors["cdr"]).plot()