In [None]:
from jax import jit, random
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from IPython.display import Markdown
from plotly.express.colors import qualitative as qual_colours
from pathlib import Path
import math
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import Markdown

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti
from emu_renewal.outputs import plot_uncertainty_patches, PANEL_SUBTITLES, plot_3d_spaghetti, plot_post_prior_comparison
from emu_renewal.calibration import StandardCalib
from emu_renewal.utils import get_adjust_idata_index, adjust_summary_cols

In [None]:
# Set general path
cwd = Path.cwd()

In [None]:
# Load in target data from WHO weekly time series
target_data_path = cwd.parent.parent / Path("data/target-data/case-data.csv")
target_data = pd.read_csv(target_data_path, index_col=0)
target_data.index = pd.to_datetime(target_data.index)

# Load in emergence variant data
emergence_data_path = cwd.parent.parent / Path("data/variant-data/variant-emergence.csv")
emergence_data = pd.read_csv(emergence_data_path, index_col=0)
emergence_data = emergence_data[emergence_data["country"] == "Malaysia"]

# Load variant prevalence data

In [None]:
emergence_data

In [None]:
# Specify fixed parameters and get calibration data
proc_update_freq = 21
init_time = 50
data = target_data['New_cases_MYS']
pop = 33e6
analysis_start = datetime(2021, 5, 1)
analysis_end = datetime(2022, 4, 30)
init_start = analysis_start - timedelta(init_time)
init_end = analysis_start - timedelta(1)
select_data = data.loc[analysis_start: analysis_end]
init_data = data.resample("D").asfreq().interpolate().loc[init_start: init_end] / 7.0

In [None]:
proc_fitter = CosineMultiCurve()
renew_model = RenewalModel(pop, analysis_start, analysis_end, proc_update_freq, proc_fitter, GammaDens(), init_time, init_data, GammaDens())

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.TruncatedNormal(7.3, 0.5, low=1.0),
    "gen_sd": dist.TruncatedNormal(3.8, 0.5, low=1.0),
    "cdr": dist.Beta(4, 10),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.TruncatedNormal(8, 0.5, low=1.0),
    "report_sd": dist.TruncatedNormal(3, 0.5, low=1.0),
}

In [None]:
#| output: false
calib = StandardCalib(renew_model, priors, select_data, indicator='weekly_sum')
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_uniform(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=2, num_samples=100, num_warmup=100)
mcmc.run(random.PRNGKey(1))

In [None]:
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=200)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd)

full_wrap = jit(get_full_result)

In [None]:
def get_full_result(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd):
    return renew_model.renewal_func(gen_mean, gen_sd, proc, cdr, rt_init, report_mean, report_sd)

full_wrap = jit(get_full_result)
panel_subtitles = ["weekly_sum"] + PANEL_SUBTITLES[1:]
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap, outputs=panel_subtitles)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles=[0.05, 0.5, 0.95], outputs=panel_subtitles)
plot_uncertainty_patches(quantiles_df, select_data, qual_colours.Plotly, panel_subtitles).update_layout(showlegend=False)

#fig.add_vline(x=datetime(2021, 9, 5))

In [None]:
quantiles_df

In [None]:
qual_colours.Plotly

In [None]:
fig = make_subplots(2,1, shared_xaxes=True)

# cases and modelled cases
fig.add_trace(go.Scatter(x=quantiles_df.index, y=quantiles_df["weekly_sum"][0.50], mode="lines", name="Modelled cases", marker_color='#636EFA' ), row=1, col=1)
fig.add_traces([go.Scatter(x=quantiles_df.index, y=quantiles_df["weekly_sum"][0.05], mode="lines", name="Rt", marker_color='#636EFA', showlegend=False),
               go.Scatter(x=quantiles_df.index, y=quantiles_df["weekly_sum"][0.95], mode="lines", name="Rt", marker_color='#636EFA', showlegend=False, fill='tonexty')],
              rows=1, cols=1)
fig.add_trace(go.Scatter(x=select_data.index, y=select_data,  mode="markers", name="Reported cases", marker_color="black" ), row=1, col=1)

# add Rt
fig.add_trace(go.Scatter(x=quantiles_df.index, y=quantiles_df["R"][0.50], mode="lines", name="Rt", marker_color='#00CC96' ), row=2, col=1)
fig.add_traces([go.Scatter(x=quantiles_df.index, y=quantiles_df["R"][0.05], mode="lines", name="Rt", marker_color='#00CC96', showlegend=False),
               go.Scatter(x=quantiles_df.index, y=quantiles_df["R"][0.95], mode="lines", name="Rt", marker_color='#00CC96', showlegend=False, fill='tonexty')],
              rows=2, cols=1)

# add vlines of variants
#fig.add_vline(x=datetime(2021, 11, 8)) # Omicron emergence
#fig.add_vline(x=datetime.strptime("2021-11-08", "%Y-%m-%d").timestamp() * 1000, annotation_text="Omicron emergence") # Omicron emergence
fig.add_vline(x=datetime.strptime("2021-12-06", "%Y-%m-%d").timestamp() * 1000, 
              annotation_text="Omicron sequence prevalence >10%", annotation_position="right")

# add vrect of lockdown
fig.add_vrect(x0=datetime.strptime("2021-06-01", "%Y-%m-%d").timestamp() * 1000, x1=datetime.strptime("2021-07-05", "%Y-%m-%d").timestamp() * 1000,
             annotation_text="Phase 1", annotation_position="top left",
             fillcolor='#FFA15A', opacity=0.50, line_width=0)# Phase 1 mco
fig.add_vrect(x0=datetime.strptime("2021-07-05", "%Y-%m-%d").timestamp() * 1000, x1=datetime.strptime("2021-08-04", "%Y-%m-%d").timestamp() * 1000,
             annotation_text="Phase 2", annotation_position="top left",
             fillcolor='#FFA15A', opacity=0.30, line_width=0)
fig.add_vrect(x0=datetime.strptime("2021-08-04", "%Y-%m-%d").timestamp() * 1000, x1=datetime.strptime("2021-08-26", "%Y-%m-%d").timestamp() * 1000,
             annotation_text="Phase 3", annotation_position="top left",
             fillcolor='#FFA15A', opacity=0.20, line_width=0)
fig.add_vrect(x0=datetime.strptime("2021-08-26", "%Y-%m-%d").timestamp() * 1000, x1=datetime.strptime("2022-01-03", "%Y-%m-%d").timestamp() * 1000,
             annotation_text="Phase 4", annotation_position="top left",
             fillcolor='#FFA15A', opacity=0.10, line_width=0)


# add vlines of measures
#fig.add_vline(x=datetime(2021, 7, 5)) # Phase 2 mco
#fig.add_vline(x=datetime(2021, 8, 4)) # Phase 3 mco
#fig.add_vline(x=datetime(2021, 8, 26)) # Phase 3 mco

fig.update_xaxes()
fig.update_layout(height=700, width=1200)

fig.show()

In [None]:
az.summary(idata)

In [None]:
plot_post_prior_comparison(idata, list(priors.keys()), priors);

In [None]:
Markdown(renew_model.get_description())