In [None]:
# Uncomment the line below to install on Colab or similar
#! pip install git+https://github.com/monash-emu/wpro-working.git@more-datasets

In [None]:
import pandas as pd
import numpy as np
from datetime import datetime
import numpyro
from numpyro import distributions as dist
from numpyro import infer
from jax import jit, random
import arviz as az
from estival.sampling import tools as esamp
from plotly.express.colors import qualitative as qual_colours
from IPython.display import Markdown

from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti, plot_uncertainty_patches
from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.calibration import StandardCalib
from emu_renewal.utils import get_adjust_idata_index, adjust_summary_cols

In [None]:
sars_data = pd.read_csv("https://github.com/monash-emu/wpro-working/raw/more-datasets/data/sars_hongkong/sars_2003_complete_dataset_clean.csv", index_col="Date", parse_dates=True)
# print("Available countries are:")
# set(sars_data["Country"])

In [None]:
country_req = "Hong Kong SAR, China"
case_data = sars_data[sars_data["Country"] == country_req]["Cumulative number of case(s)"].diff().rolling(4).mean().dropna()

In [None]:
# Specify fixed parameters and get calibration data
window_len = 14
proc_update_freq = 4
pop = 6.7e6
analysis_start = sars_data.index[0]
analysis_end = datetime(2003, 5, 1)
select_data = case_data.loc[analysis_start: analysis_end]

In [None]:
# Create exponentially increasing case counts for init_duration period before analysis starts
# Pad with zeroes to allow looking back over the full window
init_duration = 14
exp_coeff = np.log(case_data.iloc[0]) / init_duration
init_series = np.exp(exp_coeff * np.arange(init_duration))

In [None]:
fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, proc_update_freq, fitter, GammaDens(), window_len, init_series, GammaDens())

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.Uniform(6.5, 10.5),
    "gen_sd": dist.Uniform(3.0, 4.6),
    "cdr": dist.Beta(10.0, 4.0),
    "rt_init": dist.Normal(0.0, 0.25),
}
fixed_params = {
    "report_mean": 10.0,
    "report_sd": 5.0
}

In [None]:
calib = StandardCalib(renew_model, priors, select_data, fixed_params)

In [None]:
kernel = numpyro.infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_uniform(radius=0.5))
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",), priors=priors)

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
burn_in = 10
n_samples = 100
quantiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, cdr, rt_init):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, cdr, rt_init, fixed_params["report_mean"], fixed_params["report_sd"])

full_wrap = jit(get_full_result)
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles)

In [None]:
pd.Series(init_series).plot()

In [None]:
plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)

In [None]:
plot_spaghetti(spaghetti, select_data)

In [None]:
summary = az.summary(idata)
summary = adjust_summary_cols(summary)
summary.index = summary.index.map(get_adjust_idata_index(renew_model))
Markdown(summary.to_markdown())