In [None]:
#| warning: false
from jax import jit, random
import pandas as pd
from datetime import datetime
import numpyro
import arviz as az
from IPython.display import Markdown
from plotly.express.colors import qualitative as qual_colours

from estival.sampling import tools as esamp
from summer2.utils import Epoch

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti, plot_uncertainty_patches, PANEL_SUBTITLES
from emu_renewal.calibration import StandardCalib

In [None]:
from emu_renewal.calibration import StandardCalib

In [None]:
# Specify fixed parameters and get calibration data
run_in = 30
proc_update_freq = 14
mys_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data.index = pd.to_datetime(mys_data.index)
pop = 33e6
analysis_start = datetime(2021, 3, 1)
analysis_end = datetime(2021, 11, 1)
select_data = mys_data.loc[analysis_start: analysis_end]
epoch = Epoch(datetime(2019, 12, 31))

In [None]:
fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, run_in, proc_update_freq, fitter, GammaDens(), fitter, 50, epoch)

In [None]:
calib = StandardCalib(renew_model, select_data)

In [None]:
# Define parameter ranges
params_dict = {
    "gen_mean": {"name": "Generation time mean (days)", "lower": 5.0, "upper": 12.0},
    "gen_sd": {"name": "Generation time standard deviation (days)", "lower": 2.0, "upper": 7.0},
    "cdr": {"name": "Case detection proportion", "lower": 0.05, "upper": 0.4},
    "seed": {"name": "Peak seed rate", "lower": 5.0, "upper": 15.0},
}
params_df = pd.DataFrame(params_dict).transpose()

In [None]:
kernel = numpyro.infer.NUTS(calib.calibration, dense_mass=True)
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",), params=params_dict)

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
burn_in = 10
n_samples = 100
quantiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, seed, cdr):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, seed)

full_wrap = jit(get_full_result)
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles)

In [None]:
Markdown(renew_model.get_description())

In [None]:
plot_spaghetti(spaghetti, select_data)

In [None]:
patch_fig = plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)
patch_fig

In [None]:
#| label: fig-calib
#| fig-cap: "Calibration to sample data from Malaysia"
patch_fig.write_image("patch_fig.svg")

In [None]:
params_df.columns = ["name", "Lower limit", "Upper limit"]
params_df.index = params_df["name"]
params_df = params_df.drop(columns=["name"])
params_df.index.name = None

In [None]:
Markdown("### Calibration")

In [None]:
Markdown(calib.get_description())

In [None]:
Markdown(params_df.to_markdown())

In [None]:
evidence_table = pd.DataFrame(index=params_df.index, columns=["Evidence"])
evidence_table.loc[:, "Evidence"] = "To be populated [@cori2013]"
Markdown(evidence_table.to_markdown())