In [None]:
# Uncomment the line below to install on Colab or similar
#! pip install git+https://github.com/monash-emu/wpro-working.git@pyproject

In [None]:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"

In [None]:
def update():
    from jax import config
    config.update("jax_enable_x64", True)

In [None]:
update()

In [None]:
import jax

jax.devices()
from jax import numpy as jnp
jnp.linspace(0.0,1.0,10)

In [None]:
#| warning: false
from jax import jit, random
import pandas as pd
from datetime import datetime,timedelta
import numpyro
from numpyro import distributions as dist
import arviz as az
from IPython.display import Markdown
from plotly.express.colors import qualitative as qual_colours
import numpy as np

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve, LinearMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti, plot_uncertainty_patches, PANEL_SUBTITLES, plot_3d_spaghetti
from emu_renewal.calibration import StandardCalib

In [None]:
pd.options.plotting.backend = "matplotlib"

In [None]:
from jax import scipy as jsp, numpy as jnp
import jax

In [None]:
# Specify fixed parameters and get calibration data
run_in = 0#100
proc_update_freq = 7
mys_data = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data.index = pd.to_datetime(mys_data.index)
pop = 33e6
analysis_start = datetime(2021, 4, 1)
analysis_end = datetime(2021, 11, 1)
select_data = mys_data.loc[analysis_start: analysis_end]

In [None]:
sda = np.array(select_data)

from summer2.functions.derived import get_rolling_reduction

In [None]:
rmean7 = get_rolling_reduction(jnp.mean, 7)

In [None]:
rmean7(sda).shape

In [None]:
init_data = mys_data.loc[analysis_start-timedelta(50):analysis_start-timedelta(1)]
#init_data.shape
init_data.plot()
select_data.plot()
select_data.rolling(7).mean().plot()
pd.Series(rmean7(sda), index = select_data.index).plot()

select_data_ma7 = select_data.rolling(7).mean().dropna()

In [None]:
def reindex_daily_cumulative(series):
    out_idx = pd.date_range(series.index[0],series.index[-1])
    out_series = pd.Series(data=series,index=out_idx)
    return out_series.interpolate()

def report_gaps(series):
    out_idx = pd.date_range(series.index[0],series.index[-1])
    out_series = pd.Series(data=series,index=out_idx)
    return out_series.isna()

In [None]:
fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, run_in, proc_update_freq, fitter, GammaDens(), fitter, 50)

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.TruncatedNormal(6.0, 1.0,low=2.0,high=10.0),#dist.TruncatedNormal(5.38,1.0,low=2.0, high=12.0),#dist.Gamma(10.0, 1.0),
    "gen_sd": dist.Gamma(2.5, 1.0),
    #"cdr": dist.Beta(3.5, 10.0),#dist.TruncatedNormal(0.25,0.1,low=0.1, high=0.35),
    #"seed": dist.Uniform(0.01, 10.0), #Approximate upper seed value (see above). Relatively insensitive to lower limit
    #"rt0": dist.TruncatedNormal(0.0,0.1, low=-1.0,high=1.0)
}

In [None]:
{k:v.mean for k,v in priors.items()}

In [None]:
xr = np.linspace(2.5,12.5,100)
pd.Series(np.exp(priors["gen_mean"].log_prob(xr)),index=xr).plot()

In [None]:
xr = np.linspace(0.01,1.0,100)
pd.Series(np.exp(dist.HalfNormal(0.05).log_prob(xr)),index=xr).plot()

In [None]:
from jax import numpy as jnp

In [None]:
calib = StandardCalib(renew_model, select_data, priors, jnp.array(init_data), {"rt0": 0.0, "cdr": 0.25}, smoothing=False, data_dispersion_sd=0.1, process_dispersion_sd=0.1)

In [None]:
from emu_renewal.distributions import GammaDens

In [None]:
# Convenience function for PDF of a prior
def plotpdf(p):
    x = np.linspace(0.0,20.0,100)
    return pd.Series(data=np.exp(p.log_prob(x)),index=x)

#plotpdf().plot()

In [None]:
from numpyro import infer

In [None]:
from functools import partial

In [None]:
from numpyro.infer.util import constrain_fn

In [None]:
select_data

In [None]:
import numpy as np

In [None]:
init_duration = 14
window_len = 50
exp_coeff = np.log(select_data.iloc[0]) / init_duration
init_series = np.concatenate([np.zeros(window_len - init_duration), np.exp(exp_coeff * np.arange(init_duration))])

In [None]:
init_series

In [None]:
constrain_fn(calib.calibration, (), {}, sample_params.iloc[0])

In [None]:
# We do a random uniform sampling of initial points, but constrain the radius of the sample to lower
# than default; because our random process covers a lot of parameter space, we don't want to sample too far out,
# but still want to retain more diversity than simply using the median for all chains
kernel = numpyro.infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_mean)

