In [None]:
from jax import jit, random

In [None]:
import pandas as pd

In [None]:
from jax import jit, random
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from IPython.display import Markdown
from plotly.express.colors import qualitative as qual_colours
import pathlib
import math

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti
from emu_renewal.outputs import plot_uncertainty_patches, PANEL_SUBTITLES, plot_3d_spaghetti, plot_post_prior_comparison
from emu_renewal.calibration import StandardCalib
from emu_renewal.utils import get_adjust_idata_index, adjust_summary_cols

In [None]:
from pathlib import Path

In [None]:
import renewal_analysis
ra_path = Path(renewal_analysis.__path__[0])

In [None]:
# Set path to target data from WHO weekly time series
target_data_path = ra_path / 'data/target-data/case-data.csv'

In [None]:
# Load in target data from WHO weekly time series
data = pd.read_csv(target_data_path, index_col=0)
data.index = pd.to_datetime(data.index)

# Also load daily time series data from owid
mys_data_daily = pd.read_csv("https://github.com/monash-emu/wpro_working/raw/main/data/new_cases.csv", index_col=0)["MYS"]
mys_data_daily.index = pd.to_datetime(mys_data_daily.index)

In [None]:
data

In [None]:
# explore different data methods

In [None]:
# Method 1 - Interpolate over cumulative weekly time series
mys_data_1 = data['Cumulative_cases_MYS']
mys_data_1 = mys_data_1.resample('D').interpolate()
mys_data_1 = mys_data_1.diff(1)

In [None]:
# Method 2 - Interpolate over weekly non-cumulative time series, and take average
mys_data_2 = data['New_cases_MYS']
mys_data_2 = mys_data_2.resample('D').interpolate()/7

In [None]:
# Method 3 - take the average of weekly non-cumulative time series, add add lag to midweek
mys_data_3 = data['New_cases_MYS']/7
mys_data_3 = mys_data_3.shift(-3)

In [None]:
# Specify fixed parameters and get calibration data
proc_update_freq = 14
init_time = 50
mys_data = data['New_cases_MYS']
pop = 33e6
analysis_start = datetime(2021, 3, 1)
analysis_end = datetime(2021, 11, 1)
init_start = analysis_start - timedelta(init_time)
init_end = analysis_start - timedelta(1)
select_data = mys_data.loc[analysis_start: analysis_end]
init_data = mys_data.loc[init_start: init_end]

In [None]:
proc_fitter = CosineMultiCurve()
renew_model = RenewalModel(33e6, analysis_start, analysis_end, proc_update_freq, proc_fitter, GammaDens(), 50, init_data, GammaDens())

In [None]:
report_mean = 3.18
report_sd = math.sqrt(10.33)

In [None]:
report_sd

In [None]:
report_sd**2

In [None]:
# Function for computing sd from 95%ci
def compute_parameter_sd(mean, mean_ui, sd, sd_ui):
    mean_sd = (mean_ui - mean)/2
    mean_v = mean_sd**2
    sd_sd = (sd_ui - sd)/2
    sd_v = sd_sd**2
    return mean_sd, mean_v, sd_sd, sd_v

In [None]:
# values from this paper (onset to report values): https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9225637/
compute_parameter_sd(3.18, 3.19, 10.33, 10.44)

In [None]:
#incubation period covid = 6.9 days
#latent period covid = 5.5 days
# report delay onset of symptoms to report = 3.2
# gen mean 7.3
# gen sd 3.8

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.TruncatedNormal(5.0, 0.4, low=0.0),
    "gen_sd": dist.TruncatedNormal(3.8, 0.5, low=0.0),
    "cdr": dist.Beta(4.0, 10.0),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.TruncatedNormal(5, 0.5, low=0.0),
    "report_sd": dist.TruncatedNormal(2, 0.5, low=0.0),
}

In [None]:
calib = StandardCalib(renew_model, priors, select_data, indicator='weekly_sum')
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_uniform(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=2, num_samples=1000, num_warmup=1000)
mcmc.run(random.PRNGKey(1))

In [None]:
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=200)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd)

full_wrap = jit(get_full_result)
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles=[0.05, 0.5, 0.95])
plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly).update_layout(showlegend=False)

In [None]:
quantiles_df

In [None]:
az.summary(idata)

In [None]:
plot_post_prior_comparison(idata, list(priors.keys()), priors);