In [None]:
# Uncomment the line below to install on Colab or similar
#! pip install git+https://github.com/monash-emu/wpro-working.git@more-datasets

In [None]:
import pandas as pd
import plotly.express as px
import numpy as np
import pandas as pd
from datetime import datetime
import numpyro
from numpyro import distributions as dist
from numpyro import infer
from jax import jit, random
import arviz as az
from estival.sampling import tools as esamp
from plotly.express.colors import qualitative as qual_colours
from IPython.display import Markdown

from emu_renewal.outputs import get_spaghetti, get_quant_df_from_spaghetti, plot_uncertainty_patches
from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel, ModelResult
from emu_renewal.calibration import StandardCalib
from emu_renewal.targets import StandardTarget

In [None]:
# Grab data on 2014-2016 Ebola outbreak - this analysis we will focus on the epidemic in Seirra Leone. This data was don
ebola_data = pd.read_csv("https://github.com/monash-emu/wpro-working/raw/more-datasets/data/ebola_2014_2016/ebola_2014_2016_clean.csv", index_col="Date", parse_dates=True)
case_data = ebola_data[ebola_data["Country"] == "Sierra Leone"]["Cumulative no. of confirmed, probable and suspected cases"].diff().rolling(14).mean().dropna()

In [None]:
# Specify fixed parameters and get calibration data
proc_update_freq = 4
window_len = 14
pop = 7.1e6
analysis_start = ebola_data.index[0]
analysis_end = datetime(2015, 5, 1)
select_data = case_data.loc[analysis_start: analysis_end]

In [None]:
# Create exponentially increasing case counts for init_duration period before analysis starts
# Pad with zeroes to allow looking back over the full window
init_duration = 14
exp_coeff = np.log(case_data.iloc[0]) / init_duration
init_series = np.concatenate([np.zeros(window_len - init_duration), np.exp(exp_coeff * np.arange(init_duration))])

In [None]:
fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, proc_update_freq, fitter, GammaDens(), 14, init_series, GammaDens())

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.Uniform(10.0, 14.0),
    "gen_sd": dist.Uniform(3.0, 7.0),
    "cdr": dist.Beta(10.0, 4.0),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.Uniform(8.0, 12.0),
    "report_sd": dist.Uniform(3.0, 6.0),
}

In [None]:
calib_data = {
    "cases": StandardTarget(select_data, 0.1)
}
calib = StandardCalib(renew_model, priors, calib_data)

In [None]:
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=calib.custom_init(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
mcmc.run(random.PRNGKey(1))

In [None]:
idata = az.from_numpyro(mcmc)
idata_sampled = az.extract(idata, num_samples=100)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
spaghetti = new_new_get_spaghetti(calib, sample_params)
key_outputs = ["weekly_sum", "suscept", "r_t", "process"]
quantiles = get_quant_df_from_spaghetti(spaghetti, quantiles=[0.05, 0.5, 0.95])

In [None]:
plot_uncertainty_patches(quantiles, select_data, qual_colours.Plotly, outputs=key_outputs)