# We can start with any values we want for num_samples and num_warmup here
# 1000,1000 should be enough 'most of the time', and is useful while testing - expect a few bad runs depending on the seed
# 2000,2000 should be considerably more robust
# Higher values may be required for exacting results with 'pristine' r values
mcmc = numpyro.infer.MCMC(kernel, num_chains=4, num_samples=1000, num_warmup=1000)
rng_key = random.PRNGKey(15)
#mcmc.run(rng_key, extrafields=("accept_prob","diverging"), params=priors,collect_warmup=True)

In [None]:
# Run the warmup phase of the mcmc separately - we want to examine this behaviour before committing to a run
# Things we're looking for - all chains run at approximately the same speed as one another
# (eg no order of magnitude differences)
# It's normal for runs to speed up during the warmup (as they get closer to the viable region)

mcmc.warmup(rng_key, extra_fields=("accept_prob","diverging","potential_energy"), collect_warmup=True)

In [None]:
# Plot the potential energy (equivalent to log density) of our warmup samples
# By the end of the warmup run, these should be all in the same range for every chain
# If they are not, then there is no point running a calibration - either we've
# got a bad initial point (shouldn't happen), or something is wrong with our model/priors/NUTS configuration
# Even if they end up in the same place, it is worth checking to see if some chains took unusually long to 
# converge - if so, they could cause issues with the mass matrix calculation used for the actual calibration
# sampling.  Longer warmups should resolve this.

# Don't worry if chains occasionally dip lower than the others - this is a more 'perfect fit'/better MAP estimate
# but outside the center of mass - as long as they come back to the same range for most of the trace, all is well

pd.DataFrame(mcmc.get_extra_fields(True)["potential_energy"]).T.iloc[100:].plot()

In [None]:
pd.DataFrame(mcmc.get_samples(True)["gen_mean"]).T.plot()

In [None]:
# Now run the actual MCMC
# This should sample a bit faster than the warmup (if everything went right as above, then all our chains are properly
# preconditioned)

# If there are any chains running considerably faster or slower than the others, then something is wrong
# (most likely the mass matrix tuning is different for this chain; you can check the potential energy/trace
# to validate this)

mcmc.run(rng_key, extra_fields=("accept_prob","diverging","potential_energy"))

In [None]:
idata = az.from_dict(mcmc.get_samples(True))

In [None]:
# This should be an absolute maximum of 1.05 for any actual inference (good enough to 
# not be misleading, but still not really appropriate for publication/policy advice)
# For this kind of model, 1.00 is the target
az.summary(idata)

In [None]:
#az.plot_posterior(idata);

In [None]:
from jax import numpy as jnp

In [None]:
burn_in = 0
n_samples = 200
quantiles = [0.05, 0.5, 0.95]
idata_burnt = idata.sel(draw=slice(burn_in, None))
idata_sampled = az.extract(idata_burnt, num_samples=n_samples)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
az.plot_posterior(idata);

In [None]:
sample_params.components["cdr"] = np.ones(200)*0.25
#sample_params.components["rt0"] = np.ones(200)*0.0

In [None]:
def get_full_result(gen_mean, gen_sd, proc, cdr=0.25, rt0=0.0):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, calib.init_data/cdr, rt0)

full_wrap = jit(get_full_result)
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles)

In [None]:
patch_fig = plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)
patch_fig

In [None]:
patch_fig = plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)
patch_fig

In [None]:
patch_fig = plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)
patch_fig

In [None]:
patch_fig = plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)
patch_fig

In [None]:
patch_fig = plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly)
patch_fig

In [None]:
Markdown(renew_model.get_description())

In [None]:
pd.options.plotting.backend="plotly"

In [None]:
plot_spaghetti(spaghetti, select_data)

In [None]:
sample_params.loc[(3,277)]

In [None]:
sample_params.loc[(3,366)]

In [None]:
# Optional - requires kaleido
#| label: fig-calib
#| fig-cap: "Calibration to sample data from Malaysia"
# patch_fig.write_image("patch_fig.svg")

In [None]:
# params_df.columns = ["name", "Lower limit", "Upper limit"]
# params_df.index = params_df["name"]
# params_df = params_df.drop(columns=["name"])
# params_df.index.name = None

In [None]:
Markdown("### Calibration")

In [None]:
Markdown(calib.get_description())

In [None]:
# Markdown(params_df.to_markdown())

In [None]:
# evidence_table = pd.DataFrame(index=params_df.index, columns=["Evidence"])
# evidence_table.loc[:, "Evidence"] = "To be populated [@cori2013]"
# Markdown(evidence_table.to_markdown())

In [None]:
# plot_3d_spaghetti(spaghetti, ["susceptibles", "transmission potential"])

In [None]:
# First attempt at prior-posterior comparison graph
vars = list(priors.keys())
comparison_plot = az.plot_density(idata, var_names=list(priors.keys()), point_estimate=None, shade=0.5);
for i_ax, ax in enumerate(comparison_plot.ravel()[:len(vars)]):
    ax_limits = ax.get_xlim()
    x_vals = np.linspace(ax_limits[0], ax_limits[1], 100)
    y_vals = np.diff(priors[vars[i_ax]].cdf(x_vals))
    y_vals *= 0.94 / max(y_vals)
    ax.fill_between(x_vals[:-1], y_vals, color='k', alpha=0.2, linewidth=2)