In [None]:
import pandas as pd
from datetime import datetime
import numpyro
from numpyro import distributions as dist
from numpyro import infer
from jax import jit, random
import arviz as az
from estival.sampling import tools as esamp
from plotly.express.colors import qualitative as qual_colours

from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti, plot_uncertainty_patches
from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.calibration import StandardCalib

In [None]:
sars_data = pd.read_csv("../data/sars_hongkong/sars_2003_complete_dataset_clean.csv", index_col="Date", parse_dates=True)
# print("Available countries are:")
# set(sars_data["Country"])

In [None]:
country_req = "Hong Kong SAR, China"
cum_data = sars_data[sars_data["Country"] == country_req]["Cumulative number of case(s)"]#.diff().rolling(7).mean().dropna()

In [None]:
def reindex_daily_cumulative(series):
    out_idx = pd.date_range(series.index[0],series.index[-1])
    out_series = pd.Series(data=series,index=out_idx)
    return out_series.interpolate()

def report_gaps(series):
    out_idx = pd.date_range(series.index[0],series.index[-1])
    out_series = pd.Series(data=series,index=out_idx)
    return out_series.isna()

In [None]:
import jax

In [None]:
cum_data = reindex_daily_cumulative(cum_data)

In [None]:
cum_data.plot()

In [None]:
case_data = cum_data.diff().dropna()

In [None]:
case_data_ma7 = cum_data.rolling(7).mean().diff().dropna()

In [None]:
import numpy as np
from jax import numpy as jnp

In [None]:
# Just populate the buffer with our starting value; at an R value of 1.0, this will exactly produce the first value...

init_data = jnp.ones(50) * case_data.iloc[0]
init_data

In [None]:
# Specify fixed parameaters and get calibration data
run_in = 0
proc_update_freq = 5
pop = 6.7e6
analysis_start = case_data.index[0]
analysis_end = datetime(2003, 6, 1)
select_data = case_data_ma7.loc[analysis_start: analysis_end].dropna()

In [None]:
fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, run_in, proc_update_freq, fitter, GammaDens(), fitter, 50)

In [None]:
init_window = init_data#jnp.linspace(0.0,select_data.iloc[0], 50)

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.TruncatedNormal(8.4,0.5,low=1.0,high=12.0),#dist.Gamma(10.0, 1.0),
    "gen_sd": dist.Gamma(5.0, 1.0),
    #"cdr": dist.Beta(10.0, 4.0),
    #"rt0": dist.Normal(0.0,0.1)
    #"seed": dist.Uniform(0.4, 1.5),
}

In [None]:
# Just trying something different here to stop some crashes...
prior_desc = {
    "gen_mean": ("TruncatedNormal", (8.4,1.0), {"low": 1.0, "high": 16.0}),
    "gen_sd": ("Gamma", (5.0,1.0), {}),
    "rt0": ("TruncatedNormal", (0.0,0.11), {"low": -2.0, "high": 2.0})
}

In [None]:
# smoothing = True; since the input data is smoothed, our model should get smoothed too
calib = StandardCalib(renew_model, select_data, prior_desc, init_data=init_data, fixed_params={"rt0": 0.0, "cdr": 1.0},data_dispersion_sd=0.1, process_dispersion_sd=0.1,smoothing=True)

In [None]:
ival = infer.init_to_value(values={
    "proc": jnp.zeros(calib.n_process_periods),
    "rt0": 0.0
})

In [None]:
kernel = numpyro.infer.NUTS(calib.calibration, dense_mass=True, init_strategy=ival)
mcmc = numpyro.infer.MCMC(kernel, num_chains=4, num_samples=2000, num_warmup=2000, jit_model_args=True)
rng_key = random.PRNGKey(9)
#mcmc.run(rng_key, extra_fields=("accept_prob",))

In [None]:
mcmc.warmup(rng_key, extra_fields=("accept_prob","diverging","potential_energy"), collect_warmup=True)

In [None]:
pd.options.plotting.backend = "plotly"

In [None]:
pd.DataFrame(mcmc.get_samples(True)["gen_mean"]).T.plot()

In [None]:
pd.DataFrame(mcmc.get_extra_fields(True)["potential_energy"]).T.iloc[-200:].plot()

In [None]:
mcmc.run(rng_key, extra_fields=("accept_prob","diverging","potential_energy"))

In [None]:
pd.DataFrame(mcmc.get_extra_fields(True)["potential_energy"]).T.plot()

In [None]:
idata = az.from_dict(mcmc.get_samples(True))

In [None]:
az.summary(idata)

In [None]:
burn_in = 0
n_samples = 200
quantiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
sample_params.components["cdr"] = np.ones(200)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, cdr=1.0, rt0=0.0, **kwargs):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, calib.init_data/cdr, rt0)

full_wrap = jit(get_full_result)
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles)

In [None]:
fres_cases_ma7 = {k:full_wrap(**v).incidence_ma7*v["cdr"] for k,v in sample_params.iterrows()}
fres_cases = {k:full_wrap(**v).incidence*v["cdr"] for k,v in sample_params.iterrows()}

In [None]:
pd.options.plotting.backend = "matplotlib"

In [None]:
cq = pd.DataFrame(fres_cases).cumsum().quantile((0.05,0.5,0.95),axis=1).T
cq.index = quantiles_df.index
cq.plot()
case_data.cumsum().plot()

In [None]:
qdf = pd.DataFrame(fres_cases).quantile((0.05,0.5,0.95),axis=1).T
qdf.index = quantiles_df.index
#qdf["data"] = case_data
qdf.plot()
case_data.plot(style='.', color='black')

In [None]:
qdf = pd.DataFrame(fres_cases_ma7).quantile((0.05,0.5,0.95),axis=1).T
qdf.index = quantiles_df.index
#qdf["data"] = case_data
qdf.plot()
case_data_ma7.plot(style='.', color='black')

In [None]:
plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)