In [None]:
from jax import jit, random
from jax import numpy as jnp
import numpy as np
import pandas as pd
from datetime import datetime
import numpyro
from numpyro import distributions as dist
import arviz as az
pd.options.plotting.backend = 'plotly'

from estival.sampling import tools as esamp
import summer2

from emu_renewal.distributions import JaxGammaDens
from emu_renewal.renew import JaxModel

In [None]:
distri = JaxGammaDens()
j = JaxModel(33e6, 50, 276, 30, 12, distri)

@jit
def get_inc_result(gen_mean, gen_sd, proc, seed, cdr):
    return j.func(gen_mean, gen_sd, proc, seed).incidence[run_in:] * cdr

In [None]:
run_in = 30
n_process_periods = 12
raw_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
raw_data.index = pd.to_datetime(raw_data.index)
mys_data = raw_data.loc[datetime(2021, 3, 1): datetime(2021, 11, 1)].reset_index()["MYS"]
mys_data.index += run_in
n_times = len(mys_data) + run_in
calib_kwargs = {"pop": 33e6, "n_times": n_times, "run_in": run_in, "targets": mys_data}

In [None]:
priors = {
    "gen_mean": dist.Uniform(8.0, 10.0),
    "gen_sd": dist.Uniform(2.0, 7.0),
    "proc": dist.Uniform(np.repeat(-1.0, 12), np.repeat(1.0, 12)),
    "seed": dist.Uniform(7.0, 12.0),  # This is actually the log seed
    "cdr": dist.Uniform(0.1, 0.2),
}

def calib_model():
    param_updates = {k: numpyro.sample(k, v) for k, v in priors.items()}
    model_res = get_inc_result(**param_updates)
    like = dist.Normal(model_res, jnp.std(jnp.array(mys_data)) * 0.1).log_prob(jnp.array(mys_data)).sum()
    numpyro.factor("incidence", like)

In [None]:
kernel = numpyro.infer.NUTS(calib_model)
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
burn_in = 10
n_samples = 10
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
spaghetti = pd.DataFrame()
for i, p in enumerate(sample_params):
    incidence = get_inc_result(**p)
    spaghetti[i] = incidence
spaghetti.columns = sample_params.index.to_flat_index().map(str)
spaghetti.index += run_in
spaghetti["targets"] = mys_data

In [None]:
from emu_renewal.process import sinterp, cosine_multicurve
from jax import vmap

In [None]:
def get_proc_vals_from_sample(renew_model, sample_idata):
    sample_df = esamp.xarray_to_sampleiterator(sample_idata).convert("pandas")
    times = np.array(renew_model.model_times)
    proc_df = pd.DataFrame()
    for i in sample_df.index:
        y_vals = sinterp.get_scale_data(np.array(sample_df.loc[i, "proc"]))
        model_vals = np.exp(vmap(cosine_multicurve, in_axes=(0, None, None))(times, renew_model.x_proc_vals, y_vals))
        proc_df[i] = model_vals
    return proc_df

In [None]:
proc_df = get_proc_vals_from_sample(j, idata_sampled)
proc_df.columns = proc_df.columns.to_flat_index().map(str)
proc_df.plot()