In [None]:
from jax import jit, random
from jax import numpy as jnp
import numpy as np
import pandas as pd
from datetime import datetime
import numpyro
from numpyro import distributions as dist
import arviz as az

import summer2

from emu_renewal.distributions import JaxGammaDens
from emu_renewal.renew import JaxModel

In [None]:
distri = JaxGammaDens()
j = JaxModel(33e6, 50, 276, 30, 12, distri)

@jit
def get_inc_result(gen_mean, gen_sd, proc, seed):
    return j.func(gen_mean, gen_sd, proc, seed)

In [None]:
run_in = 30
n_process_periods = 12
raw_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
raw_data.index = pd.to_datetime(raw_data.index)
mys_data = raw_data.loc[datetime(2021, 3, 1): datetime(2021, 11, 1)].reset_index()["MYS"]
mys_data.index += run_in
n_times = len(mys_data) + run_in
calib_kwargs = {"pop": 33e6, "n_times": n_times, "run_in": run_in, "targets": mys_data}

In [None]:
priors = {
    "gen_mean": dist.Uniform(2.0, 10.0),
    "gen_sd": dist.Uniform(1.0, 5.0),
    "proc": dist.Uniform(np.repeat(-2.0, 4), np.repeat(2.0, 4)),
    "seed": dist.Uniform(4.0, 15.0),  # This is actually the log seed
}

def calib_model():
    param_updates = {k: numpyro.sample(k, v) for k, v in priors.items()}
    model_res = get_inc_result(**param_updates).incidence[run_in:]
    like = dist.Normal(model_res, jnp.std(jnp.array(mys_data)) * 0.1).log_prob(jnp.array(mys_data)).sum()
    numpyro.factor("incidence", like)

In [None]:
kernel = numpyro.infer.NUTS(calib_model)
mcmc = numpyro.infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))

In [None]:
idata = az.from_numpyro(mcmc)

In [None]:
az.plot_trace(idata);